Custom Layers

Custom Layers

torch.nn ships 100+ layers, but occasionally — a new architecture, an unusual normalization, a custom block — you need one the framework doesn’t have.

Writing one is trivial: subclass nn.Module, override forward. Two flavors:

  • Stateless — pure transforms. Just override forward.
  • Stateful — your own Linear, low-rank weight, etc. Wrap learnable tensors in nn.Parameter.

The custom layer composes with built-ins automatically — Sequential, parameters(), to(device), checkpointing.

Stateless layer: a centering operator

Subtract the row-wise mean from each input. Nothing to learn — pure transform:

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
class CenteredLayer(nn.Module):
    def __call__(self, X):
        return X - X.mean()

Standalone use:

layer = CenteredLayer()
layer(d2l.tensor([1.0, 2, 3, 4, 5]))
Array([-2., -1.,  0.,  1.,  2.], dtype=float32)

The output mean is (numerically) zero — by construction.

Composes with built-ins

Drop the custom layer into a Sequential like any other:

net = nn.Sequential([nn.Dense(128), CenteredLayer()])
Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(),
                                                              (4, 8)))
Y.mean()
Array(3.7252903e-09, dtype=float32)

The framework can’t tell CenteredLayer apart from Linear or ReLU — they’re all just nn.Modules.

Stateful layer: hand-rolled Linear

Implement a fully-connected layer from scratch. The one important step: wrap learnable tensors in nn.Parameter so they’re auto-registered for training:

class MyDense(nn.Module):
    in_units: int
    units: int

    def setup(self):
        self.weight = self.param('weight', nn.initializers.normal(stddev=1),
                                 (self.in_units, self.units))
        self.bias = self.param('bias', nn.initializers.zeros, self.units)

    def __call__(self, X):
        linear = jnp.matmul(X, self.weight) + self.bias
        return nn.relu(linear)
dense = MyDense(5, 3)
params = dense.init(d2l.get_key(), jnp.zeros((3, 5)))
params
{'params': {'weight': Array([[-0.2730784 , -2.0026138 ,  1.2090734 ],
         [-0.24928978,  1.4607671 ,  0.9950771 ],
         [ 0.00302258, -0.63909173,  0.6674626 ],
         [-0.02182669,  0.961296  ,  0.26424628],
         [-0.6597475 ,  0.4217269 , -0.09541126]], dtype=float32),
  'bias': Array([0., 0., 0.], dtype=float32)}}

What nn.Parameter buys you

After linear = MyLinear(5, 3):

  • linear.weight and linear.bias are tracked parameters.
  • linear.parameters() yields both — feed to the optimizer.
  • state_dict() saves them; linear.to('cuda') moves them.

All for free, just by declaring nn.Parameter in __init__.

Test drive

dense.apply(params, jax.random.uniform(d2l.get_key(),
                                       (2, 5)))
Array([[0.        , 1.2433242 , 1.4437162 ],
       [0.        , 0.4386982 , 0.80651015]], dtype=float32)

Stack two MyLinears — same Sequential plumbing as built-in layers:

net = nn.Sequential([MyDense(64, 8), MyDense(8, 1)])
Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(),
                                                              (2, 64)))
Y
Array([[0.      ],
       [2.143576]], dtype=float32)

When to write a custom layer

Real-world cases that justify a custom layer:

  • Novel architectural blocks — gated linear units, factorized weight matrices, low-rank parameterizations (LoRA).
  • Custom normalization — group norm with non-standard groups, layer-norm variants.
  • Tied/shared weights with structure — embedding + output projection sharing in language models.
  • Frozen “buffers” — running statistics in BatchNorm, position-specific masks. Use register_buffer for non-trainable tensors that should still travel with the module (saved, moved to GPU, etc.).

Recap

  • Custom layer = nn.Module subclass with a forward.
  • Stateless: just override forward. Stateful: wrap learnable tensors in nn.Parameter.
  • Use register_buffer for non-trainable state that should still travel with the module.
  • Composes with built-in layers exactly the same as a built-in. No special handling.
  • The escape hatch when the standard layer zoo doesn’t cover what you actually need.