from d2l import tensorflow as d2l
import tensorflow as tftorch.nn ships 100+ layers, but occasionally — a new architecture, an unusual normalization, a custom block — you need one the framework doesn’t have.
Writing one is trivial: subclass nn.Module, override forward. Two flavors:
forward.Linear, low-rank weight, etc. Wrap learnable tensors in nn.Parameter.The custom layer composes with built-ins automatically — Sequential, parameters(), to(device), checkpointing.
Subtract the row-wise mean from each input. Nothing to learn — pure transform:
Drop the custom layer into a Sequential like any other:
Implement a fully-connected layer from scratch. The one important step: wrap learnable tensors in nn.Parameter so they’re auto-registered for training:
class MyDense(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, X_shape):
self.weight = self.add_weight(name='weight',
shape=[X_shape[-1], self.units],
initializer=tf.random_normal_initializer())
self.bias = self.add_weight(
name='bias', shape=[self.units],
initializer=tf.zeros_initializer())
def call(self, X):
linear = tf.matmul(X, self.weight) + self.bias
return tf.nn.relu(linear)[array([[ 1.3681618e-02, 4.1976105e-02, 6.7144625e-02],
[ 9.5281027e-02, -2.0827046e-03, 7.0420615e-02],
[-7.2408892e-02, 4.1485440e-02, 4.1537687e-02],
[ 1.7760841e-02, 1.7952685e-03, -4.4985678e-02],
[ 8.0492093e-05, -8.4409781e-02, 4.2490412e-02]], dtype=float32),
array([0., 0., 0.], dtype=float32)]
nn.Parameter buys youAfter linear = MyLinear(5, 3):
linear.weight and linear.bias are tracked parameters.linear.parameters() yields both — feed to the optimizer.state_dict() saves them; linear.to('cuda') moves them.All for free, just by declaring nn.Parameter in __init__.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0. , 0.01131513, 0.10846559],
[0.03849415, 0. , 0.11945204]], dtype=float32)>
Real-world cases that justify a custom layer:
register_buffer for non-trainable tensors that should still travel with the module (saved, moved to GPU, etc.).nn.Module subclass with a forward.forward. Stateful: wrap learnable tensors in nn.Parameter.register_buffer for non-trainable state that should still travel with the module.