Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix GAN #1400

Merged
merged 1 commit into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 24 additions & 49 deletions chapter_generative-adversarial-networks/gan.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,54 +56,35 @@ npx.set_np()
from d2l import torch as d2l
import torch
from torch import nn
from torch.utils.data import DataLoader
```

## Generate some "real" data

Since this is going to be the world's lamest example, we simply generate data drawn from a Gaussian.

```{.python .input}
X = np.random.normal(size=(1000, 2))
A = np.array([[1, 2], [-0.1, 0.5]])
b = np.array([1, 2])
data = X.dot(A) + b
```

```{.python .input}
#@tab pytorch
X = torch.normal(0.0, 1, (1000, 2))
A = torch.tensor([[1, 2], [-0.1, 0.5]])
b = torch.tensor([1, 2])
data = torch.mm(X, A) + b
#@tab all
X = d2l.normal(0.0, 1, (1000, 2))
A = d2l.tensor([[1, 2], [-0.1, 0.5]])
b = d2l.tensor([1, 2])
data = d2l.matmul(X, A) + b
```

Let us see what we got. This should be a Gaussian shifted in some rather arbitrary way with mean $b$ and covariance matrix $A^TA$.

```{.python .input}
#@tab all
d2l.set_figsize()
d2l.plt.scatter(data[:100, 0].asnumpy(), data[:100, 1].asnumpy());
print(f'The covariance matrix is\n{np.dot(A.T, A)}')
```

```{.python .input}
#@tab pytorch
d2l.set_figsize()
d2l.plt.scatter(data[:100, 0].numpy(), data[:100, 1].numpy());
print(f'The covariance matrix is\n{torch.mm(A.T, A)}')
d2l.plt.scatter(d2l.numpy(data[:100, 0]), d2l.numpy(data[:100, 1]));
print(f'The covariance matrix is\n{d2l.matmul(A.T, A)}')
```

```{.python .input}
#@tab all
batch_size = 8
data_iter = d2l.load_array((data,), batch_size)
```

```{.python .input}
#@tab pytorch
batch_size = 8
data_iter = DataLoader(data, batch_size=batch_size)
```

## Generator

Our generator network will be the simplest network possible - a single layer linear model. This is since we will be driving that linear network with a Gaussian data generator. Hence, it literally only needs to learn the parameters to fake things perfectly.
Expand Down Expand Up @@ -142,7 +123,8 @@ net_D = nn.Sequential(
First we define a function to update the discriminator.

```{.python .input}
def update_D(X, Z, net_D, net_G, loss, trainer_D): #@save
#@save
def update_D(X, Z, net_D, net_G, loss, trainer_D):
"""Update discriminator."""
batch_size = X.shape[0]
ones = np.ones((batch_size,), ctx=X.ctx)
Expand All @@ -161,11 +143,12 @@ def update_D(X, Z, net_D, net_G, loss, trainer_D): #@save

```{.python .input}
#@tab pytorch
def update_D(X, Z, net_D, net_G, loss, trainer_D): #@save
#@save
def update_D(X, Z, net_D, net_G, loss, trainer_D):
"""Update discriminator."""
batch_size = X.shape[0]
ones = torch.ones((batch_size, 1))
zeros = torch.zeros((batch_size, 1))
ones = torch.ones((batch_size, 1), device=X.device)
zeros = torch.zeros((batch_size, 1), device=X.device)
trainer_D.zero_grad()
real_Y = net_D(X)
fake_X = net_G(Z)
Expand All @@ -181,7 +164,8 @@ def update_D(X, Z, net_D, net_G, loss, trainer_D): #@save
The generator is updated similarly. Here we reuse the cross-entropy loss but change the label of the fake data from $0$ to $1$.

```{.python .input}
def update_G(Z, net_D, net_G, loss, trainer_G): #@save
#@save
def update_G(Z, net_D, net_G, loss, trainer_G):
"""Update generator."""
batch_size = Z.shape[0]
ones = np.ones((batch_size,), ctx=Z.ctx)
Expand All @@ -198,16 +182,17 @@ def update_G(Z, net_D, net_G, loss, trainer_G): #@save

