DocExplainer / explainer.py
AlessioChenn's picture
Update explainer.py
6a27d7e verified
from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, SiglipModel
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
class ExplainerConfig(PretrainedConfig):
model_type = "explainer"
def __init__(self, base_model_name='google/siglip2-giant-opt-patch16-384',
hidden_dim=768, giant=True, **kwargs):
self.base_model_name = base_model_name
self.hidden_dim = hidden_dim
self.giant = giant
super().__init__(**kwargs)
class SigLIPBBoxRegressor(nn.Module):
def __init__(self, siglip_model, hidden_dim=768, giant=True):
super().__init__()
self.siglip = siglip_model
vision_dim = self.siglip.vision_model.config.hidden_size
text_dim = self.siglip.text_model.config.hidden_size
if giant: text_dim = 1536
# Feature fusion layers
self.vision_projector = nn.Sequential(
nn.Linear(vision_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
)
self.text_projector = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
)
# Cross-modal fusion
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(0.1)
)
self.topleft_regressor = nn.Sequential(
nn.Linear(hidden_dim//2, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 2), # (x1, y1)
)
self.bottomright_regressor = nn.Sequential(
nn.Linear(hidden_dim//2, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 2), # (x2, y2)
)
def forward(self, pixel_values, input_ids):
with torch.no_grad():
outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True)
# Extract pooled features
vision_features = outputs.image_embeds.float()
text_features = outputs.text_embeds.float()
# Project features
vision_proj = self.vision_projector(vision_features)
text_proj = self.text_projector(text_features)
# Fuse modalities
fused = torch.cat([vision_proj, text_proj], dim=1)
fused_features = self.fusion_layer(fused)
# Predict bbox
topleft_pred = self.topleft_regressor(fused_features)
bottomright_pred = self.bottomright_regressor(fused_features)
return torch.cat([topleft_pred, bottomright_pred], dim=1)
class Explainer(PreTrainedModel):
config_class = ExplainerConfig
def __init__(self, config):
super().__init__(config)
self.siglip_model = SiglipModel.from_pretrained(config.base_model_name)
self.bbox_regressor = SigLIPBBoxRegressor(self.siglip_model)
self.processor = AutoProcessor.from_pretrained(config.base_model_name, use_fast=True)
def forward(self, pixel_values=None, input_ids=None):
return self.bbox_regressor(pixel_values, input_ids)
def predict(self, image, text, device="cuda"):
self.to(device)
self.eval()
inputs = self.processor(
text=text,
images=image,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=64
)
pixel_values = inputs["pixel_values"].to(device).half()
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
pred_bbox = self.forward(pixel_values, input_ids)
return pred_bbox[0].cpu().numpy().tolist()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = kwargs.pop("config", None)
if config is None:
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
checkpoint_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="pytorch_model.bin"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.siglip_model.load_state_dict(checkpoint["siglip_model"])
model.bbox_regressor.load_state_dict(checkpoint["bbox_regressor"])
return model