File I/O

Saving model state

Two real problems training pipelines hit constantly:

  • Crash recovery — a 12-hour training run dies in hour
    1. Did we just lose 9 hours of compute?
  • Deployment — model trains on a research box; needs to serve from a production cluster, possibly in a different language or runtime.

Both reduce to “save the parameters, recreate the model elsewhere”. The crucial split:

   architecture  ──>  Python code  (committed to git)
   parameters    ──>  on-disk file (the .pt / .ckpt / .safetensors)

You save the state, not the class. To resurrect: import the same class, instantiate, then load_state_dict. This section covers both halves of the workflow.

Saving and loading raw tensors

First the building block: torch.save / torch.load work on any tensor, list of tensors, or dict thereof:

import torch
from torch import nn
from torch.nn import functional as F
x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file', weights_only=True)
x2
tensor([0, 1, 2, 3])

weights_only=True is the default since 2024 — pickle in PyTorch checkpoints can execute arbitrary code, so this sandboxes loading.

Containers of tensors

Lists and dicts work the same — perfect for grouping related tensors together (e.g. weights of one block, plus its running statistics):

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files', weights_only=True)
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict', weights_only=True)
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

A model to save

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.output = nn.LazyLinear(10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

state_dict() — the canonical interface

Every Module exposes a state_dict() — an ordered dict mapping parameter paths to tensor values:

{
  'hidden.weight':  Tensor (256, 20),
  'hidden.bias':    Tensor (256,),
  'output.weight':  Tensor (10, 256),
  'output.bias':    Tensor (10,),
}

The keys come from the module tree (self.hiddenhidden). Save this dict, not the module:

torch.save(net.state_dict(), 'mlp.params')

Loading: instantiate, then load

Build a fresh model with the same Python class, then call load_state_dict. The dict’s keys must match the new model’s parameter paths:

clone = MLP()
clone.load_state_dict(torch.load('mlp.params', weights_only=True))
clone.eval()
MLP(
  (hidden): LazyLinear(in_features=0, out_features=256, bias=True)
  (output): LazyLinear(in_features=0, out_features=10, bias=True)
)

Sanity check — same architecture + same weights produces bit-identical outputs:

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

Best practices

  • Save state_dict, not the module object. Pickling the module ties the file to today’s Python class; refactor and old checkpoints break.
  • Keep the model class in your code repo. The file is useless without the matching architecture definition.
  • Strict vs non-strict: load_state_dict(d, strict=False) ignores missing/extra keys — useful for partial loading (e.g. swapping a pretrained head).

Beyond state_dict

  • safetensors — same shape as state_dict but without pickle, so no arbitrary-code-exec risk. HuggingFace standard, used by every modern library.

  • Full checkpoint — save more than weights:

    {'model':     net.state_dict(),
     'optimizer': opt.state_dict(),
     'epoch':     epoch,
     'rng_state': torch.get_rng_state()}

    Lets you resume training bit-exactly after a crash.

Recap

  • Save the state, recreate the architecture in code.
  • torch.save / torch.load for tensors and dicts; state_dict() / load_state_dict() for modules.
  • Always sanity-check by running a known input through before-and-after-load and comparing.
  • Production: HuggingFace safetensors for the no-pickle story.
  • Full checkpoint: save model + optimizer + epoch + RNG state, in one dict.