import torch class DoubleCrossAttentionFusion(torch.nn.Module): def __init__(self, hidden_dim=768, num_heads=8, dropout=0.1): super().__init__() # 1. Per-modality normalization. self.norm_rgb = torch.nn.LayerNorm(hidden_dim) self.norm_depth = torch.nn.LayerNorm(hidden_dim) # 2. Cross-attention. self.cross_attn_depth = torch.torch.nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.cross_attn_rgb = torch.torch.nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) # 3. Mixing. self.mixer = torch.nn.Sequential( torch.nn.Linear(hidden_dim * 2, hidden_dim), torch.nn.GELU(), torch.nn.Dropout(dropout) ) # 4. Output normalisation. self.out_norm = torch.nn.LayerNorm(hidden_dim) def forward(self, rgb_features, depth_features): # 1. Normalize inputs. rgb = self.norm_rgb(rgb_features) depth = self.norm_depth(depth_features) # 2a. Cross-attention (depth -> rgb). attn_out_depth, _ = self.cross_attn_depth( query=depth, key=rgb, value=rgb, need_weights=False ) # 2b. Cross-attention (rgb -> depth). attn_out_rgb, _ = self.cross_attn_rgb( query=rgb, key=depth, value=depth, need_weights=False ) # 3a. Residuals. depth_attn = depth + attn_out_depth rgb_attn = rgb + attn_out_rgb # 3b. Mixing. fused = self.mixer(torch.cat([depth_attn, rgb_attn], dim=-1)) # 4. Output normalisation. return self.out_norm(fused)