Neural Style Transfer

Neural Style Transfer

Neural style transfer (Gatys, Ecker, Bethge 2015): combine the content of one image with the style of another. No model training — just iterative optimization of pixel values against a loss defined over a frozen pretrained CNN.

Content + style → synthesized image.

The key insight

In a pretrained ImageNet CNN:

  • Deeper layer activations capture content.
  • Gram matrices of activations capture style (textures, brush strokes, color palette).

Define a loss matching both; optimize over the synthesized image’s pixels.

Pipeline: forward pass extracts content + style features; backprop into pixels.

Loading content and style

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from PIL import Image
import tensorflow as tf

d2l.set_figsize()
content_img = Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);

style_img = Image.open('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img);

Preprocessing

ImageNet mean/std normalization in, inverse on the way out:

rgb_mean = jnp.array([0.485, 0.456, 0.406])
rgb_std = jnp.array([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    img = img.resize((image_shape[1], image_shape[0]))  # PIL resize is (w, h)
    img = np.array(img, dtype=np.float32) / 255.0  # (H, W, C)
    img = img.transpose(2, 0, 1)  # (C, H, W)
    img = (img - np.array(rgb_mean).reshape(3, 1, 1)) / np.array(
        rgb_std).reshape(3, 1, 1)
    return jnp.expand_dims(jnp.array(img), axis=0)

def postprocess(img):
    img = np.array(img[0])  # (C, H, W)
    img = np.clip(img.transpose(1, 2, 0) * np.array(rgb_std) +
                  np.array(rgb_mean), 0, 1)
    return img

Pretrained VGG-19 feature extractor

Style is a multi-scale phenomenon — match it across several VGG-19 layers (Conv1_1, 2_1, 3_1, 4_1, 5_1). Content is matched at one deeper layer (Conv4_2):

# Load pretrained VGG-19 via TensorFlow (JAX venv does not have torch)
pretrained_net = tf.keras.applications.VGG19(
    weights='imagenet', include_top=False)
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# The `#@tab all` layer indices refer to the torchvision VGG-19 `features`
# list (which interleaves Conv, ReLU, and MaxPool layers).
# TF VGG-19 has a different numbering, so we remap here.
_torch_to_tf = {0: 1, 5: 4, 10: 7, 19: 12, 25: 15, 28: 17}
style_layers = [_torch_to_tf[i] for i in style_layers]
content_layers = [_torch_to_tf[i] for i in content_layers]
# Skip the InputLayer (index 0); keep layers 1..max_needed
_vgg_layers = pretrained_net.layers
net = _vgg_layers[1:max(content_layers + style_layers) + 1]

Feature extractor (cont.)

def extract_features(X_tf, content_layers, style_layers):
    """Extract content and style features using TF VGG-19.

    X_tf is a TF tensor in (N, H, W, C) layout after VGG preprocessing."""
    contents = []
    styles = []
    # net starts at TF layer index 1 (InputLayer skipped)
    for i, layer in enumerate(net, start=1):
        X_tf = layer(X_tf)
        if i in style_layers:
            # TF output is (N,H,W,C) -> convert to (N,C,H,W) for loss fns
            styles.append(tf.transpose(X_tf, (0, 3, 1, 2)))
        if i in content_layers:
            contents.append(tf.transpose(X_tf, (0, 3, 1, 2)))
    return contents, styles

def _to_vgg_input(X_nchw):
    """Convert (N,C,H,W) image tensor (ImageNet-normalised) to VGG input."""
    X_nhwc = tf.transpose(X_nchw, (0, 2, 3, 1))
    X_raw = (X_nhwc * tf.constant(np.array(rgb_std).reshape(1,1,1,3),
                                   dtype=tf.float32)
             + tf.constant(np.array(rgb_mean).reshape(1,1,1,3),
                           dtype=tf.float32)) * 255.0
    return tf.keras.applications.vgg19.preprocess_input(X_raw)
def get_contents(image_shape):
    content_X = preprocess(content_img, image_shape)
    content_X_tf = _to_vgg_input(tf.constant(np.array(content_X)))
    contents_Y, _ = extract_features(content_X_tf, content_layers,
                                     style_layers)
    return content_X, contents_Y

def get_styles(image_shape):
    style_X = preprocess(style_img, image_shape)
    style_X_tf = _to_vgg_input(tf.constant(np.array(style_X)))
    _, styles_Y = extract_features(style_X_tf, content_layers, style_layers)
    return style_X, styles_Y

Content loss

Squared error between content and synthesized features at the content layer:

def content_loss(Y_hat, Y):
    return jnp.square(Y_hat - jax.lax.stop_gradient(Y)).mean()

Style loss

Squared error between Gram matrices of features at each style layer. Gram matrix G = F F^\top captures pairwise channel correlations, discarding spatial location:

def gram(X):
    num_channels, n = X.shape[1], d2l.size(X) // X.shape[1]
    X = d2l.reshape(X, (num_channels, n))
    return d2l.matmul(X, d2l.transpose(X)) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
    return jnp.square(gram(Y_hat) - jax.lax.stop_gradient(gram_Y)).mean()

Total variation loss

Penalizes high-frequency noise; keeps the synthesized image smooth:

def tv_loss(Y_hat):
    return 0.5 * (d2l.reduce_mean(
        d2l.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :])) +
                  d2l.reduce_mean(
        d2l.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1])))

