from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, SiglipModel import torch import torch.nn as nn from huggingface_hub import hf_hub_download class ExplainerConfig(PretrainedConfig): model_type = "explainer" def __init__(self, base_model_name='google/siglip2-giant-opt-patch16-384', hidden_dim=768, giant=True, **kwargs): self.base_model_name = base_model_name self.hidden_dim = hidden_dim self.giant = giant super().__init__(**kwargs) class SigLIPBBoxRegressor(nn.Module): def __init__(self, siglip_model, hidden_dim=768, giant=True): super().__init__() self.siglip = siglip_model vision_dim = self.siglip.vision_model.config.hidden_size text_dim = self.siglip.text_model.config.hidden_size if giant: text_dim = 1536 # Feature fusion layers self.vision_projector = nn.Sequential( nn.Linear(vision_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.1) ) self.text_projector = nn.Sequential( nn.Linear(text_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.1) ) # Cross-modal fusion self.fusion_layer = nn.Sequential( nn.Linear(hidden_dim*2, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(), nn.Dropout(0.1) ) self.topleft_regressor = nn.Sequential( nn.Linear(hidden_dim//2, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 2), # (x1, y1) ) self.bottomright_regressor = nn.Sequential( nn.Linear(hidden_dim//2, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 2), # (x2, y2) ) def forward(self, pixel_values, input_ids): with torch.no_grad(): outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True) # Extract pooled features vision_features = outputs.image_embeds.float() text_features = outputs.text_embeds.float() # Project features vision_proj = self.vision_projector(vision_features) text_proj = self.text_projector(text_features) # Fuse modalities fused = torch.cat([vision_proj, text_proj], dim=1) fused_features = self.fusion_layer(fused) # Predict bbox topleft_pred = self.topleft_regressor(fused_features) bottomright_pred = self.bottomright_regressor(fused_features) return torch.cat([topleft_pred, bottomright_pred], dim=1) class Explainer(PreTrainedModel): config_class = ExplainerConfig def __init__(self, config): super().__init__(config) self.siglip_model = SiglipModel.from_pretrained(config.base_model_name) self.bbox_regressor = SigLIPBBoxRegressor(self.siglip_model) self.processor = AutoProcessor.from_pretrained(config.base_model_name, use_fast=True) def forward(self, pixel_values=None, input_ids=None): return self.bbox_regressor(pixel_values, input_ids) def predict(self, image, text, device="cuda"): self.to(device) self.eval() inputs = self.processor( text=text, images=image, return_tensors="pt", padding="max_length", truncation=True, max_length=64 ) pixel_values = inputs["pixel_values"].to(device).half() input_ids = inputs["input_ids"].to(device) with torch.no_grad(): pred_bbox = self.forward(pixel_values, input_ids) return pred_bbox[0].cpu().numpy().tolist() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = kwargs.pop("config", None) if config is None: config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) model = cls(config) checkpoint_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin" ) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.siglip_model.load_state_dict(checkpoint["siglip_model"]) model.bbox_regressor.load_state_dict(checkpoint["bbox_regressor"]) return model