%matplotlib inline
import os
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import flax.linen as fnn
import optax
import flaxmodels as fm
import numpy as np
import tensorflow as tf # only used for tf.data input pipelineYou’ll rarely train a vision model from scratch. Transfer learning — start from weights pretrained on a big dataset (ImageNet) and adapt to your small one — is the default recipe.
Fine-tuning: pretrained backbone + new task-specific head.
A tiny binary classification dataset (hot dog / not hot dog) — too small to train a CNN from scratch, perfect for transfer learning:
# Load images as (PIL.Image, label) lists for compatibility with show_images
from PIL import Image as _PILImage
import pathlib
def _load_image_folder(path):
"""Load images from a directory with class subfolders, returning
a list of (PIL.Image, class_index) tuples."""
path = pathlib.Path(path)
class_names = sorted([p.name for p in path.iterdir() if p.is_dir()])
class_to_idx = {c: i for i, c in enumerate(class_names)}
items = []
for cls in class_names:
for img_path in sorted((path / cls).iterdir()):
try:
img = _PILImage.open(str(img_path)).convert('RGB')
items.append((img, class_to_idx[cls]))
except Exception:
continue
return items
train_imgs = _load_image_folder(os.path.join(data_dir, 'train'))
test_imgs = _load_image_folder(os.path.join(data_dir, 'test'))Standard ImageNet recipe — random resized crop + flip for training, center crop for eval. Match the preprocessing convention that the pretrained model expects:
# Image preprocessing. We use `tf.image` ops so the pipeline can run
# inside `tf.data.Dataset.map`. The ImageNet RGB mean/std normalization
# matches the preprocessing expected by the `flaxmodels` pretrained
# ResNet-18 weights (and the PyTorch/MXNet tabs).
IMG_SIZE = 224
_IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
_IMAGENET_STD = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)
def _normalize(x):
return (tf.cast(x, tf.float32) / 255.0 - _IMAGENET_MEAN) / _IMAGENET_STD
def train_preprocess(x):
# `x` is a (256, 256, 3) float32 RGB image with values in [0, 255].
x = tf.image.random_crop(x, size=(IMG_SIZE, IMG_SIZE, 3))
x = tf.image.random_flip_left_right(x)
return _normalize(x)
def test_preprocess(x):
x = tf.image.resize_with_crop_or_pad(x, IMG_SIZE, IMG_SIZE)
return _normalize(x)The source model was trained for 1000 ImageNet classes. Its convolutional body is reusable; the final classifier is task-specific and will be replaced:
# Load a pretrained ResNet-18 (Flax) with ImageNet weights via flaxmodels.
pretrained_net = fm.ResNet18(output='logits', pretrained='imagenet',
normalize=False)
# Initialize to materialize parameters (and download/cache the pretrained
# weights). The `init` call returns both `params` and `batch_stats`.
_init_key = jax.random.PRNGKey(0)
_dummy = jnp.zeros((1, IMG_SIZE, IMG_SIZE, 3), dtype=jnp.float32)
pretrained_vars = pretrained_net.init(_init_key, _dummy, train=False)Create a target model with the same pretrained backbone and a randomly initialized 2-way classifier for hot dog vs. not hot dog:
(512, 1000)
Let \theta_b be pretrained backbone parameters and \theta_h the new head. Use a small step on \theta_b and a larger one on \theta_h:
\eta_b = \eta,\qquad \eta_h = 10\eta.
# Pretrained ResNet-18 backbone + a fresh 2-way classification head.
# The backbone returns the dictionary of intermediate activations; we use
# the final 7x7x512 feature map (`block4_1`) and globally average-pool it.
class FineTuneResNet18(fnn.Module):
num_classes: int = 2
@fnn.compact
def __call__(self, x, train: bool):
backbone = fm.ResNet18(output='activations', pretrained='imagenet',
normalize=False)
# Keep ImageNet BatchNorm statistics fixed for the tiny target set.
feats = backbone(x, train=False)['block4_1'] # (B, 7, 7, 512)
feats = jnp.mean(feats, axis=(1, 2)) # global avg pool -> (B, 512)
logits = fnn.Dense(self.num_classes,
kernel_init=fnn.initializers.glorot_uniform(),
name='classifier')(feats)
return logits
finetune_net = FineTuneResNet18(num_classes=2)
# Initialize the wrapper. The backbone sub-module loads ImageNet weights
# from the flaxmodels checkpoint during init; only the new `classifier`
# Dense layer is randomly initialized.
finetune_vars = finetune_net.init(jax.random.PRNGKey(1), _dummy, train=False)The helper hides framework details: parameter groups, optimizer construction, metric logging, and the scratch/fine-tune switch. The four-step pattern is:
With matched ImageNet preprocessing and a small base LR, the pretrained model should reach useful accuracy quickly. The point is not just a better final score; it is much less data and compute than training the same network cold.
fine-tuned model
epoch 1, loss 0.834, train acc 0.508, test acc 0.688
epoch 2, loss 0.566, train acc 0.715, test acc 0.772
epoch 3, loss 0.459, train acc 0.796, test acc 0.824
epoch 4, loss 0.391, train acc 0.836, test acc 0.862
epoch 5, loss 0.354, train acc 0.853, test acc 0.880
Same architecture, no pretraining. Much worse on this small dataset — illustrates why transfer learning is the default:
print('scratch baseline')
# Train from scratch: same architecture but with random weights.
class ScratchResNet18(fnn.Module):
num_classes: int = 2
@fnn.compact
def __call__(self, x, train: bool):
backbone = fm.ResNet18(output='activations', pretrained=None,
normalize=False)
feats = backbone(x, train=train)['block4_1']
feats = jnp.mean(feats, axis=(1, 2))
return fnn.Dense(self.num_classes,
kernel_init=fnn.initializers.glorot_uniform(),
name='classifier')(feats)
scratch_net = ScratchResNet18(num_classes=2)
scratch_vars = scratch_net.init(jax.random.PRNGKey(2), _dummy, train=False)
scratch_vars = train_fine_tuning(scratch_net, scratch_vars, 5e-4,
param_group=False)scratch baseline
epoch 1, loss 0.736, train acc 0.520, test acc 0.643
epoch 2, loss 0.625, train acc 0.686, test acc 0.764
epoch 3, loss 0.579, train acc 0.757, test acc 0.803
epoch 4, loss 0.545, train acc 0.778, test acc 0.824
epoch 5, loss 0.508, train acc 0.797, test acc 0.829
The natural ablations are: freeze more or fewer layers, change the backbone/head learning-rate ratio, and compare against the source ImageNet “hotdog” class weights.
# Freeze the pretrained ResNet-18 backbone; only the new `classifier` head
# is updated by setting the optimizer learning rate of every other parameter
# to zero. For example, modify `train_fine_tuning` to use:
# optax.multi_transform(
# {'head': optax.sgd(lr * 10, momentum=0.9),
# 'base': optax.set_to_zero()},
# labels)preprocess_input) to what the pretrained model expects.