Combined loss

\mathcal{L} = \alpha\, \mathcal{L}_\text{content} + \beta\, \mathcal{L}_\text{style} + \gamma\, \mathcal{L}_\text{tv}.

The relative weights determine the visual style — high \beta pushes towards painterly, low \beta keeps photorealism.

content_weight, style_weight, tv_weight = 1, 1e4, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # Calculate the content, style, and total variance losses respectively
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # Add up all the losses
    l = sum(styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

Initializing the synthesized image

Start from the content image (or noise — converges slower but works). The synthesized image is the optimization variable; the network parameters are frozen:

# In JAX, we optimize the synthesized image array directly (no Module needed)
def get_inits(X, lr, styles_Y):
    # Initialize synthesized image to the content image
    gen_img = jnp.array(X, dtype=jnp.float32)
    styles_Y_gram = [gram(jnp.array(np.array(Y))) for Y in styles_Y]
    return gen_img, styles_Y_gram

Optimization loop

Adam (or LBFGS) optimizes the synthesized image itself. The CNN stays frozen; gradients flow through VGG features back to pixels:

def _tf_gram(X):
    """Gram matrix for a (N,C,H,W) TF tensor."""
    num_channels = tf.shape(X)[1]
    n = tf.cast(tf.reduce_prod(tf.shape(X)) // num_channels, tf.float32)
    X_flat = tf.reshape(X, (tf.shape(X)[0], num_channels, -1))
    return tf.matmul(X_flat, tf.transpose(X_flat, (0, 2, 1))) / (
        tf.cast(num_channels, tf.float32) * n)

def train(X, contents_Y, styles_Y, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram = get_inits(X, lr, styles_Y)
    # Pre-extract content / style targets ONCE outside the loop and keep them
    # as TF constants so the per-step graph never does host round-trips.
    contents_Y_tf = [tf.constant(np.array(y), dtype=tf.float32)
                     for y in contents_Y]
    styles_Y_gram_tf = [tf.constant(np.array(g), dtype=tf.float32)
                        for g in styles_Y_gram]

    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    # Use TF for gradient computation since VGG features are in TF
    X_tf = tf.Variable(np.array(X))
    tf_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    # Compile a single train_step into a TF graph so the per-step Python
    # overhead (and the per-step TF dispatch) is fused into one graph call.
    @tf.function(reduce_retracing=True)
    def train_step(X_tf):
        with tf.GradientTape() as tape:
            X_vgg = _to_vgg_input(X_tf)
            contents_Y_hat_tf, styles_Y_hat_tf = extract_features(
                X_vgg, content_layers, style_layers)
            contents_l_vec = tf.stack([
                tf.reduce_mean(tf.square(yh - y)) * content_weight
                for yh, y in zip(contents_Y_hat_tf, contents_Y_tf)])
            styles_l_vec = tf.stack([
                tf.reduce_mean(tf.square(_tf_gram(yh) - g)) * style_weight
                for yh, g in zip(styles_Y_hat_tf, styles_Y_gram_tf)])
            tv_l = 0.5 * (
                tf.reduce_mean(tf.abs(X_tf[:, :, 1:, :] - X_tf[:, :, :-1, :]))
                + tf.reduce_mean(tf.abs(X_tf[:, :, :, 1:] - X_tf[:, :, :, :-1]))
            ) * tv_weight
            total_loss = (tf.reduce_sum(contents_l_vec)
                          + tf.reduce_sum(styles_l_vec) + tv_l)
        grads = tape.gradient(total_loss, X_tf)
        tf_optimizer.apply_gradients([(grads, X_tf)])
        return contents_l_vec, styles_l_vec, tv_l

    for epoch in range(num_epochs):
        contents_l_vec, styles_l_vec, tv_l = train_step(X_tf)
        # Learning rate decay
        if (epoch + 1) % lr_decay_epoch == 0:
            scale = 0.8 ** ((epoch + 1) // lr_decay_epoch)
            tf_optimizer.learning_rate.assign(lr * scale)
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(
                jnp.array(X_tf.numpy())))
            animator.add(epoch + 1,
                         [float(tf.reduce_sum(contents_l_vec)),
                          float(tf.reduce_sum(styles_l_vec)),
                          float(tv_l)])
    return jnp.array(X_tf.numpy())

Optimization result

After a few hundred iterations, the content layout should remain recognizable while colors and local textures move toward the style image. The three plotted losses are weighted differently, so compare their trends rather than their raw magnitudes:

image_shape = (300, 450)  # PIL Image (h, w)
content_X, contents_Y = get_contents(image_shape)
_, styles_Y = get_styles(image_shape)
output = train(content_X, contents_Y, styles_Y, 0.3, 500, 50)

Recap

  • Style transfer = optimize pixels to minimize a content loss + a Gram-matrix style loss + TV smoothness loss.
  • The CNN is frozen; we backprop into the image, not the weights.
  • Multi-layer style matching is what gives the recognizable texture-on-content look.
  • Modern variants: feedforward style nets (one pass per image), AdaIN, neural style with diffusion models — same idea, faster inference.