File I/O

Saving model state

Two real problems training pipelines hit constantly:

  • Crash recovery — a 12-hour training run dies in hour
    1. Did we just lose 9 hours of compute?
  • Deployment — model trains on a research box; needs to serve from a production cluster, possibly in a different language or runtime.

Both reduce to “save the parameters, recreate the model elsewhere”. The crucial split:

   architecture  ──>  Python code  (committed to git)
   parameters    ──>  on-disk file (the .pt / .ckpt / .safetensors)

You save the state, not the class. To resurrect: import the same class, instantiate, then load_state_dict. This section covers both halves of the workflow.

Saving and loading raw tensors

First the building block: torch.save / torch.load work on any tensor, list of tensors, or dict thereof:

from d2l import jax as d2l
import flax
from flax import linen as nn
from flax.training import checkpoints
import jax
from jax import numpy as jnp
import os
x = jnp.arange(4)
jnp.save('x-file.npy', x)
x2 = jnp.load('x-file.npy', allow_pickle=True)
x2
Array([0, 1, 2, 3], dtype=int32)

weights_only=True is the default since 2024 — pickle in PyTorch checkpoints can execute arbitrary code, so this sandboxes loading.

Containers of tensors

Lists and dicts work the same — perfect for grouping related tensors together (e.g. weights of one block, plus its running statistics):

y = jnp.zeros(4)
jnp.save('xy-files.npy', [x, y])
x2, y2 = jnp.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(Array([0., 1., 2., 3.], dtype=float32),
 Array([0., 0., 0., 0.], dtype=float32))
mydict = {'x': x, 'y': y}
jnp.save('mydict.npy', mydict)
mydict2 = jnp.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': Array([0, 1, 2, 3], dtype=int32), 'y': Array([0., 0., 0., 0.], dtype=float32)},
      dtype=object)

A model to save

class MLP(nn.Module):
    def setup(self):
        self.hidden = nn.Dense(256)
        self.output = nn.Dense(10)

    def __call__(self, x):
        return self.output(nn.relu(self.hidden(x)))

net = MLP()
X = jax.random.normal(d2l.get_key(), (2, 20))
Y, params = net.init_with_output(d2l.get_key(), X)

state_dict() — the canonical interface

Every Module exposes a state_dict() — an ordered dict mapping parameter paths to tensor values:

{
  'hidden.weight':  Tensor (256, 20),
  'hidden.bias':    Tensor (256,),
  'output.weight':  Tensor (10, 256),
  'output.bias':    Tensor (10,),
}

The keys come from the module tree (self.hiddenhidden). Save this dict, not the module:

checkpoints.save_checkpoint(os.path.abspath('ckpt_dir'), params, step=1,
                            overwrite=True)
'/home/smola/d2l-neu/_notebooks/jax/chapter_builders-guide/ckpt_dir/checkpoint_1'

Loading: instantiate, then load

Build a fresh model with the same Python class, then call load_state_dict. The dict’s keys must match the new model’s parameter paths:

clone = MLP()
cloned_params = flax.core.freeze(checkpoints.restore_checkpoint(
    os.path.abspath('ckpt_dir'), target=params))

Sanity check — same architecture + same weights produces bit-identical outputs:

Y_clone = clone.apply(cloned_params, X)
Y_clone == Y
Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]], dtype=bool)

Best practices

  • Save state_dict, not the module object. Pickling the module ties the file to today’s Python class; refactor and old checkpoints break.
  • Keep the model class in your code repo. The file is useless without the matching architecture definition.
  • Strict vs non-strict: load_state_dict(d, strict=False) ignores missing/extra keys — useful for partial loading (e.g. swapping a pretrained head).

Beyond state_dict

  • safetensors — same shape as state_dict but without pickle, so no arbitrary-code-exec risk. HuggingFace standard, used by every modern library.

  • Full checkpoint — save more than weights:

    {'model':     net.state_dict(),
     'optimizer': opt.state_dict(),
     'epoch':     epoch,
     'rng_state': torch.get_rng_state()}

    Lets you resume training bit-exactly after a crash.

Recap

  • Save the state, recreate the architecture in code.
  • torch.save / torch.load for tensors and dicts; state_dict() / load_state_dict() for modules.
  • Always sanity-check by running a known input through before-and-after-load and comparing.
  • Production: HuggingFace safetensors for the no-pickle story.
  • Full checkpoint: save model + optimizer + epoch + RNG state, in one dict.