from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import tensorflow as tf
import osA second Kaggle capstone: ImageNet Dogs (120 fine-grained breeds). The big difference from CIFAR-10: this is a subset of ImageNet, so a pretrained ResNet already knows almost everything about these classes. Fine-tuning is the right play.
Kaggle “Dog Breed Identification” page.
E0524 02:33:44.330891 3930 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:33:44.331318 3930 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
'0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')
# If you use the full dataset downloaded for the Kaggle competition, change
# the variable below to `False`
demo = True
if demo:
data_dir = d2l.download_extract('dog_tiny')
else:
data_dir = os.path.join('..', 'data', 'dog-breed-identification')Same idea as CIFAR-10 — reshuffle the Kaggle layout into train/<class>/img.jpg for the standard ImageFolder loader:
ImageNet-scale augmentation: random resized crop, random horizontal flip, color jitter, and the same input preprocessing convention the pretrained backbone expects:
def transform_train_fn(image, label):
"""Training augmentation: random crop, flip, color jitter, normalize."""
image = tf.cast(image, tf.float32)
# Random resized crop to 224x224
image = tf.image.resize(image, [256, 256])
image = tf.image.random_crop(image, size=[224, 224, 3])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.4 * 255)
image = tf.image.random_contrast(image, lower=0.6, upper=1.4)
image = tf.image.random_saturation(image, lower=0.6, upper=1.4)
image = tf.clip_by_value(image, 0.0, 255.0)
return tf.keras.applications.resnet50.preprocess_input(image), labeldef transform_test_fn(image, label):
"""Test preprocessing: resize, center crop, normalize."""
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [256, 256])
# Center crop to 224x224
image = tf.image.resize_with_crop_or_pad(image, 224, 224)
return tf.keras.applications.resnet50.preprocess_input(image), labeldef _load_image_folder_tf(folder_path):
"""Load images from a class-subfolder directory into a tf.data.Dataset."""
ds = tf.keras.utils.image_dataset_from_directory(
folder_path, label_mode='int', image_size=(256, 256),
batch_size=None, shuffle=False)
return ds
train_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'train'))
train_valid_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'train_valid'))
valid_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'valid'))
test_ds = _load_image_folder_tf(
os.path.join(data_dir, 'train_valid_test', 'test'))train_iter = (train_ds.map(transform_train_fn, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(10000).batch(batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
train_valid_iter = (train_valid_ds.map(transform_train_fn,
num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(10000).batch(batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
valid_iter = (valid_ds.map(transform_test_fn, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
test_iter = (test_ds.map(transform_test_fn, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size, drop_remainder=False)
.prefetch(tf.data.AUTOTUNE))This competition is close to ImageNet, so we reuse a pretrained ResNet as a frozen feature extractor and train only a small 120-way breed classifier:
# Use a pretrained TF ResNet50 to extract ImageNet logits as frozen features,
# then train a small Flax output network on top. This mirrors the PyTorch tab,
# where torchvision's pretrained ResNet emits 1000 ImageNet logits before the
# custom dog-breed head.
_resnet_for_features = tf.keras.applications.ResNet50(
weights='imagenet', include_top=True, classifier_activation=None,
input_shape=(224, 224, 3))
_resnet_for_features.trainable = False
class OutputNet(nn.Module):
"""Small output network for fine-tuning."""
num_classes: int = 120
@nn.compact
def __call__(self, x, training=False):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(self.num_classes)(x)
return x
def get_net():
output_net = OutputNet(num_classes=120)
return _resnet_for_features, output_netOnly the custom output network receives gradients. The validation loss is computed through the same frozen features, so it measures whether the dog-breed head is generalizing:
def loss_fn(logits, labels):
return optax.softmax_cross_entropy_with_integer_labels(logits, labels)
def extract_features(features_net, X_batch):
"""Extract features using the frozen TF ResNet50."""
feats = features_net.predict(X_batch, verbose=0)
return jnp.array(feats)
def precompute_features(features_net, data_iter):
"""Run the frozen TF backbone once over the whole dataset and cache
the (features, labels) tensors as JAX arrays. Subsequent training
only iterates the small classifier head over these cached features,
so we never round-trip from JAX into TF on the per-step path."""
feats_list, labels_list = [], []
for features, labels in data_iter:
f = features_net(features, training=False).numpy()
feats_list.append(f)
labels_list.append(labels.numpy())
feats = jnp.array(np.concatenate(feats_list, axis=0))
labels = jnp.array(np.concatenate(labels_list, axis=0))
return feats, labels
def evaluate_loss_from_feats(feats, labels, output_net, variables,
batch_size):
l_sum, n = 0.0, 0
for i in range(0, feats.shape[0], batch_size):
fb = feats[i:i + batch_size]
yb = labels[i:i + batch_size]
logits = output_net.apply(variables, fb, training=False)
l = loss_fn(logits, yb)
l_sum += float(l.sum())
n += int(yb.shape[0])
return l_sum / n
def evaluate_loss(data_iter, features_net, output_net, variables):
l_sum, n = 0.0, 0
for features, labels in data_iter:
feats = extract_features(features_net, features.numpy())
y = jnp.array(labels.numpy())
logits = output_net.apply(variables, feats, training=False)
l = loss_fn(logits, y)
l_sum += float(l.sum())
n += len(labels)
return l_sum / nThe helper is mostly framework bookkeeping. The training structure is:
That is the practical transfer-learning tradeoff: far less memory and time, while keeping most ImageNet visual knowledge.
Expect validation loss to be the useful curve here; with 120 fine-grained classes, top-line accuracy can be noisy on the tiny book subset. On the full competition data, train longer and tune the head/augmentation strength.
train loss 3.760, valid loss 4.146
2826.7 examples/sec
Write one probability vector per test image. The CSV has image id plus 120 breed probabilities, so the final layer must stay aligned with the competition’s class order:
features_net, output_net = get_net()
variables = train(features_net, output_net, train_valid_iter, None,
num_epochs, lr, wd, lr_period, lr_decay)
preds = []
for data, label in test_iter:
feats = extract_features(features_net, data.numpy())
logits = output_net.apply(variables, feats, training=False)
output = jax.nn.softmax(logits, axis=-1)
preds.extend(np.array(output))
# Get class names from the train_valid dataset directory
class_names = sorted(os.listdir(
os.path.join(data_dir, 'train_valid_test', 'train_valid')))
ids = sorted(os.listdir(
os.path.join(data_dir, 'train_valid_test', 'test', 'unknown')))
with open('submission.csv', 'w') as f:
f.write('id,' + ','.join(class_names) + '\n')
for i, output in zip(ids, preds):
f.write(i.split('.')[0] + ',' + ','.join(
[str(num) for num in output]) + '\n')train loss 3.697
5403.8 examples/sec