Dog Breed Identification (ImageNet Dogs) on Kaggle

Kaggle Dog Breed

A second Kaggle capstone: ImageNet Dogs (120 fine-grained breeds). The big difference from CIFAR-10: this is a subset of ImageNet, so a pretrained ResNet already knows almost everything about these classes. Fine-tuning is the right play.

Kaggle “Dog Breed Identification” page.

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import tensorflow as tf
import os
E0524 02:33:44.330891 3930 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:33:44.331318 3930 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...

Downloading

d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
                            '0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')

# If you use the full dataset downloaded for the Kaggle competition, change
# the variable below to `False`
demo = True
if demo:
    data_dir = d2l.download_extract('dog_tiny')
else:
    data_dir = os.path.join('..', 'data', 'dog-breed-identification')

Organizing the dataset

Same idea as CIFAR-10 — reshuffle the Kaggle layout into train/<class>/img.jpg for the standard ImageFolder loader:

def reorg_dog_data(data_dir, valid_ratio):
    labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))
    d2l.reorg_train_valid(data_dir, labels, valid_ratio)
    d2l.reorg_test(data_dir)


batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

Augmentation

ImageNet-scale augmentation: random resized crop, random horizontal flip, color jitter, and the same input preprocessing convention the pretrained backbone expects:

def transform_train_fn(image, label):
    """Training augmentation: random crop, flip, color jitter, normalize."""
    image = tf.cast(image, tf.float32)
    # Random resized crop to 224x224
    image = tf.image.resize(image, [256, 256])
    image = tf.image.random_crop(image, size=[224, 224, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.4 * 255)
    image = tf.image.random_contrast(image, lower=0.6, upper=1.4)
    image = tf.image.random_saturation(image, lower=0.6, upper=1.4)
    image = tf.clip_by_value(image, 0.0, 255.0)
    return tf.keras.applications.resnet50.preprocess_input(image), label
def transform_test_fn(image, label):
    """Test preprocessing: resize, center crop, normalize."""
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [256, 256])
    # Center crop to 224x224
    image = tf.image.resize_with_crop_or_pad(image, 224, 224)
    return tf.keras.applications.resnet50.preprocess_input(image), label

Data loaders

def _load_image_folder_tf(folder_path):
    """Load images from a class-subfolder directory into a tf.data.Dataset."""
    ds = tf.keras.utils.image_dataset_from_directory(
        folder_path, label_mode='int', image_size=(256, 256),
        batch_size=None, shuffle=False)
    return ds

train_ds = _load_image_folder_tf(
    os.path.join(data_dir, 'train_valid_test', 'train'))
train_valid_ds = _load_image_folder_tf(
    os.path.join(data_dir, 'train_valid_test', 'train_valid'))
valid_ds = _load_image_folder_tf(
    os.path.join(data_dir, 'train_valid_test', 'valid'))
test_ds = _load_image_folder_tf(
    os.path.join(data_dir, 'train_valid_test', 'test'))
train_iter = (train_ds.map(transform_train_fn, num_parallel_calls=tf.data.AUTOTUNE)
              .shuffle(10000).batch(batch_size, drop_remainder=True)
              .prefetch(tf.data.AUTOTUNE))
train_valid_iter = (train_valid_ds.map(transform_train_fn,
                    num_parallel_calls=tf.data.AUTOTUNE)
                    .shuffle(10000).batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.AUTOTUNE))
valid_iter = (valid_ds.map(transform_test_fn, num_parallel_calls=tf.data.AUTOTUNE)
              .batch(batch_size, drop_remainder=True)
              .prefetch(tf.data.AUTOTUNE))
test_iter = (test_ds.map(transform_test_fn, num_parallel_calls=tf.data.AUTOTUNE)
             .batch(batch_size, drop_remainder=False)
             .prefetch(tf.data.AUTOTUNE))

Frozen ImageNet features

This competition is close to ImageNet, so we reuse a pretrained ResNet as a frozen feature extractor and train only a small 120-way breed classifier:

# Use a pretrained TF ResNet50 to extract ImageNet logits as frozen features,
# then train a small Flax output network on top. This mirrors the PyTorch tab,
# where torchvision's pretrained ResNet emits 1000 ImageNet logits before the
# custom dog-breed head.
_resnet_for_features = tf.keras.applications.ResNet50(
    weights='imagenet', include_top=True, classifier_activation=None,
    input_shape=(224, 224, 3))
_resnet_for_features.trainable = False

