import torch class SiameseProjector(torch.nn.Module): def __init__(self, inner_features = None, act_layer = torch.nn.GELU): super().__init__() self.inner_features = inner_features self.act_fcn = act_layer() # Localisation branch. self.input = torch.nn.Linear(in_features=inner_features, out_features=inner_features) self.projection = torch.nn.Linear(in_features=inner_features, out_features=inner_features) self.output = torch.nn.Linear(in_features=inner_features, out_features=inner_features) # Contrastive branch. self.input_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features) self.projection_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features) self.output_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features) # Localisation head. self.probe = torch.nn.Conv2d(in_channels=inner_features, out_channels=1, kernel_size=3) def forward(self, x): # Localisation branch. x = self.input(x) x = self.act_fcn(x) x = self.projection(x) x = self.act_fcn(x) x = self.output(x) # Localisation head. seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5))) # Contrastive branch. y = self.input_con(x) y = self.act_fcn(y) y = self.projection_con(y) y = self.act_fcn(y) # Contrastive head. feat = self.output_con(y) return feat, seg def forward_segmentation(self, x): # Localisation branch. x = self.input(x) x = self.act_fcn(x) x = self.projection(x) x = self.act_fcn(x) x = self.output(x) # Localisation head. seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5))) return seg