The Image Classification Dataset

Fashion-MNIST as a reusable dataset

Fashion-MNIST is the workhorse dataset for the rest of this chapter:

  • 10 classes (T-shirt / trouser / pullover / …) of 28×28 grayscale images; 60 k train, 10 k test.
  • Drop-in replacement for MNIST — same shape, same API, harder.
  • We’ll wrap it in a DataModule so every classifier we build can reuse the same loaders.

Dataset setup

Imports and the FashionMNIST DataModule shell:

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
import time
import tensorflow as tf
import tensorflow_datasets as tfds

d2l.use_svg_display()
E0524 02:41:11.626254 23516 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:41:11.626680 23516 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
class FashionMNIST(d2l.DataModule):
    """The Fashion-MNIST dataset."""
    def __init__(self, batch_size=64, resize=(28, 28)):
        super().__init__()
        self.save_hyperparameters()
        self.train, self.val = tf.keras.datasets.fashion_mnist.load_data()

Instantiate Fashion-MNIST

Instantiate (resizing to 32×32 to match later ConvNet inputs):

data = FashionMNIST(resize=(32, 32))
len(data.train[0]), len(data.val[0])
(60000, 10000)

What does one example look like?

Each train item is a (C, H, W) image tensor + an integer label:

data.train[0][0].shape
(28, 28)

A 1×32×32 grayscale image — single channel, after the resize.

Human-readable labels

The dataset stores labels as integers 0–9. A small helper turns each label into its English name (T-shirt, Trouser, Pullover, …):

@d2l.add_to_class(FashionMNIST)
def text_labels(self, indices):
    """Return text labels."""
    labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
              'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [labels[int(i)] for i in indices]

Minibatches

Wrap the framework dataloader so train and val each yield batches in the same shape:

@d2l.add_to_class(FashionMNIST)
def get_dataloader(self, train):
    data = self.train if train else self.val
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
                            tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y)
    shuffle_buf = len(data[0]) if train else 1
    # `drop_remainder=train` for the same reason as the TF tab — JAX
    # also retraces a `@jax.jit`'d step function per unique input shape.
    return tfds.as_numpy(
        tf.data.Dataset.from_tensor_slices(process(*data)).shuffle(
            shuffle_buf).batch(self.batch_size,
                               drop_remainder=train).map(resize_fn))
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)
(64, 32, 32, 1) float32 (64,) int32

A batch of 64 32×32 grayscale images plus 64 integer labels.

Throughput sanity check

Time one full epoch through the loader. Slow loading bottlenecks training as much as the model itself:

tic = time.time()
for X, y in data.train_dataloader():
    continue
f'{time.time() - tic:.2f} sec'
'0.69 sec'

Visualization helpers

A grid plotter we’ll reuse for spot-checks:

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images."""
    raise NotImplementedError

Bound to the dataset as a method that pulls one batch and labels each tile with the class name:

@d2l.add_to_class(FashionMNIST)
def visualize(self, batch, nrows=1, ncols=8, labels=None):
    X, y = batch
    if not labels:
        labels = self.text_labels(y)
    d2l.show_images(jnp.squeeze(X), nrows, ncols, titles=labels)

batch = next(iter(data.val_dataloader()))
data.visualize(batch)

Recap

  • Fashion-MNIST: 10 classes, 28×28 grayscale, harder than MNIST.
  • A DataModule subclass owns the framework’s train / val_dataloader, label decoding, and a visualize helper.
  • Always sanity-check throughput — slow I/O caps training speed.
  • Same data API drives every model in this chapter.