Model Card: ResNet-50 Graph Radiology Captioner
Model Details
- Model Name: ResNet-50 Graph Radiology Captioner
- Architecture: Custom Encoder-Decoder with Graph Propagation
- Language: English
- Dataset: eltorio/ROCOv2-radiology
- License: Apache-2.0
Model Description
This model is designed to generate automated textual descriptions (captions) for radiology images. It utilizes a hybrid architecture combining Convolutional Neural Networks (CNNs) with Graph Propagation and Transformer decoders.
The architecture consists of three main components:
- Visual Encoder: A pre-trained
ResNet-50(from Hugging Face) extracts visual features from the input image. - Graph Propagation Layer: A custom layer processes the spatial feature map using a sliding window mechanism (window size $3$) to aggregate neighborhood information over $3$ steps. This allows the model to better capture spatial relationships between anatomical structures.
- Textual Decoder: A standard Transformer Decoder (4 layers, 8 heads) generates the caption token by token, attending to the enhanced visual features.
Intended Uses
- Primary Use: Generating preliminary descriptive reports for medical imaging datasets.
- Research: Investigating graph-based spatial reasoning in medical image captioning.
Limitations
- The model may generate repetitive phrases or hallucinate findings not present in the image.
- The inference logic relies on a greedy decoder; beam search could improve results.
- Performance is strictly tied to the ROCOv2 dataset distribution; generalizability to other clinical domains may vary.
Training Data
The model was trained on the ROCOv2-radiology dataset.
- Dataset Split:
- Train: 56,963 images
- Validation: 2,999 images
- Image Modality: Radiology (X-ray, CT, MRI, etc.)
- Preprocessing: Images resized to $224 \times 224$, normalized using ImageNet statistics, and augmented with random horizontal flips and rotations ($\pm 10^\circ$).
Training Procedure
The model was trained for 15 epochs on a TPU environment using the CrossEntropyLoss function.
Hyperparameters
| Parameter | Value |
|---|---|
| Batch Size | 200 |
| Learning Rate | $1 \times 10^{-4}$ |
| Optimizer | AdamW |
| Epochs | 15 |
| Max Sequence Length | 80 |
| Embedding Dimension | 512 |
| Feed Forward Dimension | 1024 |
| Decoder Layers | 4 |
| Attention Heads | 8 |
| Graph Window Size | 3 |
| Graph Propagation Steps | 3 |
Training Results
The model achieved a steady decrease in training loss and perplexity over 15 epochs.
- Final Training Loss: $2.4942$
- Final Perplexity: $12.1121$
| Epoch | Training Loss | Perplexity |
|---|---|---|
| 1 | 6.0415 | 420.53 |
| 5 | 3.3029 | 27.19 |
| 10 | 2.7978 | 16.41 |
| 15 | 2.4942 | 12.11 |
Evaluation
Inference Sample
Below is a comparison between the ground truth caption and the model's generation on a validation set image.
Ground Truth: "Figure 1: PICC outlined by the arrows." Generated Caption: "radiograph of the abdomen showing the presence of a large, well - defined, oval, calcified, and calcifications in the distal part of the right lower quadrant ( arrows )."
How to Use
To use this model, you need to install torch, transformers, and pillow. The model architecture must be defined in your environment to load the weights.
import torch
import torch.nn as nn
from transformers import BertTokenizer, AutoModel
from PIL import Image
from torchvision import transforms
# 1. Define the Model Architecture (Must match training script)
class GraphPropagationLayer(nn.Module):
def __init__(self, embed_dim, window_size=3, max_steps=3):
super().__init__()
self.window_size = window_size
self.max_steps = max_steps
self.msg_linear = nn.Linear(embed_dim, embed_dim)
self.activation = nn.Hardtanh()
def forward(self, x):
B, L, D = x.shape
pad = self.window_size // 2
for _ in range(self.max_steps):
neighbors = []
for i in range(-pad, pad + 1):
neighbors.append(torch.roll(x, shifts=i, dims=1))
stacked = torch.stack(neighbors, dim=2)
aggregated = stacked.mean(dim=2)
messages = self.msg_linear(aggregated)
x = self.activation(x + messages)
return x
class ResNetGraphCaptioner(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads):
super().__init__()
self.encoder = AutoModel.from_pretrained("microsoft/resnet-50", trust_remote_code=True)
self.encoder.config.pooling = None
encoder_dim = self.encoder.config.hidden_sizes[-1]
self.vis_proj = nn.Linear(encoder_dim, embed_dim)
self.graph_layer = GraphPropagationLayer(embed_dim, window_size=3, max_steps=3)
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoder = nn.Parameter(torch.randn(1, 2000, embed_dim))
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=1024, batch_first=True)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(embed_dim, vocab_size)
def forward(self, images, captions):
# Encode
outputs = self.encoder(pixel_values=images)
B, C, H, W = outputs.last_hidden_state.shape
visual_features = outputs.last_hidden_state.permute(0, 2, 3, 1).reshape(B, H*W, C)
visual_tokens = self.vis_proj(visual_features)
memory = self.graph_layer(visual_tokens)
# Decode
caption_emb = self.embedding(captions)
seq_len = caption_emb.size(1)
caption_emb = caption_emb + self.pos_encoder[:, :seq_len, :]
tgt_mask = torch.triu(torch.ones((seq_len, seq_len)) * float('-inf'), diagonal=1).to(images.device)
output = self.decoder(tgt=caption_emb, memory=memory, tgt_mask=tgt_mask)
return self.fc_out(output)
# 2. Load Checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Initialize model with specific config used during training
model = ResNetGraphCaptioner(
vocab_size=tokenizer.vocab_size,
embed_dim=512,
num_layers=4,
num_heads=8
)
# Load weights (assuming you have the checkpoint file locally)
checkpoint_path = "erfanasghariyan/resnet50_roco_radiocaptioner_big/roco_best_model.bin"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()
# 3. Inference Function
def generate_caption(image_path, model, tokenizer, max_len=80):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(image_path).convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
input_ids = [tokenizer.cls_token_id]
with torch.no_grad():
for _ in range(max_len):
tensor_input = torch.tensor([input_ids], device=device)
outputs = model(img_tensor, tensor_input)
next_token = outputs[0, -1, :].argmax().item()
if next_token == tokenizer.sep_token_id:
break
input_ids.append(next_token)
return tokenizer.decode(input_ids[1:], skip_special_tokens=True)
# Example Usage
# caption = generate_caption("path/to/xray.jpg", model, tokenizer)
# print(caption)
- Downloads last month
- 34