```{.python .input}
#@tab pytorch
def update_G(Z, net_D, net_G, loss, trainer_G): #@save
#@save
def update_G(Z, net_D, net_G, loss, trainer_G):
"""Update generator."""
batch_size = Z.shape[0]
ones = torch.ones((batch_size, 1))
ones = torch.ones((batch_size, 1), device=Z.device)
trainer_G.zero_grad()
# We could reuse `fake_X` from `update_D` to save computation
fake_X = net_G(Z)
# Recomputing `fake_Y` is needed since `net_D` is changed
fake_Y = net_D(fake_X)
loss_G=loss(fake_Y,ones)
loss_G = loss(fake_Y,ones)
loss_G.backward()
trainer_G.step()
return loss_G
Expand Down Expand Up @@ -255,13 +240,11 @@ def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
```{.python .input}
#@tab pytorch
def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
loss = nn.BCEWithLogitsLoss()
loss = nn.BCEWithLogitsLoss(reduction='sum')
for w in net_D.parameters():
nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
nn.init.normal_(w, 0, 0.02)
net_D.zero_grad()
net_G.zero_grad()
trainer_D = torch.optim.Adam(net_D.parameters(), lr=lr_D)
trainer_G = torch.optim.Adam(net_G.parameters(), lr=lr_G)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
Expand All @@ -272,11 +255,9 @@ def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
# Train one epoch
timer = d2l.Timer()
metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples
for X in data_iter:
for (X,) in data_iter:
batch_size = X.shape[0]
Z = torch.normal(0, 1, size=(batch_size, latent_dim))
trainer_D.zero_grad()
trainer_G.zero_grad()
metric.add(update_D(X, Z, net_D, net_G, loss, trainer_D),
update_G(Z, net_D, net_G, loss, trainer_G),
batch_size)
Expand All @@ -297,13 +278,7 @@ def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
Now we specify the hyperparameters to fit the Gaussian distribution.

```{.python .input}
lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
latent_dim, d2l.numpy(data[:100]))
```

```{.python .input}
#@tab pytorch
#@tab all
lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
latent_dim, d2l.numpy(data[:100]))
Expand Down
2 changes: 1 addition & 1 deletion d2l/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,7 @@ def update_D(X, Z, net_D, net_G, loss, trainer_D):


# Defined in file: ./chapter_generative-adversarial-networks/gan.md
def update_G(Z, net_D, net_G, loss, trainer_G): #@save
def update_G(Z, net_D, net_G, loss, trainer_G):
"""Update generator."""
batch_size = Z.shape[0]
ones = np.ones((batch_size,), ctx=Z.ctx)
Expand Down
36 changes: 35 additions & 1 deletion d2l/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def __init__(self, vocab_size, num_hiddens, device,
self.init_state, self.forward_fn = init_state, forward

def __call__(self, X, state):
X = F.one_hot(X.T.long(), self.vocab_size).type(torch.float32)
X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
return self.forward_fn(X, state, self.params)

def begin_state(self, batch_size, device):
Expand Down Expand Up @@ -1219,6 +1219,40 @@ def bbox_to_rect(bbox, color):
fill=False, edgecolor=color, linewidth=2)


# Defined in file: ./chapter_generative-adversarial-networks/gan.md
def update_D(X, Z, net_D, net_G, loss, trainer_D):
"""Update discriminator."""
batch_size = X.shape[0]
ones = torch.ones((batch_size, 1), device=X.device)
zeros = torch.zeros((batch_size, 1), device=X.device)
trainer_D.zero_grad()
real_Y = net_D(X)
fake_X = net_G(Z)
# Do not need to compute gradient for `net_G`, detach it from
# computing gradients.
fake_Y = net_D(fake_X.detach())
loss_D = (loss(real_Y, ones) + loss(fake_Y, zeros)) / 2
loss_D.backward()
trainer_D.step()
return loss_D


# Defined in file: ./chapter_generative-adversarial-networks/gan.md
def update_G(Z, net_D, net_G, loss, trainer_G):
"""Update generator."""
batch_size = Z.shape[0]
ones = torch.ones((batch_size, 1), device=Z.device)
trainer_G.zero_grad()
# We could reuse `fake_X` from `update_D` to save computation
fake_X = net_G(Z)
# Recomputing `fake_Y` is needed since `net_D` is changed
fake_Y = net_D(fake_X)
loss_G = loss(fake_Y,ones)
loss_G.backward()
trainer_G.step()
return loss_G


# Alias defined in config.ini


Expand Down