File size: 1,726 Bytes
6e253d5
 
 
 
41c6f3a
6e253d5
 
 
41c6f3a
6e253d5
 
 
 
41c6f3a
 
 
6e253d5
 
41c6f3a
6e253d5
 
 
 
 
 
 
 
 
41c6f3a
 
6e253d5
 
41c6f3a
 
 
 
 
 
 
 
 
 
6e253d5
 
 
 
 
 
41c6f3a
6e253d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torchvision.transforms as T
from torchvision import models
from monai.networks.nets import UNet
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianTokenizer, MarianMTModel

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_vision_model(checkpoint_path=None, num_classes=14):
    model = models.resnet50(pretrained=True)
    in_feats = model.fc.in_features
    model.fc = torch.nn.Linear(in_feats, num_classes)
    model.to(DEVICE).eval()
    if checkpoint_path:
        state = torch.load(checkpoint_path, map_location=DEVICE)
        model.load_state_dict(state)
    return model

def load_segmentation_model(checkpoint_path=None):
    net = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2
    ).to(DEVICE)
    net.eval()
    if checkpoint_path:
        net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
    return net

def load_text_models(summarizer_name="google/t5-small-ssm"):
    tokenizer = AutoTokenizer.from_pretrained(summarizer_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_name).to(DEVICE)
    return tokenizer, model

def load_translation_model():
    tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
    model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar").to(DEVICE)
    return tokenizer, model

def get_image_transform():
    return T.Compose([
        T.ToPILImage(),
        T.Resize((224, 224)),
        T.Grayscale(num_output_channels=3),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])