Spaces:
Sleeping
Sleeping
File size: 1,726 Bytes
6e253d5 41c6f3a 6e253d5 41c6f3a 6e253d5 41c6f3a 6e253d5 41c6f3a 6e253d5 41c6f3a 6e253d5 41c6f3a 6e253d5 41c6f3a 6e253d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
import torchvision.transforms as T
from torchvision import models
from monai.networks.nets import UNet
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianTokenizer, MarianMTModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_vision_model(checkpoint_path=None, num_classes=14):
model = models.resnet50(pretrained=True)
in_feats = model.fc.in_features
model.fc = torch.nn.Linear(in_feats, num_classes)
model.to(DEVICE).eval()
if checkpoint_path:
state = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(state)
return model
def load_segmentation_model(checkpoint_path=None):
net = UNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2
).to(DEVICE)
net.eval()
if checkpoint_path:
net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
return net
def load_text_models(summarizer_name="google/t5-small-ssm"):
tokenizer = AutoTokenizer.from_pretrained(summarizer_name)
model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_name).to(DEVICE)
return tokenizer, model
def load_translation_model():
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar").to(DEVICE)
return tokenizer, model
def get_image_transform():
return T.Compose([
T.ToPILImage(),
T.Resize((224, 224)),
T.Grayscale(num_output_channels=3),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]) |