import torch import torchvision.transforms as T from torchvision import models from monai.networks.nets import UNet from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianTokenizer, MarianMTModel DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_vision_model(checkpoint_path=None, num_classes=14): model = models.resnet50(pretrained=True) in_feats = model.fc.in_features model.fc = torch.nn.Linear(in_feats, num_classes) model.to(DEVICE).eval() if checkpoint_path: state = torch.load(checkpoint_path, map_location=DEVICE) model.load_state_dict(state) return model def load_segmentation_model(checkpoint_path=None): net = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2 ).to(DEVICE) net.eval() if checkpoint_path: net.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE)) return net def load_text_models(summarizer_name="google/t5-small-ssm"): tokenizer = AutoTokenizer.from_pretrained(summarizer_name) model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_name).to(DEVICE) return tokenizer, model def load_translation_model(): tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar").to(DEVICE) return tokenizer, model def get_image_transform(): return T.Compose([ T.ToPILImage(), T.Resize((224, 224)), T.Grayscale(num_output_channels=3), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])