%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
import random
import tensorflow as tf
import tensorflow_datasets as tfdsBefore we train a model we need data. For pedagogy, we’ll synthesize it — known weights, known noise, and a guaranteed correct answer to compare against:
\mathbf{y} = \mathbf{X} \mathbf{w} + b + \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \sigma^2 I).
This chapter:
DataModule to generate the synthetic batch.A DataModule subclass that draws features and computes labels in __init__:
class SyntheticRegressionData(d2l.DataModule):
"""Synthetic data for linear regression."""
def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000,
batch_size=32, key=jax.random.PRNGKey(0)):
super().__init__()
self.save_hyperparameters()
n = num_train + num_val
key1, key2 = jax.random.split(key)
self.X = jax.random.normal(key1, (n, w.shape[0]))
noise = jax.random.normal(key2, (n, 1)) * noise
self.y = d2l.matmul(self.X, d2l.reshape(w, (-1, 1))) + b + noiseInstantiate with the true w = [2, -3.4], b = 4.2:
Each row of features is a vector in \mathbb{R}^2; the corresponding label is a scalar:
features: [ 1.0040143 -0.9063372]
label: [9.265151]
get_dataloader shuffles indices, then yields minibatches of size batch_size:
def get_dataloader(self, train):
if train:
indices = list(range(0, self.num_train))
# The examples are read in random order
random.shuffle(indices)
else:
indices = list(range(self.num_train, self.num_train+self.num_val))
for i in range(0, len(indices), self.batch_size):
batch_indices = d2l.tensor(indices[i: i+self.batch_size])
yield self.X[batch_indices], self.y[batch_indices]For real work, wrap features and labels in the framework’s built-in dataset / dataloader (workers, prefetch, GPU pinning):
@d2l.add_to_class(d2l.DataModule)
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
tensors = tuple(a[indices] for a in tensors)
# Use Tensorflow Datasets & Dataloader. JAX or Flax do not provide
# any dataloading functionality. `drop_remainder=train` keeps every
# *training* minibatch the same shape, so a `@jax.jit`'d step
# function compiles once per epoch instead of recompiling for the
# smaller last batch (a common source of multi-minute slowdowns on
# NLP datasets where the last batch is a different shape every time).
shuffle_buffer = tensors[0].shape[0] if train else 1
return tfds.as_numpy(
tf.data.Dataset.from_tensor_slices(tensors).shuffle(
buffer_size=shuffle_buffer
).batch(self.batch_size, drop_remainder=train))Identical iteration protocol from the caller’s POV:
X shape: (32, 2)
y shape: (32, 1)
w, b you can compare against later.DataModule subclasses encapsulate “where do batches come from?” once, reusable across models.