File I/O

Saving model state

Two real problems training pipelines hit constantly:

  • Crash recovery — a 12-hour training run dies in hour
    1. Did we just lose 9 hours of compute?
  • Deployment — model trains on a research box; needs to serve from a production cluster, possibly in a different language or runtime.

Both reduce to “save the parameters, recreate the model elsewhere”. The crucial split:

   architecture  ──>  Python code  (committed to git)
   parameters    ──>  on-disk file (the .pt / .ckpt / .safetensors)

You save the state, not the class. To resurrect: import the same class, instantiate, then load_state_dict. This section covers both halves of the workflow.

Saving and loading raw tensors

First the building block: torch.save / torch.load work on any tensor, list of tensors, or dict thereof:

from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
x = np.arange(4)
npx.save('x-file', x)
x2 = npx.load('x-file')
x2

weights_only=True is the default since 2024 — pickle in PyTorch checkpoints can execute arbitrary code, so this sandboxes loading.

Containers of tensors

Lists and dicts work the same — perfect for grouping related tensors together (e.g. weights of one block, plus its running statistics):

y = np.zeros(4)
npx.savez('x-files', x, y)
x2, y2 = (npx.load('x-files')[k] for k in ('arr_0', 'arr_1'))
(x2, y2)
mydict = {'x': x, 'y': y}
npx.savez('mydict', **mydict)
mydict2 = npx.load('mydict')
mydict2

A model to save

class MLP(nn.Block):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Dense(256, activation='relu')
        self.output = nn.Dense(10)

    def forward(self, x):
        return self.output(self.hidden(x))

net = MLP()
net.initialize()
X = np.random.uniform(size=(2, 20))
Y = net(X)

state_dict() — the canonical interface

Every Module exposes a state_dict() — an ordered dict mapping parameter paths to tensor values:

{
  'hidden.weight':  Tensor (256, 20),
  'hidden.bias':    Tensor (256,),
  'output.weight':  Tensor (10, 256),
  'output.bias':    Tensor (10,),
}

The keys come from the module tree (self.hiddenhidden). Save this dict, not the module:

net.save_parameters('mlp.params')

Loading: instantiate, then load

Build a fresh model with the same Python class, then call load_state_dict. The dict’s keys must match the new model’s parameter paths:

clone = MLP()
clone.load_parameters('mlp.params')

Sanity check — same architecture + same weights produces bit-identical outputs:

Y_clone = clone(X)
Y_clone == Y

Best practices

  • Save state_dict, not the module object. Pickling the module ties the file to today’s Python class; refactor and old checkpoints break.
  • Keep the model class in your code repo. The file is useless without the matching architecture definition.
  • Strict vs non-strict: load_state_dict(d, strict=False) ignores missing/extra keys — useful for partial loading (e.g. swapping a pretrained head).

Beyond state_dict

  • safetensors — same shape as state_dict but without pickle, so no arbitrary-code-exec risk. HuggingFace standard, used by every modern library.

  • Full checkpoint — save more than weights:

    {'model':     net.state_dict(),
     'optimizer': opt.state_dict(),
     'epoch':     epoch,
     'rng_state': torch.get_rng_state()}

    Lets you resume training bit-exactly after a crash.

Recap

  • Save the state, recreate the architecture in code.
  • torch.save / torch.load for tensors and dicts; state_dict() / load_state_dict() for modules.
  • Always sanity-check by running a known input through before-and-after-load and comparing.
  • Production: HuggingFace safetensors for the no-pickle story.
  • Full checkpoint: save model + optimizer + epoch + RNG state, in one dict.