import collections
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import tensorflow as tf
import math
import os
import pandas as pd
import shutilA capstone deck: assemble everything from the chapter (augmentation, fine-tuning, modern CNN architectures) and take a Kaggle competition. CIFAR-10 has been done to death, but it’s the right size for a teaching example — small enough to fit in memory, big enough that augmentation and ensembling matter.
Kaggle CIFAR-10 competition page.
Tiny demo subset for the book; swap in the full dataset for the actual competition:
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
# If you use the full dataset downloaded for the Kaggle competition, set
# `demo` to False
demo = True
if demo:
data_dir = d2l.download_extract('cifar10_tiny')
else:
data_dir = '../data/cifar-10/'Kaggle ships everything in one folder; standard torchvision-style training expects train/<class>/img.png. Build that layout from the labels.csv:
def read_csv_labels(fname):
"""Read `fname` to return a filename to label dictionary."""
with open(fname, 'r') as f:
# Skip the file header line (column name)
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('# training examples:', len(labels))
print('# classes:', len(set(labels.values())))# training examples: 1000
# classes: 10
def copyfile(filename, target_dir):
"""Copy a file into a target directory."""
os.makedirs(target_dir, exist_ok=True)
shutil.copy(filename, target_dir)
def reorg_train_valid(data_dir, labels, valid_ratio):
"""Split the validation set out of the original training set."""
# The number of examples of the class that has the fewest examples in the
# training dataset
n = collections.Counter(labels.values()).most_common()[-1][1]
# The number of examples per class for the validation set
n_valid_per_label = max(1, math.floor(n * valid_ratio))
label_count = {}
for train_file in os.listdir(os.path.join(data_dir, 'train')):
label = labels[train_file.split('.')[0]]
fname = os.path.join(data_dir, 'train', train_file)
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train_valid', label))
if label not in label_count or label_count[label] < n_valid_per_label:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'valid', label))
label_count[label] = label_count.get(label, 0) + 1
else:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train', label))
return n_valid_per_labelStandard recipe — random crop, flip, normalize for train; just normalize for eval:
CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR_STD = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
def transform_train_fn(image, label):
"""Training augmentation: resize, random crop, flip, normalize."""
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [40, 40])
image = tf.image.random_crop(image, size=[32, 32, 3])
image = tf.image.random_flip_left_right(image)
image = image / 255.0
image = (image - CIFAR_MEAN) / CIFAR_STD
return image, label
def transform_test_fn(image, label):
"""Test preprocessing: normalize only."""
image = tf.cast(image, tf.float32) / 255.0
image = (image - CIFAR_MEAN) / CIFAR_STD
return image, labelFolder-based dataset + the augmentation pipelines:
def _load_image_folder_tf(folder_path):
"""Load images from a class-subfolder directory into a tf.data.Dataset
of (image, label) where image is uint8 [H, W, 3] and label is int."""
ds = tf.keras.utils.image_dataset_from_directory(
folder_path, label_mode='int', image_size=(32, 32),
batch_size=None, shuffle=False)
return ds
train_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'train'))
train_valid_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'train_valid'))
valid_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'valid'))
test_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'test'))train_iter = (train_ds.map(transform_train_fn, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(10000).batch(batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
train_valid_iter = (train_valid_ds.map(transform_train_fn,
num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(10000).batch(batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
valid_iter = (valid_ds.map(transform_test_fn, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
test_iter = (test_ds.map(transform_test_fn, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size, drop_remainder=False)
.prefetch(tf.data.AUTOTUNE))No transfer learning this time — CIFAR-10 is small enough to train from scratch. The core unit is the same residual block from the ResNet chapter: two 3×3 convs plus an identity or 1×1 projection shortcut.
Four residual stages progressively downsample the image and widen channels. Global average pooling removes spatial dimensions; the final dense layer emits 10 class logits:
Across frameworks, get_net returns the same contract: input minibatches of CIFAR-10 images, output logits with shape (batch, 10), and cross-entropy as the training loss.
class Residual(nn.Module):
num_channels: int
use_1x1conv: bool = False
strides: int = 1
@nn.compact
def __call__(self, X, training=False):
Y = nn.relu(nn.BatchNorm(use_running_average=not training)(
nn.Conv(self.num_channels, kernel_size=(3, 3),
strides=(self.strides, self.strides), padding='SAME')(X)))
Y = nn.BatchNorm(use_running_average=not training)(
nn.Conv(self.num_channels, kernel_size=(3, 3), padding='SAME')(Y))
if self.use_1x1conv:
X = nn.Conv(self.num_channels, kernel_size=(1, 1),
strides=(self.strides, self.strides))(X)
return nn.relu(Y + X)
class ResNet18(nn.Module):
num_classes: int = 10
@nn.compact
def __call__(self, X, training=False):
X = nn.relu(nn.BatchNorm(use_running_average=not training)(
nn.Conv(64, kernel_size=(3, 3), strides=(1, 1),
padding='SAME')(X)))
# Block 1
for _ in range(2):
X = Residual(64)(X, training=training)
# Block 2
X = Residual(128, use_1x1conv=True, strides=2)(X, training=training)
X = Residual(128)(X, training=training)
# Block 3
X = Residual(256, use_1x1conv=True, strides=2)(X, training=training)
X = Residual(256)(X, training=training)
# Block 4
X = Residual(512, use_1x1conv=True, strides=2)(X, training=training)
X = Residual(512)(X, training=training)
X = jnp.mean(X, axis=(1, 2)) # Global average pooling
X = nn.Dense(self.num_classes)(X)
return X
def get_net():
return ResNet18(num_classes=10)
def loss_fn(logits, labels):
return optax.softmax_cross_entropy_with_integer_labels(logits, labels)SGD with momentum + weight decay + LR step decay is the classic small-image vision recipe. The long helper mainly adapts that recipe to each framework, so teach the invariant loop:
Use the validation split for model selection. Training loss should decline smoothly; validation accuracy is the signal for whether augmentation and the learning-rate schedule are helping rather than just fitting the train set.
train loss 0.703, train acc 0.742, valid acc 0.438
1623.4 examples/sec
Run on the test set, write a Kaggle-format CSV:
net, preds = get_net(), []
variables = train(net, train_valid_iter, None, num_epochs, lr, wd, lr_period,
lr_decay)
for X, _ in test_iter:
X_jax = jnp.array(X.numpy()) # Already NHWC from tf.data
y_hat = net.apply(variables, X_jax, training=False)
preds.extend(np.array(y_hat.argmax(axis=-1)))
# Get class names from the train_valid dataset directory
class_names = sorted(os.listdir(
os.path.join(data_dir, 'train_valid_test', 'train_valid')))
sorted_ids = list(range(1, sum(1 for _ in test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id': sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: class_names[x])
df.to_csv('submission.csv', index=False)train loss 0.719, train acc 0.725
2118.2 examples/sec