File I/O

Saving model state

Two real problems training pipelines hit constantly:

  • Crash recovery — a 12-hour training run dies in hour
    1. Did we just lose 9 hours of compute?
  • Deployment — model trains on a research box; needs to serve from a production cluster, possibly in a different language or runtime.

Both reduce to “save the parameters, recreate the model elsewhere”. The crucial split:

   architecture  ──>  Python code  (committed to git)
   parameters    ──>  on-disk file (the .pt / .ckpt / .safetensors)

You save the state, not the class. To resurrect: import the same class, instantiate, then load_state_dict. This section covers both halves of the workflow.

Saving and loading raw tensors

First the building block: torch.save / torch.load work on any tensor, list of tensors, or dict thereof:

import tensorflow as tf
import numpy as np
x = tf.range(4)
np.save('x-file.npy', x)
x2 = np.load('x-file.npy', allow_pickle=True)
x2
array([0, 1, 2, 3], dtype=int32)

weights_only=True is the default since 2024 — pickle in PyTorch checkpoints can execute arbitrary code, so this sandboxes loading.

Containers of tensors

Lists and dicts work the same — perfect for grouping related tensors together (e.g. weights of one block, plus its running statistics):

y = tf.zeros(4)
np.save('xy-files.npy', [x, y])
x2, y2 = np.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
mydict = {'x': x, 'y': y}
np.save('mydict.npy', mydict)
mydict2 = np.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': <tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, 'y': <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 0., ...
      dtype=object)

A model to save

class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.out = tf.keras.layers.Dense(units=10)

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.hidden(x)
        return self.out(x)

net = MLP()
X = tf.random.uniform((2, 20))
Y = net(X)

state_dict() — the canonical interface

Every Module exposes a state_dict() — an ordered dict mapping parameter paths to tensor values:

{
  'hidden.weight':  Tensor (256, 20),
  'hidden.bias':    Tensor (256,),
  'output.weight':  Tensor (10, 256),
  'output.bias':    Tensor (10,),
}

The keys come from the module tree (self.hiddenhidden). Save this dict, not the module:

net.save_weights('mlp.weights.h5')

Loading: instantiate, then load

Build a fresh model with the same Python class, then call load_state_dict. The dict’s keys must match the new model’s parameter paths:

clone = MLP()
clone(X)
clone.load_weights('mlp.weights.h5')

Sanity check — same architecture + same weights produces bit-identical outputs:

Y_clone = clone(X)
Y_clone == Y
<tf.Tensor: shape=(2, 10), dtype=bool, numpy=
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])>

Best practices

  • Save state_dict, not the module object. Pickling the module ties the file to today’s Python class; refactor and old checkpoints break.
  • Keep the model class in your code repo. The file is useless without the matching architecture definition.
  • Strict vs non-strict: load_state_dict(d, strict=False) ignores missing/extra keys — useful for partial loading (e.g. swapping a pretrained head).

Beyond state_dict

  • safetensors — same shape as state_dict but without pickle, so no arbitrary-code-exec risk. HuggingFace standard, used by every modern library.

  • Full checkpoint — save more than weights:

    {'model':     net.state_dict(),
     'optimizer': opt.state_dict(),
     'epoch':     epoch,
     'rng_state': torch.get_rng_state()}

    Lets you resume training bit-exactly after a crash.

Recap

  • Save the state, recreate the architecture in code.
  • torch.save / torch.load for tensors and dicts; state_dict() / load_state_dict() for modules.
  • Always sanity-check by running a known input through before-and-after-load and comparing.
  • Production: HuggingFace safetensors for the no-pickle story.
  • Full checkpoint: save model + optimizer + epoch + RNG state, in one dict.