from d2l import tensorflow as d2l
import tensorflow as tfGenerative Adversarial Networks (Goodfellow et al., 2014) — train a generator G and a discriminator D in a minimax game:
\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}} [\log D(x)] + \mathbb{E}_{z \sim p_z} [\log(1 - D(G(z)))].
At equilibrium, G’s distribution matches the data distribution. No likelihood, no MCMC — just two networks playing against each other.
Noise → generator → samples; discriminator vs real data.
This deck demos a tiny GAN on a 2D Gaussian. The next deck (DCGAN) generates real images.
Import backend utilities and define the plotting helper used to watch the 2D distribution during training:
The “real” data is a 2D Gaussian, so success is visible: generated points should eventually match the same tilted elliptical cloud:
Training batches are iid draws from the target Gaussian. The discriminator only sees samples, not the analytic density:
The covariance matrix is
[[1.01 1.95]
[1.95 4.25]]
The scatter plot is the visual target for the generator. Later training plots should move the generated samples toward this shape:
Tiny MLP: latent z → 2D output. Maps the prior distribution to (hopefully) the data distribution:
Tiny MLP, sigmoid output: 2D point → P(real). Standard binary classifier:
For each batch:
def update_D(X, Z, net_D, net_G, loss, optimizer_D):
"""Update discriminator."""
batch_size = tf.shape(X)[0]
ones = tf.ones((batch_size,)) # Labels corresponding to real data
zeros = tf.zeros((batch_size,)) # Labels corresponding to fake data
# Do not need to compute gradient for `net_G`, so it is outside GradientTape
fake_X = net_G(Z)
with tf.GradientTape() as tape:
real_Y = net_D(X)
fake_Y = net_D(fake_X)
# We multiply the loss by batch_size to match PyTorch's BCEWithLogitsLoss
loss_D = (loss(ones, tf.reshape(real_Y, [-1])) + loss(
zeros, tf.reshape(fake_Y, [-1]))) * tf.cast(
batch_size, tf.float32) / 2
grads_D = tape.gradient(loss_D, net_D.trainable_variables)
optimizer_D.apply_gradients(zip(grads_D, net_D.trainable_variables))
return loss_DSample fresh fakes; update G on \log D(G(z)) (the “non-saturating” form). It gives stronger gradients early in training than directly minimizing \log(1-D(G(z))):
def update_G(Z, net_D, net_G, loss, optimizer_G):
"""Update generator."""
batch_size = tf.shape(Z)[0]
ones = tf.ones((batch_size,))
with tf.GradientTape() as tape:
# We could reuse `fake_X` from `update_D` to save computation
fake_X = net_G(Z)
# Recomputing `fake_Y` is needed since `net_D` is changed
fake_Y = net_D(fake_X)
# We multiply the loss by batch_size to match PyTorch's BCEWithLogits loss
loss_G = loss(ones, tf.reshape(fake_Y, [-1])) * tf.cast(
batch_size, tf.float32)
grads_G = tape.gradient(loss_G, net_G.trainable_variables)
optimizer_G.apply_gradients(zip(grads_G, net_G.trainable_variables))
return loss_GAlternate one discriminator step and one generator step. The losses are useful diagnostics, but the sample plot is the clearest signal that the generator distribution is moving in the right direction:
def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
loss = tf.keras.losses.BinaryCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
for w in net_D.trainable_variables:
w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
for w in net_G.trainable_variables:
w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
optimizer_D = tf.keras.optimizers.Adam(learning_rate=lr_D)
optimizer_G = tf.keras.optimizers.Adam(learning_rate=lr_G)
# Wrap the per-batch updates in `@tf.function` *here* (rather than
# decorating `update_D` / `update_G` in the d2l package) so the
# traced graph closes over the concrete `net_D`, `net_G`,
# `optimizer_*`, `loss` from this scope. That eliminates retraces
# caused by passing those Python objects as @tf.function arguments
# — a bigger deal for deep DCGAN-style nets than for this toy 2-D
# one, but the speedup applies to both.
@tf.function(reduce_retracing=True)
def step_D(X, Z):
return update_D(X, Z, net_D, net_G, loss, optimizer_D)
@tf.function(reduce_retracing=True)
def step_G(Z):
return update_G(Z, net_D, net_G, loss, optimizer_G)
animator = d2l.Animator(
xlabel="epoch", ylabel="loss", xlim=[1, num_epochs], nrows=2,
figsize=(5, 5), legend=["discriminator", "generator"])
animator.fig.subplots_adjust(hspace=0.3)
for epoch in range(num_epochs):
# Train one epoch
timer = d2l.Timer()
metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples
for (X,) in data_iter:
batch_size = X.shape[0]
Z = tf.random.normal(
mean=0, stddev=1, shape=(batch_size, latent_dim))
metric.add(step_D(X, Z),
step_G(Z),
batch_size)
# Visualize generated examples
Z = tf.random.normal(mean=0, stddev=1, shape=(100, latent_dim))
fake_X = net_G(Z)
animator.axes[1].cla()
animator.axes[1].scatter(data[:, 0], data[:, 1])
animator.axes[1].scatter(fake_X[:, 0], fake_X[:, 1])
animator.axes[1].legend(["real", "generated"])
# Show the losses
loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
animator.add(epoch + 1, (loss_D, loss_G))
print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
f'{metric[2] / timer.stop():.1f} examples/sec')The final generated cloud should overlap the target Gaussian. If all samples collapse to a small region, the generator has found a mode-collapse failure instead of matching the distribution:
loss_D 0.693, loss_G 0.694, 1385.1 examples/sec