medai_server / models.py
mnni43353's picture
Update models.py
41c6f3a verified
import torch
import torchvision.transforms as T
from torchvision import models
from monai.networks.nets import UNet
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianTokenizer, MarianMTModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_vision_model(checkpoint_path=None, num_classes=14):
model = models.resnet50(pretrained=True)
in_feats = model.fc.in_features
model.fc = torch.nn.Linear(in_feats, num_classes)
model.to(DEVICE).eval()
if checkpoint_path:
state = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(state)
return model
def load_segmentation_model(checkpoint_path=None):
net = UNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2
).to(DEVICE)
net.eval()
if checkpoint_path:
net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
return net
def load_text_models(summarizer_name="google/t5-small-ssm"):
tokenizer = AutoTokenizer.from_pretrained(summarizer_name)
model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_name).to(DEVICE)
return tokenizer, model
def load_translation_model():
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar").to(DEVICE)
return tokenizer, model
def get_image_transform():
return T.Compose([
T.ToPILImage(),
T.Resize((224, 224)),
T.Grayscale(num_output_channels=3),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])