File size: 2,026 Bytes
7e08bf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch

class SiameseProjector(torch.nn.Module):
    def __init__(self, inner_features = None, act_layer = torch.nn.GELU):
        super().__init__()
        
        self.inner_features = inner_features
        self.act_fcn = act_layer()

        # Localisation branch.
        self.input = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
        self.projection = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
        self.output = torch.nn.Linear(in_features=inner_features, out_features=inner_features)

        # Contrastive branch.
        self.input_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
        self.projection_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
        self.output_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)

        # Localisation head.
        self.probe = torch.nn.Conv2d(in_channels=inner_features, out_channels=1, kernel_size=3)

    def forward(self, x):

        # Localisation branch.
        x = self.input(x)
        x = self.act_fcn(x)
        x = self.projection(x)
        x = self.act_fcn(x)
        x = self.output(x)

        # Localisation head.
        seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))

        # Contrastive branch.
        y = self.input_con(x)
        y = self.act_fcn(y)
        y = self.projection_con(y)
        y = self.act_fcn(y)

        # Contrastive head.
        feat = self.output_con(y)
        
        return feat, seg

    def forward_segmentation(self, x):

        # Localisation branch.
        x = self.input(x)
        x = self.act_fcn(x)
        x = self.projection(x)
        x = self.act_fcn(x)
        x = self.output(x)
        
        # Localisation head.
        seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))

        return seg