class OutputNet(nn.Module):
    """Small output network for fine-tuning."""
    num_classes: int = 120

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)
        return x

def get_net():
    output_net = OutputNet(num_classes=120)
    return _resnet_for_features, output_net

Head loss and validation

Only the custom output network receives gradients. The validation loss is computed through the same frozen features, so it measures whether the dog-breed head is generalizing:

def loss_fn(logits, labels):
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels)

def extract_features(features_net, X_batch):
    """Extract features using the frozen TF ResNet50."""
    feats = features_net.predict(X_batch, verbose=0)
    return jnp.array(feats)

def precompute_features(features_net, data_iter):
    """Run the frozen TF backbone once over the whole dataset and cache
    the (features, labels) tensors as JAX arrays. Subsequent training
    only iterates the small classifier head over these cached features,
    so we never round-trip from JAX into TF on the per-step path."""
    feats_list, labels_list = [], []
    for features, labels in data_iter:
        f = features_net(features, training=False).numpy()
        feats_list.append(f)
        labels_list.append(labels.numpy())
    feats = jnp.array(np.concatenate(feats_list, axis=0))
    labels = jnp.array(np.concatenate(labels_list, axis=0))
    return feats, labels

def evaluate_loss_from_feats(feats, labels, output_net, variables,
                             batch_size):
    l_sum, n = 0.0, 0
    for i in range(0, feats.shape[0], batch_size):
        fb = feats[i:i + batch_size]
        yb = labels[i:i + batch_size]
        logits = output_net.apply(variables, fb, training=False)
        l = loss_fn(logits, yb)
        l_sum += float(l.sum())
        n += int(yb.shape[0])
    return l_sum / n

def evaluate_loss(data_iter, features_net, output_net, variables):
    l_sum, n = 0.0, 0
    for features, labels in data_iter:
        feats = extract_features(features_net, features.numpy())
        y = jnp.array(labels.numpy())
        logits = output_net.apply(variables, feats, training=False)
        l = loss_fn(logits, y)
        l_sum += float(l.sum())
        n += len(labels)
    return l_sum / n

Training function

The helper is mostly framework bookkeeping. The training structure is:

  • precompute frozen ImageNet features;
  • train the 120-way head with cross-entropy;
  • report validation loss on held-out breeds;
  • repeat on all training data before writing the submission file.

That is the practical transfer-learning tradeoff: far less memory and time, while keeping most ImageNet visual knowledge.

Train

Expect validation loss to be the useful curve here; with 120 fine-grained classes, top-line accuracy can be noisy on the tiny book subset. On the full competition data, train longer and tune the head/augmentation strength.

num_epochs, lr, wd = 10, 1e-4, 1e-4
lr_period, lr_decay = 2, 0.9
features_net, output_net = get_net()
variables = train(features_net, output_net, train_iter, valid_iter,
                  num_epochs, lr, wd, lr_period, lr_decay)

train loss 3.760, valid loss 4.146
2826.7 examples/sec

Submit predictions

Write one probability vector per test image. The CSV has image id plus 120 breed probabilities, so the final layer must stay aligned with the competition’s class order:

features_net, output_net = get_net()
variables = train(features_net, output_net, train_valid_iter, None,
                  num_epochs, lr, wd, lr_period, lr_decay)

preds = []
for data, label in test_iter:
    feats = extract_features(features_net, data.numpy())
    logits = output_net.apply(variables, feats, training=False)
    output = jax.nn.softmax(logits, axis=-1)
    preds.extend(np.array(output))
# Get class names from the train_valid dataset directory
class_names = sorted(os.listdir(
    os.path.join(data_dir, 'train_valid_test', 'train_valid')))
ids = sorted(os.listdir(
    os.path.join(data_dir, 'train_valid_test', 'test', 'unknown')))
with open('submission.csv', 'w') as f:
    f.write('id,' + ','.join(class_names) + '\n')
    for i, output in zip(ids, preds):
        f.write(i.split('.')[0] + ',' + ','.join(
            [str(num) for num in output]) + '\n')

train loss 3.697
5403.8 examples/sec

Recap

  • ImageNet Dogs ⊂ ImageNet → fine-tuning a pretrained CNN crushes from-scratch training.
  • Standard recipe: pretrained backbone, new 120-way head, ImageNet-scale augmentation, ImageNet-compatible preprocessing.
  • Same shape as the CIFAR-10 deck; only the dataset and the choice “train from scratch vs fine-tune” differ.
  • The general lesson: when your task is close to the pretraining domain, transfer learning beats everything.