import torch import torchvision import torch.nn as nn import numpy as np from utils.hrnet import hrnet_w32 class Encoder(nn.Module): def __init__(self, encoder='hrnet', pretrained=True): super(Encoder, self).__init__() if encoder == 'swin': '''Swin Transformer encoder''' self.encoder = torchvision.models.swin_b(weights='DEFAULT') self.encoder.head = nn.GELU() elif encoder == 'hrnet': '''HRNet encoder''' self.encoder = hrnet_w32(pretrained=pretrained) else: raise NotImplementedError('Encoder not implemented') def forward(self, x): out = self.encoder(x) return out class Self_Attn(nn.Module): """ Self attention Layer for Feature Map dimension""" def __init__(self, in_dim, out_dim): super(Self_Attn, self).__init__() self.channel_in = in_dim self.query_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1) self.key_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1) self.value_conv = nn.Conv1d(in_channels = in_dim, out_channels = out_dim, kernel_size = 1) self.softmax = nn.Softmax(dim = -1) def forward(self, q, k, v): """ inputs : x : input feature maps(B X C X H X W) returns : out : self attention value + input feature attention: B X N X N (N is Height * Width) """ batchsize, C, height = q.size() # proj_query: reshape to B x N x c, N = H x W proj_query = self.query_conv(q.permute(0, 2, 1)) # proj_query: reshape to B x c x N, N = H x W proj_key = self.key_conv(k.permute(0, 2, 1)) # transpose check, energy: B x N x N, N = H x W energy = torch.bmm(proj_query, proj_key.permute(0, 2, 1)) # attention: B x N x N, N = H x W attention = self.softmax(energy) # proj_value is normal convolution, B x C x N proj_value = self.value_conv(v.permute(0, 2, 1)) # out: B x C x N out = torch.bmm(attention, proj_value) out = out.view(batchsize, C, height) out = out/np.sqrt(self.channel_in) return out class Cross_Att(nn.Module): def __init__(self, in_dim, out_dim): super(Cross_Att, self).__init__() self.cross_attn_1 = Self_Attn(in_dim, out_dim) self.cross_attn_2 = Self_Attn(in_dim, out_dim) self.layer_norm = nn.LayerNorm([1, in_dim]) def forward(self, sem_seg, part_seg): cross1 = self.cross_attn_1(sem_seg, part_seg, part_seg) cross2 = self.cross_attn_1(part_seg, sem_seg, sem_seg) out = cross1 * cross2 out = self.layer_norm(out) return out class Decoder(nn.Module): def __init__(self, in_dim, out_dim, encoder='hrnet'): super(Decoder, self).__init__() self.out_dim = out_dim if encoder == 'swin': self.upsample = nn.Sequential( nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(out_dim), nn.ReLU(), nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(out_dim), nn.ReLU(), nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(out_dim), nn.Softmax(1) ) elif encoder == 'hrnet': self.upsample = nn.Sequential( nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(out_dim), nn.ReLU(), nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(out_dim), # nn.ReLU(), # nn.ConvTranspose2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), # nn.BatchNorm2d(out_dim), nn.Softmax(1) ) else: raise NotImplementedError('Decoder not implemented') def forward(self, x): out = self.upsample(x) return out class Classifier(nn.Module): def __init__(self, in_dim, out_dim=6890): super(Classifier, self).__init__() self.out_dim = out_dim self.classifier = nn.Sequential( nn.Linear(in_dim, 4096, True), nn.ReLU(), nn.Linear(4096, out_dim, True), nn.Sigmoid() ) def forward(self, x): out = self.classifier(x) return out.reshape(-1, self.out_dim)