import torch.nn as nn import torch from math import log2 import torch.nn.functional as F from torch import autograd class DCDiscriminator(nn.Module): ''' DC Discriminator class. Args: in_dim (int): input dimension n_feat (int): features of final hidden layer img_size (int): input image size ''' def __init__(self, in_dim=1, out_dim=1, n_feat=512, img_size=256, last_bias=False): super().__init__() self.in_dim = in_dim self.out_dim = out_dim n_layers = int(log2(img_size) - 2) self.blocks = nn.ModuleList( [nn.Conv2d( in_dim, int(n_feat / (2 ** (n_layers - 1))), 4, 2, 1, bias=False)] + [nn.Conv2d( int(n_feat / (2 ** (n_layers - i))), int(n_feat / (2 ** (n_layers - 1 - i))), 4, 2, 1, bias=False) for i in range(1, n_layers)]) self.conv_out = nn.Conv2d(n_feat, out_dim, 4, 1, 0, bias=last_bias) self.actvn = nn.LeakyReLU(0.2, inplace=True) def forward(self, x): batch_size = x.shape[0] if x.shape[1] != self.in_dim: import ipdb; ipdb.set_trace() x = x[:, :self.in_dim] for layer in self.blocks: x = self.actvn(layer(x)) out = self.conv_out(x) out = out.reshape(batch_size, self.out_dim) return out # class ADADiscriminator(DCDiscriminator): # def __init__(self, aug, aug_p, **kwargs): # super().__init__(**kwargs) # self.aug = build_from_config(aug) # self.aug.p.copy_(torch.tensor(aug_p, dtype=torch.float32)) # self.resolution = kwargs['img_size'] # def get_resolution(self): # return self.resolution # def forward(self, x, **kwargs): # x = self.aug(x) # return super().forward(x, **kwargs) # class ADADiscriminatorView(ADADiscriminator): # def __init__(self, out_dim_position, out_dim_latent, **kwargs): # self.out_dim_position = out_dim_position # self.out_dim_latent = out_dim_latent # super().__init__(**kwargs) def bce_loss_target(d_out, target): targets = d_out.new_full(size=d_out.size(), fill_value=target) loss = F.binary_cross_entropy_with_logits(d_out, targets) return loss.mean() def compute_grad2(d_out, x_in): batch_size = x_in.size(0) grad_dout = autograd.grad( outputs=d_out.sum(), inputs=x_in, create_graph=True, retain_graph=True, only_inputs=True )[0] grad_dout2 = grad_dout.pow(2) assert(grad_dout2.size() == x_in.size()) reg = grad_dout2.reshape(batch_size, -1).sum(1) return reg.mean()