mnni43353 commited on
Commit
41c6f3a
·
verified ·
1 Parent(s): 8659f57

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +19 -4
models.py CHANGED
@@ -2,17 +2,21 @@ import torch
2
  import torchvision.transforms as T
3
  from torchvision import models
4
  from monai.networks.nets import UNet
 
5
 
6
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- def load_vision_model(num_classes=14):
9
  model = models.resnet50(pretrained=True)
10
  in_feats = model.fc.in_features
11
  model.fc = torch.nn.Linear(in_feats, num_classes)
12
  model.to(DEVICE).eval()
 
 
 
13
  return model
14
 
15
- def load_segmentation_model():
16
  net = UNet(
17
  spatial_dims=2,
18
  in_channels=1,
@@ -22,14 +26,25 @@ def load_segmentation_model():
22
  num_res_units=2
23
  ).to(DEVICE)
24
  net.eval()
 
 
25
  return net
26
 
 
 
 
 
 
 
 
 
 
 
27
  def get_image_transform():
28
  return T.Compose([
29
  T.ToPILImage(),
30
  T.Resize((224, 224)),
31
  T.Grayscale(num_output_channels=3),
32
  T.ToTensor(),
33
- T.Normalize([0.485, 0.456, 0.406],
34
- [0.229, 0.224, 0.225])
35
  ])
 
2
  import torchvision.transforms as T
3
  from torchvision import models
4
  from monai.networks.nets import UNet
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianTokenizer, MarianMTModel
6
 
7
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ def load_vision_model(checkpoint_path=None, num_classes=14):
10
  model = models.resnet50(pretrained=True)
11
  in_feats = model.fc.in_features
12
  model.fc = torch.nn.Linear(in_feats, num_classes)
13
  model.to(DEVICE).eval()
14
+ if checkpoint_path:
15
+ state = torch.load(checkpoint_path, map_location=DEVICE)
16
+ model.load_state_dict(state)
17
  return model
18
 
19
+ def load_segmentation_model(checkpoint_path=None):
20
  net = UNet(
21
  spatial_dims=2,
22
  in_channels=1,
 
26
  num_res_units=2
27
  ).to(DEVICE)
28
  net.eval()
29
+ if checkpoint_path:
30
+ net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
31
  return net
32
 
33
+ def load_text_models(summarizer_name="google/t5-small-ssm"):
34
+ tokenizer = AutoTokenizer.from_pretrained(summarizer_name)
35
+ model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_name).to(DEVICE)
36
+ return tokenizer, model
37
+
38
+ def load_translation_model():
39
+ tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
40
+ model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar").to(DEVICE)
41
+ return tokenizer, model
42
+
43
  def get_image_transform():
44
  return T.Compose([
45
  T.ToPILImage(),
46
  T.Resize((224, 224)),
47
  T.Grayscale(num_output_channels=3),
48
  T.ToTensor(),
49
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 
50
  ])