Real Data Distribution

Generative Adversarial Networks

Generative Adversarial Networks

Generative Adversarial Networks (Goodfellow et al., 2014) — train a generator G and a discriminator D in a minimax game:

\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}} [\log D(x)] + \mathbb{E}_{z \sim p_z} [\log(1 - D(G(z)))].

  • D tries to distinguish real samples from fake.
  • G tries to produce samples that fool D.

At equilibrium, G’s distribution matches the data distribution. No likelihood, no MCMC — just two networks playing against each other.

The GAN architecture

Noise → generator → samples; discriminator vs real data.

This deck demos a tiny GAN on a 2D Gaussian. The next deck (DCGAN) generates real images.

Setup

Import backend utilities and define the plotting helper used to watch the 2D distribution during training:

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np

The “real” data is a 2D Gaussian, so success is visible: generated points should eventually match the same tilted elliptical cloud:

X = jax.random.normal(jax.random.PRNGKey(0), (1000, 2))
A = jnp.array([[1, 2], [-0.1, 0.5]])
b = jnp.array([1, 2])
data = jnp.matmul(X, A) + b

Sampling Real Batches

Training batches are iid draws from the target Gaussian. The discriminator only sees samples, not the analytic density:

d2l.set_figsize()
d2l.plt.scatter(np.array(data[:100, 0]), np.array(data[:100, 1]));
print(f'The covariance matrix is\n{jnp.matmul(A.T, A)}')

The covariance matrix is
[[1.01 1.95]
 [1.95 4.25]]

Inspecting Real Samples

The scatter plot is the visual target for the generator. Later training plots should move the generated samples toward this shape:

batch_size = 8
data_iter = d2l.load_array((data,), batch_size)

Generator

Tiny MLP: latent z → 2D output. Maps the prior distribution to (hopefully) the data distribution:

