Custom Layers

Custom Layers

torch.nn ships 100+ layers, but occasionally — a new architecture, an unusual normalization, a custom block — you need one the framework doesn’t have.

Writing one is trivial: subclass nn.Module, override forward. Two flavors:

  • Stateless — pure transforms. Just override forward.
  • Stateful — your own Linear, low-rank weight, etc. Wrap learnable tensors in nn.Parameter.

The custom layer composes with built-ins automatically — Sequential, parameters(), to(device), checkpointing.

Stateless layer: a centering operator

Subtract the row-wise mean from each input. Nothing to learn — pure transform:

from d2l import tensorflow as d2l
import tensorflow as tf
class CenteredLayer(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, X):
        return X - tf.reduce_mean(X)

Standalone use:

layer = CenteredLayer()
layer(d2l.tensor([1.0, 2, 3, 4, 5]))
<tf.Tensor: shape=(5,), dtype=float32, numpy=array([-2., -1.,  0.,  1.,  2.], dtype=float32)>

The output mean is (numerically) zero — by construction.

Composes with built-ins

Drop the custom layer into a Sequential like any other:

net = tf.keras.Sequential([tf.keras.layers.Dense(128), CenteredLayer()])
Y = net(tf.random.uniform((4, 8)))
tf.reduce_mean(Y)
<tf.Tensor: shape=(), dtype=float32, numpy=9.313225746154785e-10>

The framework can’t tell CenteredLayer apart from Linear or ReLU — they’re all just nn.Modules.

Stateful layer: hand-rolled Linear

Implement a fully-connected layer from scratch. The one important step: wrap learnable tensors in nn.Parameter so they’re auto-registered for training:

class MyDense(tf.keras.Model):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, X_shape):
        self.weight = self.add_weight(name='weight',
            shape=[X_shape[-1], self.units],
            initializer=tf.random_normal_initializer())
        self.bias = self.add_weight(
            name='bias', shape=[self.units],
            initializer=tf.zeros_initializer())

    def call(self, X):
        linear = tf.matmul(X, self.weight) + self.bias
        return tf.nn.relu(linear)
dense = MyDense(3)
dense(tf.random.uniform((2, 5)))
dense.get_weights()
[array([[ 1.3681618e-02,  4.1976105e-02,  6.7144625e-02],
        [ 9.5281027e-02, -2.0827046e-03,  7.0420615e-02],
        [-7.2408892e-02,  4.1485440e-02,  4.1537687e-02],
        [ 1.7760841e-02,  1.7952685e-03, -4.4985678e-02],
        [ 8.0492093e-05, -8.4409781e-02,  4.2490412e-02]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

What nn.Parameter buys you

After linear = MyLinear(5, 3):

  • linear.weight and linear.bias are tracked parameters.
  • linear.parameters() yields both — feed to the optimizer.
  • state_dict() saves them; linear.to('cuda') moves them.

All for free, just by declaring nn.Parameter in __init__.

Test drive

dense(tf.random.uniform((2, 5)))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.        , 0.01131513, 0.10846559],
       [0.03849415, 0.        , 0.11945204]], dtype=float32)>

Stack two MyLinears — same Sequential plumbing as built-in layers:

net = tf.keras.models.Sequential([MyDense(8), MyDense(1)])
net(tf.random.uniform((2, 64)))
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.],
       [0.]], dtype=float32)>

When to write a custom layer

Real-world cases that justify a custom layer:

  • Novel architectural blocks — gated linear units, factorized weight matrices, low-rank parameterizations (LoRA).
  • Custom normalization — group norm with non-standard groups, layer-norm variants.
  • Tied/shared weights with structure — embedding + output projection sharing in language models.
  • Frozen “buffers” — running statistics in BatchNorm, position-specific masks. Use register_buffer for non-trainable tensors that should still travel with the module (saved, moved to GPU, etc.).

Recap

  • Custom layer = nn.Module subclass with a forward.
  • Stateless: just override forward. Stateful: wrap learnable tensors in nn.Parameter.
  • Use register_buffer for non-trainable state that should still travel with the module.
  • Composes with built-in layers exactly the same as a built-in. No special handling.
  • The escape hatch when the standard layer zoo doesn’t cover what you actually need.