The hot-dog dataset

Fine-Tuning

Fine-Tuning

You’ll rarely train a vision model from scratch. Transfer learning — start from weights pretrained on a big dataset (ImageNet) and adapt to your small one — is the default recipe.

Fine-tuning: pretrained backbone + new task-specific head.

The standard recipe

  1. Take a pretrained network (ResNet, ViT, etc.).
  2. Replace the output layer with a head for your task.
  3. Optionally freeze early layers; train the rest.
  4. Small LR on the pretrained part, larger LR on the new head.

Setup

%matplotlib inline
import os
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import flax.linen as fnn
import optax
import flaxmodels as fm
import numpy as np
import tensorflow as tf  # only used for tf.data input pipeline

A tiny binary classification dataset (hot dog / not hot dog) — too small to train a CNN from scratch, perfect for transfer learning:

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip', 
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
# Load images as (PIL.Image, label) lists for compatibility with show_images
from PIL import Image as _PILImage
import pathlib

def _load_image_folder(path):
    """Load images from a directory with class subfolders, returning
    a list of (PIL.Image, class_index) tuples."""
    path = pathlib.Path(path)
    class_names = sorted([p.name for p in path.iterdir() if p.is_dir()])
    class_to_idx = {c: i for i, c in enumerate(class_names)}
    items = []
    for cls in class_names:
        for img_path in sorted((path / cls).iterdir()):
            try:
                img = _PILImage.open(str(img_path)).convert('RGB')
                items.append((img, class_to_idx[cls]))
            except Exception:
                continue
    return items

train_imgs = _load_image_folder(os.path.join(data_dir, 'train'))
test_imgs = _load_image_folder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

Augmentation pipelines

Standard ImageNet recipe — random resized crop + flip for training, center crop for eval. Match the preprocessing convention that the pretrained model expects:

# Image preprocessing. We use `tf.image` ops so the pipeline can run
# inside `tf.data.Dataset.map`. The ImageNet RGB mean/std normalization
# matches the preprocessing expected by the `flaxmodels` pretrained
# ResNet-18 weights (and the PyTorch/MXNet tabs).
IMG_SIZE = 224
_IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
_IMAGENET_STD  = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)

def _normalize(x):
    return (tf.cast(x, tf.float32) / 255.0 - _IMAGENET_MEAN) / _IMAGENET_STD

def train_preprocess(x):
    # `x` is a (256, 256, 3) float32 RGB image with values in [0, 255].
    x = tf.image.random_crop(x, size=(IMG_SIZE, IMG_SIZE, 3))
    x = tf.image.random_flip_left_right(x)
    return _normalize(x)

def test_preprocess(x):
    x = tf.image.resize_with_crop_or_pad(x, IMG_SIZE, IMG_SIZE)
    return _normalize(x)

Inspect the pretrained head

The source model was trained for 1000 ImageNet classes. Its convolutional body is reusable; the final classifier is task-specific and will be replaced:

# Load a pretrained ResNet-18 (Flax) with ImageNet weights via flaxmodels.
pretrained_net = fm.ResNet18(output='logits', pretrained='imagenet',
                             normalize=False)
# Initialize to materialize parameters (and download/cache the pretrained
# weights). The `init` call returns both `params` and `batch_stats`.
_init_key = jax.random.PRNGKey(0)
_dummy = jnp.zeros((1, IMG_SIZE, IMG_SIZE, 3), dtype=jnp.float32)
pretrained_vars = pretrained_net.init(_init_key, _dummy, train=False)

Replace the task head

Create a target model with the same pretrained backbone and a randomly initialized 2-way classifier for hot dog vs. not hot dog:

# The 1000-way ImageNet classifier is the final Dense layer of the network.
pretrained_vars['params']['Dense_0']['kernel'].shape
(512, 1000)

Discriminative learning rates

Let \theta_b be pretrained backbone parameters and \theta_h the new head. Use a small step on \theta_b and a larger one on \theta_h:

\eta_b = \eta,\qquad \eta_h = 10\eta.

