from d2l import jax as d2l
import flax
from flax import linen as nn
from flax.training import checkpoints
import jax
from jax import numpy as jnp
import osTwo real problems training pipelines hit constantly:
Both reduce to “save the parameters, recreate the model elsewhere”. The crucial split:
architecture ──> Python code (committed to git)
parameters ──> on-disk file (the .pt / .ckpt / .safetensors)
You save the state, not the class. To resurrect: import the same class, instantiate, then load_state_dict. This section covers both halves of the workflow.
First the building block: torch.save / torch.load work on any tensor, list of tensors, or dict thereof:
Lists and dicts work the same — perfect for grouping related tensors together (e.g. weights of one block, plus its running statistics):
(Array([0., 1., 2., 3.], dtype=float32),
Array([0., 0., 0., 0.], dtype=float32))
Every Module exposes a state_dict() — an ordered dict mapping parameter paths to tensor values:
{
'hidden.weight': Tensor (256, 20),
'hidden.bias': Tensor (256,),
'output.weight': Tensor (10, 256),
'output.bias': Tensor (10,),
}
The keys come from the module tree (self.hidden → hidden). Save this dict, not the module:
'/home/smola/d2l-neu/_notebooks/jax/chapter_builders-guide/ckpt_dir/checkpoint_1'
Build a fresh model with the same Python class, then call load_state_dict. The dict’s keys must match the new model’s parameter paths:
state_dict, not the module object. Pickling the module ties the file to today’s Python class; refactor and old checkpoints break.load_state_dict(d, strict=False) ignores missing/extra keys — useful for partial loading (e.g. swapping a pretrained head).state_dictsafetensors — same shape as state_dict but without pickle, so no arbitrary-code-exec risk. HuggingFace standard, used by every modern library.
Full checkpoint — save more than weights:
Lets you resume training bit-exactly after a crash.
torch.save / torch.load for tensors and dicts; state_dict() / load_state_dict() for modules.safetensors for the no-pickle story.