Layers and Modules

Modules as Building Blocks

Modern networks aren’t flat stacks. ResNet-152 has 152 conv layers, organized into a handful of repeating patterns. Transformers stack 12, 24, 96 identical blocks. Writing them one layer at a time would be miserable.

The module abstraction (nn.Module in PyTorch, flax.linen.Module in JAX) handles the recursion. A module can be a single layer, a block of layers, or the whole model — all three are the same Python class.

Modules compose recursively

Layers compose into modules; modules compose into models.

What every module must do

The framework asks five things of every module:

  1. Take input via forward(x).
  2. Return output (possibly a different shape).
  3. Compute gradients of output w.r.t. input (autograd does this for free).
  4. Store and expose its parameters.
  5. Initialize them (or accept user init).

Subclass nn.Module, write __init__ + forward, and the base class supplies the bookkeeping automatically.

The simple way: nn.Sequential

For a linear chain of layers, nn.Sequential does everything. Construct, call, done:

import torch
from torch import nn
from torch.nn import functional as F
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))

X = torch.rand(2, 20)
net(X).shape
torch.Size([2, 10])

Sequential is a module. Internally it stores its children in a list and the forward walks them in order. “List of layers, run them in sequence” — that’s all.

Hand-rolled MLP module

Sequential is good when the topology is a chain. For anything else, define your own subclass. The pattern: name sub-modules in __init__, write forward to use them:

class MLP(nn.Module):
    def __init__(self):
        # Call the constructor of the parent class nn.Module to perform
        # the necessary initialization
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.out = nn.LazyLinear(10)

    # Define the forward propagation of the model, that is, how to return the
    # required model output based on the input X
    def forward(self, X):
        return self.out(F.relu(self.hidden(X)))

The two attributes self.hidden and self.out aren’t ordinary fields — assigning a Module to a Module attribute registers it as a child. From this moment on:

  • net.parameters() includes both layers’ weights/biases.
  • net.to('cuda') moves both to GPU.
  • net.state_dict() gives a flat dict of every parameter.

Total user code: ~6 lines.

Building a Sequential ourselves

What does nn.Sequential actually do? Almost nothing — its implementation in 4 lines:

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, X):
        for module in self.children():            
            X = module(X)
        return X

Plug it in and the API is identical to the framework’s:

net = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
net(X).shape
torch.Size([2, 10])

The “magic” is just using add_module() so children get registered, then a for loop in forward.

forward is just Python

This is the superpower of the module abstraction: forward is normal Python. Use loops, conditionals, random tensors, anything you’d write in numpy:

class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # Random weight parameters that will not compute gradients and
        # therefore keep constant during training
        self.rand_weight = torch.rand((20, 20))
        self.linear = nn.LazyLinear(20)

    def forward(self, X):
        X = self.linear(X)        
        X = F.relu(X @ self.rand_weight + 1)
        # Reuse the fully connected layer. This is equivalent to sharing
        # parameters with two fully connected layers
        X = self.linear(X)
        # Control flow
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()

The while loop, the fixed rand_weight, even reusing self.linear twice (parameter sharing!) all work, and all flow gradients correctly:

net = FixedHiddenMLP()
net(X)
tensor(-0.0418, grad_fn=<SumBackward0>)

Composition: modules all the way down

Modules nest to any depth. A NestMLP holds a Sequential; a top-level Sequential holds a NestMLP + a Linear + a FixedHiddenMLP:

class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.LazyLinear(64), nn.ReLU(),
                                 nn.LazyLinear(32), nn.ReLU())
        self.linear = nn.LazyLinear(16)

    def forward(self, X):
        return self.linear(self.net(X))

chimera = nn.Sequential(NestMLP(), nn.LazyLinear(20), FixedHiddenMLP())
chimera(X)
tensor(0.3213, grad_fn=<SumBackward0>)

The framework recursively walks this tree to find every parameter. Every modern architecture is built this way: ResNet = blocks of ResBlocks of conv+BN+ReLU. Transformer = blocks of attention+FFN. Same recursion every time.

Recap

  • A module is one Python class that represents a layer, a block, or a whole model.
  • Children assigned to attributes are auto-registered for parameter tracking, device placement, serialization.
  • Sequential is a 4-line module that runs children in order; for arbitrary topologies, subclass and write forward.
  • forward is plain Python — control flow, parameter sharing, fixed buffers all welcome.
  • Modules compose recursively; that recursion is what lets ResNet-152 be 50 lines instead of 5000.