import tensorflow as tfModern networks aren’t flat stacks. ResNet-152 has 152 conv layers, organized into a handful of repeating patterns. Transformers stack 12, 24, 96 identical blocks. Writing them one layer at a time would be miserable.
The module abstraction (nn.Module in PyTorch, flax.linen.Module in JAX) handles the recursion. A module can be a single layer, a block of layers, or the whole model — all three are the same Python class.
Layers compose into modules; modules compose into models.
The framework asks five things of every module:
forward(x).Subclass nn.Module, write __init__ + forward, and the base class supplies the bookkeeping automatically.
For a linear chain of layers, nn.Sequential does everything. Construct, call, done:
TensorShape([2, 10])
Sequential is a module. Internally it stores its children in a list and the forward walks them in order. “List of layers, run them in sequence” — that’s all.
Sequential is good when the topology is a chain. For anything else, define your own subclass. The pattern: name sub-modules in __init__, write forward to use them:
class MLP(tf.keras.Model):
def __init__(self):
# Call the constructor of the parent class tf.keras.Model to perform
# the necessary initialization
super().__init__()
self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
self.out = tf.keras.layers.Dense(units=10)
# Define the forward propagation of the model, that is, how to return the
# required model output based on the input X
def call(self, X):
return self.out(self.hidden(X))The two attributes self.hidden and self.out aren’t ordinary fields — assigning a Module to a Module attribute registers it as a child. From this moment on:
net.parameters() includes both layers’ weights/biases.net.to('cuda') moves both to GPU.net.state_dict() gives a flat dict of every parameter.Total user code: ~6 lines.
What does nn.Sequential actually do? Almost nothing — its implementation in 4 lines:
forward is just PythonThis is the superpower of the module abstraction: forward is normal Python. Use loops, conditionals, random tensors, anything you’d write in numpy:
class FixedHiddenMLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
# Random weight parameters created with tf.constant are not updated
# during training (i.e., constant parameters)
self.rand_weight = tf.constant(tf.random.uniform((20, 20)))
self.dense = tf.keras.layers.Dense(20, activation=tf.nn.relu)
def build(self, input_shape):
self.flatten.build(input_shape)
self.dense.build((input_shape[0], 20))
super().build(input_shape)
def call(self, inputs):
X = self.flatten(inputs)
# Use the created constant parameters, as well as the relu and
# matmul functions
X = tf.nn.relu(tf.matmul(X, self.rand_weight) + 1)
# Reuse the fully connected layer. This is equivalent to sharing
# parameters with two fully connected layers
X = self.dense(X)
# Control flow
while tf.reduce_sum(tf.math.abs(X)) > 1:
X /= 2
return tf.reduce_sum(X)
def compute_output_shape(self, input_shape):
return (input_shape[0],)The while loop, the fixed rand_weight, even reusing self.linear twice (parameter sharing!) all work, and all flow gradients correctly:
<tf.Tensor: shape=(), dtype=float32, numpy=0.9014120101928711>
Modules nest to any depth. A NestMLP holds a Sequential; a top-level Sequential holds a NestMLP + a Linear + a FixedHiddenMLP:
class NestMLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.net = tf.keras.Sequential()
self.net.add(tf.keras.layers.Dense(64, activation=tf.nn.relu))
self.net.add(tf.keras.layers.Dense(32, activation=tf.nn.relu))
self.dense = tf.keras.layers.Dense(16, activation=tf.nn.relu)
def call(self, inputs):
return self.dense(self.net(inputs))
chimera = tf.keras.Sequential()
chimera.add(NestMLP())
chimera.add(tf.keras.layers.Dense(20))
chimera.add(FixedHiddenMLP())
chimera(X)<tf.Tensor: shape=(), dtype=float32, numpy=0.905106782913208>
The framework recursively walks this tree to find every parameter. Every modern architecture is built this way: ResNet = blocks of ResBlocks of conv+BN+ReLU. Transformer = blocks of attention+FFN. Same recursion every time.
Sequential is a 4-line module that runs children in order; for arbitrary topologies, subclass and write forward.forward is plain Python — control flow, parameter sharing, fixed buffers all welcome.