%matplotlib inline
import time
from d2l import torch as d2l
import torch
import torchvision
from torchvision import transforms
d2l.use_svg_display()Fashion-MNIST is the workhorse dataset for the rest of this chapter:
DataModule so every classifier we build can reuse the same loaders.Imports and the FashionMNIST DataModule shell:
class FashionMNIST(d2l.DataModule):
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
trans = transforms.Compose([transforms.Resize(resize),
transforms.ToTensor()])
self.train = torchvision.datasets.FashionMNIST(
root=self.root, train=True, transform=trans, download=True)
self.val = torchvision.datasets.FashionMNIST(
root=self.root, train=False, transform=trans, download=True)Instantiate (resizing to 32×32 to match later ConvNet inputs):
(60000, 10000)
Each train item is a (C, H, W) image tensor + an integer label:
torch.Size([1, 32, 32])
A 1×32×32 grayscale image — single channel, after the resize.
The dataset stores labels as integers 0–9. A small helper turns each label into its English name (T-shirt, Trouser, Pullover, …):
Wrap the framework dataloader so train and val each yield batches in the same shape:
Time one full epoch through the loader. Slow loading bottlenecks training as much as the model itself:
'2.87 sec'
A grid plotter we’ll reuse for spot-checks:
Bound to the dataset as a method that pulls one batch and labels each tile with the class name:
DataModule subclass owns the framework’s train / val_dataloader, label decoding, and a visualize helper.