Spaces:
Running
Running
| import torch | |
| class SiameseProjector(torch.nn.Module): | |
| def __init__(self, inner_features = None, act_layer = torch.nn.GELU): | |
| super().__init__() | |
| self.inner_features = inner_features | |
| self.act_fcn = act_layer() | |
| # Localisation branch. | |
| self.input = torch.nn.Linear(in_features=inner_features, out_features=inner_features) | |
| self.projection = torch.nn.Linear(in_features=inner_features, out_features=inner_features) | |
| self.output = torch.nn.Linear(in_features=inner_features, out_features=inner_features) | |
| # Contrastive branch. | |
| self.input_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features) | |
| self.projection_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features) | |
| self.output_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features) | |
| # Localisation head. | |
| self.probe = torch.nn.Conv2d(in_channels=inner_features, out_channels=1, kernel_size=3) | |
| def forward(self, x): | |
| # Localisation branch. | |
| x = self.input(x) | |
| x = self.act_fcn(x) | |
| x = self.projection(x) | |
| x = self.act_fcn(x) | |
| x = self.output(x) | |
| # Localisation head. | |
| seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5))) | |
| # Contrastive branch. | |
| y = self.input_con(x) | |
| y = self.act_fcn(y) | |
| y = self.projection_con(y) | |
| y = self.act_fcn(y) | |
| # Contrastive head. | |
| feat = self.output_con(y) | |
| return feat, seg | |
| def forward_segmentation(self, x): | |
| # Localisation branch. | |
| x = self.input(x) | |
| x = self.act_fcn(x) | |
| x = self.projection(x) | |
| x = self.act_fcn(x) | |
| x = self.output(x) | |
| # Localisation head. | |
| seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5))) | |
| return seg | |