Known ground truth

Synthetic Regression Data

Synthetic data with a known answer

Before we train a model we need data. For pedagogy, we’ll synthesize it — known weights, known noise, and a guaranteed correct answer to compare against:

\mathbf{y} = \mathbf{X} \mathbf{w} + b + \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \sigma^2 I).

This chapter:

  • Subclass DataModule to generate the synthetic batch.
  • Roll a hand-written minibatch sampler (to see how it works).
  • Swap in the framework’s built-in dataloader (the version we’ll actually use).

Synthetic data module

A DataModule subclass that draws features and computes labels in __init__:

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
import random
import tensorflow as tf
import tensorflow_datasets as tfds
class SyntheticRegressionData(d2l.DataModule):
    """Synthetic data for linear regression."""
    def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000,
                 batch_size=32, key=jax.random.PRNGKey(0)):
        super().__init__()
        self.save_hyperparameters()
        n = num_train + num_val
        key1, key2 = jax.random.split(key)
        self.X = jax.random.normal(key1, (n, w.shape[0]))
        noise = jax.random.normal(key2, (n, 1)) * noise
        self.y = d2l.matmul(self.X, d2l.reshape(w, (-1, 1))) + b + noise

Instantiate with the true w = [2, -3.4], b = 4.2:

data = SyntheticRegressionData(w=d2l.tensor([2, -3.4]), b=4.2)

Inspecting one example

Each row of features is a vector in \mathbb{R}^2; the corresponding label is a scalar:

print('features:', data.X[0],'\nlabel:', data.y[0])
features: [ 1.0040143 -0.9063372] 
label: [9.265151]

A handwritten dataloader

get_dataloader shuffles indices, then yields minibatches of size batch_size:

def get_dataloader(self, train):
    if train:
        indices = list(range(0, self.num_train))
        # The examples are read in random order
        random.shuffle(indices)
    else:
        indices = list(range(self.num_train, self.num_train+self.num_val))
    for i in range(0, len(indices), self.batch_size):
        batch_indices = d2l.tensor(indices[i: i+self.batch_size])
        yield self.X[batch_indices], self.y[batch_indices]
X, y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)
X shape: (32, 2) 
y shape: (32, 1)

Educational, but slow — Python loops over indices, no prefetching, no parallelism.

The framework dataloader

For real work, wrap features and labels in the framework’s built-in dataset / dataloader (workers, prefetch, GPU pinning):

@d2l.add_to_class(d2l.DataModule)
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
    tensors = tuple(a[indices] for a in tensors)
    # Use Tensorflow Datasets & Dataloader. JAX or Flax do not provide
    # any dataloading functionality. `drop_remainder=train` keeps every
    # *training* minibatch the same shape, so a `@jax.jit`'d step
    # function compiles once per epoch instead of recompiling for the
    # smaller last batch (a common source of multi-minute slowdowns on
    # NLP datasets where the last batch is a different shape every time).
    shuffle_buffer = tensors[0].shape[0] if train else 1
    return tfds.as_numpy(
        tf.data.Dataset.from_tensor_slices(tensors).shuffle(
            buffer_size=shuffle_buffer
        ).batch(self.batch_size, drop_remainder=train))
@d2l.add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
    i = slice(0, self.num_train) if train else slice(self.num_train, None)
    return self.get_tensorloader((self.X, self.y), train, i)

Same minibatch interface

Identical iteration protocol from the caller’s POV:

X, y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)
X shape: (32, 2) 
y shape: (32, 1)

len(dl) reports the number of batches per epoch — convenient for progress bars:

len(data.train_dataloader())
31

Recap

  • Synthetic data → ground-truth w, b you can compare against later.
  • DataModule subclasses encapsulate “where do batches come from?” once, reusable across models.
  • Hand-rolled iterator vs. framework dataloader — same protocol; framework version wins on speed and ergonomics.