class Generator(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Dense(2)(x)

net_G = Generator()

Discriminator

Tiny MLP, sigmoid output: 2D point → P(real). Standard binary classifier:

class Discriminator(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.tanh(nn.Dense(5)(x))
        x = nn.tanh(nn.Dense(3)(x))
        return nn.Dense(1)(x)

net_D = Discriminator()

Discriminator Update

For each batch:

  1. Sample fake G(z), real x. Update D on \log D(x) + \log(1 - D(G(z))).
from functools import partial

@partial(jax.jit, static_argnames=('net_D', 'net_G', 'optimizer_D'))
def update_D(X, Z, net_D, net_G, params_D, params_G, loss_fn, opt_state_D,
             optimizer_D):
    """Update discriminator."""
    batch_size = X.shape[0]
    ones = jnp.ones((batch_size,))
    zeros = jnp.zeros((batch_size,))
    # Do not need to compute gradient for `net_G`
    fake_X = net_G.apply(params_G, Z)
    def loss_D_fn(params_D):
        real_Y = net_D.apply(params_D, X).squeeze()
        fake_Y = net_D.apply(params_D, fake_X).squeeze()
        loss_D = (jnp.sum(optax.sigmoid_binary_cross_entropy(real_Y, ones)) +
                  jnp.sum(optax.sigmoid_binary_cross_entropy(fake_Y, zeros))
                  ) / 2
        return loss_D
    loss_D, grads_D = jax.value_and_grad(loss_D_fn)(params_D)
    updates, opt_state_D = optimizer_D.update(grads_D, opt_state_D, params_D)
    params_D = optax.apply_updates(params_D, updates)
    return loss_D, params_D, opt_state_D

Generator Update

Sample fresh fakes; update G on \log D(G(z)) (the “non-saturating” form). It gives stronger gradients early in training than directly minimizing \log(1-D(G(z))):

@partial(jax.jit, static_argnames=('net_D', 'net_G', 'optimizer_G'))
def update_G(Z, net_D, net_G, params_D, params_G, loss_fn, opt_state_G,
             optimizer_G):
    """Update generator."""
    batch_size = Z.shape[0]
    ones = jnp.ones((batch_size,))
    def loss_G_fn(params_G):
        # We could reuse `fake_X` from `update_D` to save computation
        fake_X = net_G.apply(params_G, Z)
        # Recomputing `fake_Y` is needed since `net_D` is changed
        fake_Y = net_D.apply(params_D, fake_X).squeeze()
        loss_G = jnp.sum(optax.sigmoid_binary_cross_entropy(fake_Y, ones))
        return loss_G
    loss_G, grads_G = jax.value_and_grad(loss_G_fn)(params_G)
    updates, opt_state_G = optimizer_G.update(grads_G, opt_state_G, params_G)
    params_G = optax.apply_updates(params_G, updates)
    return loss_G, params_G, opt_state_G

Training loop

Alternate one discriminator step and one generator step. The losses are useful diagnostics, but the sample plot is the clearest signal that the generator distribution is moving in the right direction:

def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
    key = jax.random.PRNGKey(42)
    key, key_D, key_G = jax.random.split(key, 3)
    # Initialize parameters
    dummy = jnp.ones((1, 2))
    params_D = net_D.init(key_D, dummy)
    params_G = net_G.init(key_G, dummy)
    # Reinitialize with normal(0, 0.02). Use one subkey per leaf so that
    # different parameter tensors aren't drawn from the same RNG state.
    leaves_D, treedef_D = jax.tree_util.tree_flatten(params_D)
    keys_D = jax.random.split(key_D, len(leaves_D))
    params_D = jax.tree_util.tree_unflatten(
        treedef_D,
        [jax.random.normal(k, p.shape) * 0.02
         for k, p in zip(keys_D, leaves_D)])
    leaves_G, treedef_G = jax.tree_util.tree_flatten(params_G)
    keys_G = jax.random.split(key_G, len(leaves_G))
    params_G = jax.tree_util.tree_unflatten(
        treedef_G,
        [jax.random.normal(k, p.shape) * 0.02
         for k, p in zip(keys_G, leaves_G)])
    optimizer_D = optax.adam(lr_D)
    optimizer_G = optax.adam(lr_G)
    opt_state_D = optimizer_D.init(params_D)
    opt_state_G = optimizer_G.init(params_G)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)
    for epoch in range(num_epochs):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for (X,) in data_iter:
            X = jnp.array(X)
            batch_size = X.shape[0]
            key, subkey = jax.random.split(key)
            Z = jax.random.normal(subkey, (batch_size, latent_dim))
            loss_D, params_D, opt_state_D = update_D(
                X, Z, net_D, net_G, params_D, params_G, None,
                opt_state_D, optimizer_D)
            loss_G, params_G, opt_state_G = update_G(
                Z, net_D, net_G, params_D, params_G, None,
                opt_state_G, optimizer_G)
            metric.add(loss_D, loss_G, batch_size)
        # Visualize generated examples
        key, subkey = jax.random.split(key)
        Z = jax.random.normal(subkey, (100, latent_dim))
        fake_X = np.array(net_G.apply(params_G, Z))
        animator.axes[1].cla()
        animator.axes[1].scatter(data[:, 0], data[:, 1])
        animator.axes[1].scatter(fake_X[:, 0], fake_X[:, 1])
        animator.axes[1].legend(['real', 'generated'])
        # Show the losses
        loss_D, loss_G = metric[0]/metric[2], metric[1]/metric[2]
        animator.add(epoch + 1, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec')

Training Run

The final generated cloud should overlap the target Gaussian. If all samples collapse to a small region, the generator has found a mode-collapse failure instead of matching the distribution:

lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
      latent_dim, d2l.numpy(data[:100]))

loss_D 0.693, loss_G 0.693, 2147.5 examples/sec

Recap

  • GAN = generator + discriminator playing a minimax game.
  • Equilibrium: G’s distribution = data distribution, D’s output is 1/2 everywhere.
  • Notoriously tricky to train: mode collapse, vanishing gradients early on, training instability.
  • Modern variants (WGAN, WGAN-GP, StyleGAN, BigGAN) fix pieces of this; the core minimax idea stays.