%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
import time
import tensorflow as tf
import tensorflow_datasets as tfds
d2l.use_svg_display()Fashion-MNIST is the workhorse dataset for the rest of this chapter:
DataModule so every classifier we build can reuse the same loaders.Imports and the FashionMNIST DataModule shell:
E0524 02:41:11.626254 23516 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:41:11.626680 23516 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
Instantiate (resizing to 32×32 to match later ConvNet inputs):
(60000, 10000)
Each train item is a (C, H, W) image tensor + an integer label:
(28, 28)
A 1×32×32 grayscale image — single channel, after the resize.
The dataset stores labels as integers 0–9. A small helper turns each label into its English name (T-shirt, Trouser, Pullover, …):
Wrap the framework dataloader so train and val each yield batches in the same shape:
@d2l.add_to_class(FashionMNIST)
def get_dataloader(self, train):
data = self.train if train else self.val
process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
tf.cast(y, dtype='int32'))
resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y)
shuffle_buf = len(data[0]) if train else 1
# `drop_remainder=train` for the same reason as the TF tab — JAX
# also retraces a `@jax.jit`'d step function per unique input shape.
return tfds.as_numpy(
tf.data.Dataset.from_tensor_slices(process(*data)).shuffle(
shuffle_buf).batch(self.batch_size,
drop_remainder=train).map(resize_fn))Time one full epoch through the loader. Slow loading bottlenecks training as much as the model itself:
'0.69 sec'
A grid plotter we’ll reuse for spot-checks:
Bound to the dataset as a method that pulls one batch and labels each tile with the class name:
DataModule subclass owns the framework’s train / val_dataloader, label decoding, and a visualize helper.