from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnptorch.nn ships 100+ layers, but occasionally — a new architecture, an unusual normalization, a custom block — you need one the framework doesn’t have.
Writing one is trivial: subclass nn.Module, override forward. Two flavors:
forward.Linear, low-rank weight, etc. Wrap learnable tensors in nn.Parameter.The custom layer composes with built-ins automatically — Sequential, parameters(), to(device), checkpointing.
Subtract the row-wise mean from each input. Nothing to learn — pure transform:
Drop the custom layer into a Sequential like any other:
Implement a fully-connected layer from scratch. The one important step: wrap learnable tensors in nn.Parameter so they’re auto-registered for training:
class MyDense(nn.Module):
in_units: int
units: int
def setup(self):
self.weight = self.param('weight', nn.initializers.normal(stddev=1),
(self.in_units, self.units))
self.bias = self.param('bias', nn.initializers.zeros, self.units)
def __call__(self, X):
linear = jnp.matmul(X, self.weight) + self.bias
return nn.relu(linear){'params': {'weight': Array([[-0.2730784 , -2.0026138 , 1.2090734 ],
[-0.24928978, 1.4607671 , 0.9950771 ],
[ 0.00302258, -0.63909173, 0.6674626 ],
[-0.02182669, 0.961296 , 0.26424628],
[-0.6597475 , 0.4217269 , -0.09541126]], dtype=float32),
'bias': Array([0., 0., 0.], dtype=float32)}}
nn.Parameter buys youAfter linear = MyLinear(5, 3):
linear.weight and linear.bias are tracked parameters.linear.parameters() yields both — feed to the optimizer.state_dict() saves them; linear.to('cuda') moves them.All for free, just by declaring nn.Parameter in __init__.
Array([[0. , 1.2433242 , 1.4437162 ],
[0. , 0.4386982 , 0.80651015]], dtype=float32)
Real-world cases that justify a custom layer:
register_buffer for non-trainable tensors that should still travel with the module (saved, moved to GPU, etc.).nn.Module subclass with a forward.forward. Stateful: wrap learnable tensors in nn.Parameter.register_buffer for non-trainable state that should still travel with the module.