Spaces:
Sleeping
Sleeping
Update models.py
Browse files
models.py
CHANGED
|
@@ -2,17 +2,21 @@ import torch
|
|
| 2 |
import torchvision.transforms as T
|
| 3 |
from torchvision import models
|
| 4 |
from monai.networks.nets import UNet
|
|
|
|
| 5 |
|
| 6 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
|
| 8 |
-
def load_vision_model(num_classes=14):
|
| 9 |
model = models.resnet50(pretrained=True)
|
| 10 |
in_feats = model.fc.in_features
|
| 11 |
model.fc = torch.nn.Linear(in_feats, num_classes)
|
| 12 |
model.to(DEVICE).eval()
|
|
|
|
|
|
|
|
|
|
| 13 |
return model
|
| 14 |
|
| 15 |
-
def load_segmentation_model():
|
| 16 |
net = UNet(
|
| 17 |
spatial_dims=2,
|
| 18 |
in_channels=1,
|
|
@@ -22,14 +26,25 @@ def load_segmentation_model():
|
|
| 22 |
num_res_units=2
|
| 23 |
).to(DEVICE)
|
| 24 |
net.eval()
|
|
|
|
|
|
|
| 25 |
return net
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def get_image_transform():
|
| 28 |
return T.Compose([
|
| 29 |
T.ToPILImage(),
|
| 30 |
T.Resize((224, 224)),
|
| 31 |
T.Grayscale(num_output_channels=3),
|
| 32 |
T.ToTensor(),
|
| 33 |
-
T.Normalize([0.485, 0.456, 0.406],
|
| 34 |
-
[0.229, 0.224, 0.225])
|
| 35 |
])
|
|
|
|
| 2 |
import torchvision.transforms as T
|
| 3 |
from torchvision import models
|
| 4 |
from monai.networks.nets import UNet
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianTokenizer, MarianMTModel
|
| 6 |
|
| 7 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
|
| 9 |
+
def load_vision_model(checkpoint_path=None, num_classes=14):
|
| 10 |
model = models.resnet50(pretrained=True)
|
| 11 |
in_feats = model.fc.in_features
|
| 12 |
model.fc = torch.nn.Linear(in_feats, num_classes)
|
| 13 |
model.to(DEVICE).eval()
|
| 14 |
+
if checkpoint_path:
|
| 15 |
+
state = torch.load(checkpoint_path, map_location=DEVICE)
|
| 16 |
+
model.load_state_dict(state)
|
| 17 |
return model
|
| 18 |
|
| 19 |
+
def load_segmentation_model(checkpoint_path=None):
|
| 20 |
net = UNet(
|
| 21 |
spatial_dims=2,
|
| 22 |
in_channels=1,
|
|
|
|
| 26 |
num_res_units=2
|
| 27 |
).to(DEVICE)
|
| 28 |
net.eval()
|
| 29 |
+
if checkpoint_path:
|
| 30 |
+
net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
|
| 31 |
return net
|
| 32 |
|
| 33 |
+
def load_text_models(summarizer_name="google/t5-small-ssm"):
|
| 34 |
+
tokenizer = AutoTokenizer.from_pretrained(summarizer_name)
|
| 35 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_name).to(DEVICE)
|
| 36 |
+
return tokenizer, model
|
| 37 |
+
|
| 38 |
+
def load_translation_model():
|
| 39 |
+
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
| 40 |
+
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar").to(DEVICE)
|
| 41 |
+
return tokenizer, model
|
| 42 |
+
|
| 43 |
def get_image_transform():
|
| 44 |
return T.Compose([
|
| 45 |
T.ToPILImage(),
|
| 46 |
T.Resize((224, 224)),
|
| 47 |
T.Grayscale(num_output_channels=3),
|
| 48 |
T.ToTensor(),
|
| 49 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
|
| 50 |
])
|