Model Card: ResNet-50 Graph Radiology Captioner

Model Details

  • Model Name: ResNet-50 Graph Radiology Captioner
  • Architecture: Custom Encoder-Decoder with Graph Propagation
  • Language: English
  • Dataset: eltorio/ROCOv2-radiology
  • License: Apache-2.0

Model Description

This model is designed to generate automated textual descriptions (captions) for radiology images. It utilizes a hybrid architecture combining Convolutional Neural Networks (CNNs) with Graph Propagation and Transformer decoders.

The architecture consists of three main components:

  1. Visual Encoder: A pre-trained ResNet-50 (from Hugging Face) extracts visual features from the input image.
  2. Graph Propagation Layer: A custom layer processes the spatial feature map using a sliding window mechanism (window size $3$) to aggregate neighborhood information over $3$ steps. This allows the model to better capture spatial relationships between anatomical structures.
  3. Textual Decoder: A standard Transformer Decoder (4 layers, 8 heads) generates the caption token by token, attending to the enhanced visual features.

Intended Uses

  • Primary Use: Generating preliminary descriptive reports for medical imaging datasets.
  • Research: Investigating graph-based spatial reasoning in medical image captioning.

Limitations

  • The model may generate repetitive phrases or hallucinate findings not present in the image.
  • The inference logic relies on a greedy decoder; beam search could improve results.
  • Performance is strictly tied to the ROCOv2 dataset distribution; generalizability to other clinical domains may vary.

Training Data

The model was trained on the ROCOv2-radiology dataset.

  • Dataset Split:
    • Train: 56,963 images
    • Validation: 2,999 images
  • Image Modality: Radiology (X-ray, CT, MRI, etc.)
  • Preprocessing: Images resized to $224 \times 224$, normalized using ImageNet statistics, and augmented with random horizontal flips and rotations ($\pm 10^\circ$).

Training Procedure

The model was trained for 15 epochs on a TPU environment using the CrossEntropyLoss function.

Hyperparameters

Parameter Value
Batch Size 200
Learning Rate $1 \times 10^{-4}$
Optimizer AdamW
Epochs 15
Max Sequence Length 80
Embedding Dimension 512
Feed Forward Dimension 1024
Decoder Layers 4
Attention Heads 8
Graph Window Size 3
Graph Propagation Steps 3

Training Results

The model achieved a steady decrease in training loss and perplexity over 15 epochs.

  • Final Training Loss: $2.4942$
  • Final Perplexity: $12.1121$
Epoch Training Loss Perplexity
1 6.0415 420.53
5 3.3029 27.19
10 2.7978 16.41
15 2.4942 12.11

Evaluation

Inference Sample

Below is a comparison between the ground truth caption and the model's generation on a validation set image.

Ground Truth: "Figure 1: PICC outlined by the arrows." Generated Caption: "radiograph of the abdomen showing the presence of a large, well - defined, oval, calcified, and calcifications in the distal part of the right lower quadrant ( arrows )."

How to Use

To use this model, you need to install torch, transformers, and pillow. The model architecture must be defined in your environment to load the weights.

import torch
import torch.nn as nn
from transformers import BertTokenizer, AutoModel
from PIL import Image
from torchvision import transforms

# 1. Define the Model Architecture (Must match training script)
class GraphPropagationLayer(nn.Module):
def __init__(self, embed_dim, window_size=3, max_steps=3):
super().__init__()
self.window_size = window_size
self.max_steps = max_steps
self.msg_linear = nn.Linear(embed_dim, embed_dim)
self.activation = nn.Hardtanh()

def forward(self, x):
B, L, D = x.shape
pad = self.window_size // 2
for _ in range(self.max_steps):
neighbors = []
for i in range(-pad, pad + 1):
neighbors.append(torch.roll(x, shifts=i, dims=1))
stacked = torch.stack(neighbors, dim=2)
aggregated = stacked.mean(dim=2)
messages = self.msg_linear(aggregated)
x = self.activation(x + messages)
return x

class ResNetGraphCaptioner(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads):
super().__init__()
self.encoder = AutoModel.from_pretrained("microsoft/resnet-50", trust_remote_code=True)
self.encoder.config.pooling = None
encoder_dim = self.encoder.config.hidden_sizes[-1]

self.vis_proj = nn.Linear(encoder_dim, embed_dim)
self.graph_layer = GraphPropagationLayer(embed_dim, window_size=3, max_steps=3)

self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoder = nn.Parameter(torch.randn(1, 2000, embed_dim))

decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=1024, batch_first=True)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(embed_dim, vocab_size)

def forward(self, images, captions):
# Encode
outputs = self.encoder(pixel_values=images)
B, C, H, W = outputs.last_hidden_state.shape
visual_features = outputs.last_hidden_state.permute(0, 2, 3, 1).reshape(B, H*W, C)
visual_tokens = self.vis_proj(visual_features)
memory = self.graph_layer(visual_tokens)

# Decode
caption_emb = self.embedding(captions)
seq_len = caption_emb.size(1)
caption_emb = caption_emb + self.pos_encoder[:, :seq_len, :]

tgt_mask = torch.triu(torch.ones((seq_len, seq_len)) * float('-inf'), diagonal=1).to(images.device)
output = self.decoder(tgt=caption_emb, memory=memory, tgt_mask=tgt_mask)
return self.fc_out(output)

# 2. Load Checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Initialize model with specific config used during training
model = ResNetGraphCaptioner(
vocab_size=tokenizer.vocab_size, 
embed_dim=512, 
num_layers=4, 
num_heads=8
)

# Load weights (assuming you have the checkpoint file locally)
checkpoint_path = "erfanasghariyan/resnet50_roco_radiocaptioner_big/roco_best_model.bin"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()

# 3. Inference Function
def generate_caption(image_path, model, tokenizer, max_len=80):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image = Image.open(image_path).convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)

input_ids = [tokenizer.cls_token_id]

with torch.no_grad():
for _ in range(max_len):
tensor_input = torch.tensor([input_ids], device=device)
outputs = model(img_tensor, tensor_input)
next_token = outputs[0, -1, :].argmax().item()

if next_token == tokenizer.sep_token_id:
break
input_ids.append(next_token)

return tokenizer.decode(input_ids[1:], skip_special_tokens=True)

# Example Usage
# caption = generate_caption("path/to/xray.jpg", model, tokenizer)
# print(caption)
Downloads last month
34
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train erfanasghariyan/resnet50_roco_radiocaptioner_big

Space using erfanasghariyan/resnet50_roco_radiocaptioner_big 1