AlessioChenn commited on
Commit
6a27d7e
·
verified ·
1 Parent(s): f2d0a8e

Update explainer.py

Browse files
Files changed (1) hide show
  1. explainer.py +14 -4
explainer.py CHANGED
@@ -22,6 +22,7 @@ class SigLIPBBoxRegressor(nn.Module):
22
  text_dim = self.siglip.text_model.config.hidden_size
23
  if giant: text_dim = 1536
24
 
 
25
  self.vision_projector = nn.Sequential(
26
  nn.Linear(vision_dim, hidden_dim),
27
  nn.ReLU(),
@@ -32,6 +33,8 @@ class SigLIPBBoxRegressor(nn.Module):
32
  nn.ReLU(),
33
  nn.Dropout(0.1)
34
  )
 
 
35
  self.fusion_layer = nn.Sequential(
36
  nn.Linear(hidden_dim*2, hidden_dim),
37
  nn.ReLU(),
@@ -46,7 +49,7 @@ class SigLIPBBoxRegressor(nn.Module):
46
  nn.Dropout(0.1),
47
  nn.Linear(256, 128),
48
  nn.ReLU(),
49
- nn.Linear(128, 2),
50
  )
51
  self.bottomright_regressor = nn.Sequential(
52
  nn.Linear(hidden_dim//2, 256),
@@ -54,22 +57,30 @@ class SigLIPBBoxRegressor(nn.Module):
54
  nn.Dropout(0.1),
55
  nn.Linear(256, 128),
56
  nn.ReLU(),
57
- nn.Linear(128, 2),
58
  )
59
 
60
  def forward(self, pixel_values, input_ids):
61
  with torch.no_grad():
62
  outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True)
 
 
63
  vision_features = outputs.image_embeds.float()
64
  text_features = outputs.text_embeds.float()
 
 
65
 
66
  vision_proj = self.vision_projector(vision_features)
67
  text_proj = self.text_projector(text_features)
 
 
68
  fused = torch.cat([vision_proj, text_proj], dim=1)
69
  fused_features = self.fusion_layer(fused)
70
-
 
71
  topleft_pred = self.topleft_regressor(fused_features)
72
  bottomright_pred = self.bottomright_regressor(fused_features)
 
73
  return torch.cat([topleft_pred, bottomright_pred], dim=1)
74
 
75
  class Explainer(PreTrainedModel):
@@ -104,7 +115,6 @@ class Explainer(PreTrainedModel):
104
 
105
  @classmethod
106
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
107
- # Load config automatically (HF passes `config` here sometimes)
108
  config = kwargs.pop("config", None)
109
  if config is None:
110
  config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
 
22
  text_dim = self.siglip.text_model.config.hidden_size
23
  if giant: text_dim = 1536
24
 
25
+ # Feature fusion layers
26
  self.vision_projector = nn.Sequential(
27
  nn.Linear(vision_dim, hidden_dim),
28
  nn.ReLU(),
 
33
  nn.ReLU(),
34
  nn.Dropout(0.1)
35
  )
36
+
37
+ # Cross-modal fusion
38
  self.fusion_layer = nn.Sequential(
39
  nn.Linear(hidden_dim*2, hidden_dim),
40
  nn.ReLU(),
 
49
  nn.Dropout(0.1),
50
  nn.Linear(256, 128),
51
  nn.ReLU(),
52
+ nn.Linear(128, 2), # (x1, y1)
53
  )
54
  self.bottomright_regressor = nn.Sequential(
55
  nn.Linear(hidden_dim//2, 256),
 
57
  nn.Dropout(0.1),
58
  nn.Linear(256, 128),
59
  nn.ReLU(),
60
+ nn.Linear(128, 2), # (x2, y2)
61
  )
62
 
63
  def forward(self, pixel_values, input_ids):
64
  with torch.no_grad():
65
  outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True)
66
+
67
+ # Extract pooled features
68
  vision_features = outputs.image_embeds.float()
69
  text_features = outputs.text_embeds.float()
70
+
71
+ # Project features
72
 
73
  vision_proj = self.vision_projector(vision_features)
74
  text_proj = self.text_projector(text_features)
75
+
76
+ # Fuse modalities
77
  fused = torch.cat([vision_proj, text_proj], dim=1)
78
  fused_features = self.fusion_layer(fused)
79
+
80
+ # Predict bbox
81
  topleft_pred = self.topleft_regressor(fused_features)
82
  bottomright_pred = self.bottomright_regressor(fused_features)
83
+
84
  return torch.cat([topleft_pred, bottomright_pred], dim=1)
85
 
86
  class Explainer(PreTrainedModel):
 
115
 
116
  @classmethod
117
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
 
118
  config = kwargs.pop("config", None)
119
  if config is None:
120
  config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)