import tensorflow as tf
import numpy as npTwo real problems training pipelines hit constantly:
Both reduce to “save the parameters, recreate the model elsewhere”. The crucial split:
architecture ──> Python code (committed to git)
parameters ──> on-disk file (the .pt / .ckpt / .safetensors)
You save the state, not the class. To resurrect: import the same class, instantiate, then load_state_dict. This section covers both halves of the workflow.
First the building block: torch.save / torch.load work on any tensor, list of tensors, or dict thereof:
Lists and dicts work the same — perfect for grouping related tensors together (e.g. weights of one block, plus its running statistics):
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
class MLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
self.out = tf.keras.layers.Dense(units=10)
def call(self, inputs):
x = self.flatten(inputs)
x = self.hidden(x)
return self.out(x)
net = MLP()
X = tf.random.uniform((2, 20))
Y = net(X)Every Module exposes a state_dict() — an ordered dict mapping parameter paths to tensor values:
{
'hidden.weight': Tensor (256, 20),
'hidden.bias': Tensor (256,),
'output.weight': Tensor (10, 256),
'output.bias': Tensor (10,),
}
The keys come from the module tree (self.hidden → hidden). Save this dict, not the module:
Build a fresh model with the same Python class, then call load_state_dict. The dict’s keys must match the new model’s parameter paths:
state_dict, not the module object. Pickling the module ties the file to today’s Python class; refactor and old checkpoints break.load_state_dict(d, strict=False) ignores missing/extra keys — useful for partial loading (e.g. swapping a pretrained head).state_dictsafetensors — same shape as state_dict but without pickle, so no arbitrary-code-exec risk. HuggingFace standard, used by every modern library.
Full checkpoint — save more than weights:
Lets you resume training bit-exactly after a crash.
torch.save / torch.load for tensors and dicts; state_dict() / load_state_dict() for modules.safetensors for the no-pickle story.