# Pretrained ResNet-18 backbone + a fresh 2-way classification head.
# The backbone returns the dictionary of intermediate activations; we use
# the final 7x7x512 feature map (`block4_1`) and globally average-pool it.
class FineTuneResNet18(fnn.Module):
    num_classes: int = 2
    @fnn.compact
    def __call__(self, x, train: bool):
        backbone = fm.ResNet18(output='activations', pretrained='imagenet',
                               normalize=False)
        # Keep ImageNet BatchNorm statistics fixed for the tiny target set.
        feats = backbone(x, train=False)['block4_1']  # (B, 7, 7, 512)
        feats = jnp.mean(feats, axis=(1, 2))          # global avg pool -> (B, 512)
        logits = fnn.Dense(self.num_classes,
                           kernel_init=fnn.initializers.glorot_uniform(),
                           name='classifier')(feats)
        return logits

finetune_net = FineTuneResNet18(num_classes=2)
# Initialize the wrapper. The backbone sub-module loads ImageNet weights
# from the flaxmodels checkpoint during init; only the new `classifier`
# Dense layer is randomly initialized.
finetune_vars = finetune_net.init(jax.random.PRNGKey(1), _dummy, train=False)

Training helper

The helper hides framework details: parameter groups, optimizer construction, metric logging, and the scratch/fine-tune switch. The four-step pattern is:

  • build the pretrained backbone and new head;
  • assign a small learning rate to backbone parameters;
  • assign a larger learning rate to the randomly initialized head;
  • train and compare against a scratch baseline.

Run fine-tuning

With matched ImageNet preprocessing and a small base LR, the pretrained model should reach useful accuracy quickly. The point is not just a better final score; it is much less data and compute than training the same network cold.

print('fine-tuned model')
finetune_vars = train_fine_tuning(finetune_net, finetune_vars, 1e-4)
fine-tuned model
epoch 1, loss 0.834, train acc 0.508, test acc 0.688
epoch 2, loss 0.566, train acc 0.715, test acc 0.772
epoch 3, loss 0.459, train acc 0.796, test acc 0.824
epoch 4, loss 0.391, train acc 0.836, test acc 0.862
epoch 5, loss 0.354, train acc 0.853, test acc 0.880

From-scratch baseline

Same architecture, no pretraining. Much worse on this small dataset — illustrates why transfer learning is the default:

print('scratch baseline')
# Train from scratch: same architecture but with random weights.
class ScratchResNet18(fnn.Module):
    num_classes: int = 2
    @fnn.compact
    def __call__(self, x, train: bool):
        backbone = fm.ResNet18(output='activations', pretrained=None,
                               normalize=False)
        feats = backbone(x, train=train)['block4_1']
        feats = jnp.mean(feats, axis=(1, 2))
        return fnn.Dense(self.num_classes,
                         kernel_init=fnn.initializers.glorot_uniform(),
                         name='classifier')(feats)

scratch_net = ScratchResNet18(num_classes=2)
scratch_vars = scratch_net.init(jax.random.PRNGKey(2), _dummy, train=False)
scratch_vars = train_fine_tuning(scratch_net, scratch_vars, 5e-4,
                                 param_group=False)
scratch baseline
epoch 1, loss 0.736, train acc 0.520, test acc 0.643
epoch 2, loss 0.625, train acc 0.686, test acc 0.764
epoch 3, loss 0.579, train acc 0.757, test acc 0.803
epoch 4, loss 0.545, train acc 0.778, test acc 0.824
epoch 5, loss 0.508, train acc 0.797, test acc 0.829

What to vary

The natural ablations are: freeze more or fewer layers, change the backbone/head learning-rate ratio, and compare against the source ImageNet “hotdog” class weights.

# Freeze the pretrained ResNet-18 backbone; only the new `classifier` head
# is updated by setting the optimizer learning rate of every other parameter
# to zero. For example, modify `train_fine_tuning` to use:
#   optax.multi_transform(
#       {'head': optax.sgd(lr * 10, momentum=0.9),
#        'base': optax.set_to_zero()},
#       labels)
# The pretrained classifier maps 512-dim features to 1000 ImageNet classes.
weight = pretrained_vars['params']['Dense_0']['kernel']  # Shape: (512, 1000)
hotdog_w = weight[:, 934]
hotdog_w.shape
(512,)

Recap

  • Transfer learning: pretrained backbone + new head; almost always beats from-scratch on small / medium datasets.
  • Use small LR on the backbone (10×–100× smaller than the head LR) — pretrained features need only nudges.
  • Match input preprocessing (mean/std normalization, input size, or model-specific preprocess_input) to what the pretrained model expects.
  • Modern variants: feature-extractor mode (freeze everything but head), full fine-tune (everything trains), parameter-efficient methods (LoRA, adapters).