RADAR-demo / models /projector.py
arcanoXIII's picture
Upload 13 files
7e08bf1 verified
import torch
class SiameseProjector(torch.nn.Module):
def __init__(self, inner_features = None, act_layer = torch.nn.GELU):
super().__init__()
self.inner_features = inner_features
self.act_fcn = act_layer()
# Localisation branch.
self.input = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
self.projection = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
self.output = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
# Contrastive branch.
self.input_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
self.projection_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
self.output_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
# Localisation head.
self.probe = torch.nn.Conv2d(in_channels=inner_features, out_channels=1, kernel_size=3)
def forward(self, x):
# Localisation branch.
x = self.input(x)
x = self.act_fcn(x)
x = self.projection(x)
x = self.act_fcn(x)
x = self.output(x)
# Localisation head.
seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))
# Contrastive branch.
y = self.input_con(x)
y = self.act_fcn(y)
y = self.projection_con(y)
y = self.act_fcn(y)
# Contrastive head.
feat = self.output_con(y)
return feat, seg
def forward_segmentation(self, x):
# Localisation branch.
x = self.input(x)
x = self.act_fcn(x)
x = self.projection(x)
x = self.act_fcn(x)
x = self.output(x)
# Localisation head.
seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))
return seg