%matplotlib inline
from d2l import tensorflow as d2l
import tensorflow as tf
import randomBefore we train a model we need data. For pedagogy, we’ll synthesize it — known weights, known noise, and a guaranteed correct answer to compare against:
\mathbf{y} = \mathbf{X} \mathbf{w} + b + \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \sigma^2 I).
This chapter:
DataModule to generate the synthetic batch.A DataModule subclass that draws features and computes labels in __init__:
class SyntheticRegressionData(d2l.DataModule):
"""Synthetic data for linear regression."""
def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000,
batch_size=32):
super().__init__()
self.save_hyperparameters()
n = num_train + num_val
self.X = tf.random.normal((n, w.shape[0]))
noise = tf.random.normal((n, 1)) * noise
self.y = d2l.matmul(self.X, d2l.reshape(w, (-1, 1))) + b + noiseInstantiate with the true w = [2, -3.4], b = 4.2:
Each row of features is a vector in \mathbb{R}^2; the corresponding label is a scalar:
features: tf.Tensor([-1.3305755 1.3084109], shape=(2,), dtype=float32)
label: tf.Tensor([-2.9095104], shape=(1,), dtype=float32)
get_dataloader shuffles indices, then yields minibatches of size batch_size:
def get_dataloader(self, train):
if train:
indices = list(range(0, self.num_train))
# The examples are read in random order
random.shuffle(indices)
else:
indices = list(range(self.num_train, self.num_train+self.num_val))
for i in range(0, len(indices), self.batch_size):
j = tf.constant(indices[i : i+self.batch_size])
yield tf.gather(self.X, j), tf.gather(self.y, j)For real work, wrap features and labels in the framework’s built-in dataset / dataloader (workers, prefetch, GPU pinning):
@d2l.add_to_class(d2l.DataModule)
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
tensors = tuple(a[indices] for a in tensors)
shuffle_buffer = tensors[0].shape[0] if train else 1
return tf.data.Dataset.from_tensor_slices(tensors).shuffle(
buffer_size=shuffle_buffer).batch(self.batch_size)Identical iteration protocol from the caller’s POV:
X shape: (32, 2)
y shape: (32, 1)
w, b you can compare against later.DataModule subclasses encapsulate “where do batches come from?” once, reusable across models.