|
|
from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, SiglipModel |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class ExplainerConfig(PretrainedConfig): |
|
|
model_type = "explainer" |
|
|
|
|
|
def __init__(self, base_model_name='google/siglip2-giant-opt-patch16-384', |
|
|
hidden_dim=768, giant=True, **kwargs): |
|
|
self.base_model_name = base_model_name |
|
|
self.hidden_dim = hidden_dim |
|
|
self.giant = giant |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class SigLIPBBoxRegressor(nn.Module): |
|
|
def __init__(self, siglip_model, hidden_dim=768, giant=True): |
|
|
super().__init__() |
|
|
self.siglip = siglip_model |
|
|
|
|
|
vision_dim = self.siglip.vision_model.config.hidden_size |
|
|
text_dim = self.siglip.text_model.config.hidden_size |
|
|
if giant: text_dim = 1536 |
|
|
|
|
|
|
|
|
self.vision_projector = nn.Sequential( |
|
|
nn.Linear(vision_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1) |
|
|
) |
|
|
self.text_projector = nn.Sequential( |
|
|
nn.Linear(text_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1) |
|
|
) |
|
|
|
|
|
|
|
|
self.fusion_layer = nn.Sequential( |
|
|
nn.Linear(hidden_dim*2, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(hidden_dim, hidden_dim//2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1) |
|
|
) |
|
|
self.topleft_regressor = nn.Sequential( |
|
|
nn.Linear(hidden_dim//2, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 2), |
|
|
) |
|
|
self.bottomright_regressor = nn.Sequential( |
|
|
nn.Linear(hidden_dim//2, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 2), |
|
|
) |
|
|
|
|
|
def forward(self, pixel_values, input_ids): |
|
|
with torch.no_grad(): |
|
|
outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True) |
|
|
|
|
|
|
|
|
vision_features = outputs.image_embeds.float() |
|
|
text_features = outputs.text_embeds.float() |
|
|
|
|
|
|
|
|
|
|
|
vision_proj = self.vision_projector(vision_features) |
|
|
text_proj = self.text_projector(text_features) |
|
|
|
|
|
|
|
|
fused = torch.cat([vision_proj, text_proj], dim=1) |
|
|
fused_features = self.fusion_layer(fused) |
|
|
|
|
|
|
|
|
topleft_pred = self.topleft_regressor(fused_features) |
|
|
bottomright_pred = self.bottomright_regressor(fused_features) |
|
|
|
|
|
return torch.cat([topleft_pred, bottomright_pred], dim=1) |
|
|
|
|
|
class Explainer(PreTrainedModel): |
|
|
config_class = ExplainerConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.siglip_model = SiglipModel.from_pretrained(config.base_model_name) |
|
|
self.bbox_regressor = SigLIPBBoxRegressor(self.siglip_model) |
|
|
self.processor = AutoProcessor.from_pretrained(config.base_model_name, use_fast=True) |
|
|
|
|
|
def forward(self, pixel_values=None, input_ids=None): |
|
|
return self.bbox_regressor(pixel_values, input_ids) |
|
|
|
|
|
def predict(self, image, text, device="cuda"): |
|
|
self.to(device) |
|
|
self.eval() |
|
|
inputs = self.processor( |
|
|
text=text, |
|
|
images=image, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=64 |
|
|
) |
|
|
pixel_values = inputs["pixel_values"].to(device).half() |
|
|
input_ids = inputs["input_ids"].to(device) |
|
|
with torch.no_grad(): |
|
|
pred_bbox = self.forward(pixel_values, input_ids) |
|
|
return pred_bbox[0].cpu().numpy().tolist() |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
|
config = kwargs.pop("config", None) |
|
|
if config is None: |
|
|
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="pytorch_model.bin" |
|
|
) |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
model.siglip_model.load_state_dict(checkpoint["siglip_model"]) |
|
|
model.bbox_regressor.load_state_dict(checkpoint["bbox_regressor"]) |
|
|
return model |
|
|
|