-
Notifications
You must be signed in to change notification settings - Fork 61
/
ssgan_32.py
executable file
·115 lines (92 loc) · 3.75 KB
/
ssgan_32.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Implementation of SSGAN for image size 32.
"""
import torch
import torch.nn as nn
from torch_mimicry.modules import SNLinear
from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
from torch_mimicry.nets.ssgan import ssgan_base
class SSGANGenerator32(ssgan_base.SSGANBaseGenerator):
r"""
ResNet backbone generator for SSGAN.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
ss_loss_scale (float): Self-supervised loss scale for generator.
"""
def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
# Build the layers
self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
self.b5 = nn.BatchNorm2d(self.ngf)
self.c5 = nn.Conv2d(ngf, 3, 3, 1, padding=1)
self.activation = nn.ReLU(True)
# Initialise the weights
nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
def forward(self, x):
r"""
Feedforwards a batch of noise vectors into a batch of fake images.
Args:
x (Tensor): A batch of noise vectors of shape (N, nz).
Returns:
Tensor: A batch of fake images of shape (N, C, H, W).
"""
h = self.l1(x)
h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.b5(h)
h = self.activation(h)
h = torch.tanh(self.c5(h))
return h
class SSGANDiscriminator32(ssgan_base.SSGANBaseDiscriminator):
r"""
ResNet backbone discriminator for SSGAN.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
ss_loss_scale (float): Self-supervised loss scale for discriminator.
"""
def __init__(self, ndf=128, **kwargs):
super().__init__(ndf=ndf, **kwargs)
# Build layers
self.block1 = DBlockOptimized(3, self.ndf)
self.block2 = DBlock(self.ndf, self.ndf, downsample=True)
self.block3 = DBlock(self.ndf, self.ndf, downsample=False)
self.block4 = DBlock(self.ndf, self.ndf, downsample=False)
self.l5 = SNLinear(self.ndf, 1)
# Rotation class prediction layer
self.l_y = SNLinear(self.ndf, self.num_classes)
# Initialise the weights
nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
nn.init.xavier_uniform_(self.l_y.weight.data, 1.0)
self.activation = nn.ReLU(True)
def forward(self, x):
r"""
Feedforwards a batch of real/fake images and produces a batch of GAN logits,
and rotation classes.
Args:
x (Tensor): A batch of images of shape (N, C, H, W).
Returns:
Tensor: A batch of GAN logits of shape (N, 1).
Tensor: A batch of predicted classes of shape (N, num_classes).
"""
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.activation(h)
# Global sum pooling
h = torch.sum(h, dim=(2, 3))
output = self.l5(h)
# Produce the class output logits
output_classes = self.l_y(h)
return output, output_classes