import jax
from jax import numpy as jnpA tensor is an n-dimensional array of numbers — the fundamental data structure for everything that follows in this book.
ndarray, but GPU-accelerated and differentiable.In this section: how to create, reshape, index, operate on, and share memory with tensors.
A single import wires up the framework’s tensor library:
A 1-D tensor of n evenly spaced floats — our running example:
E0524 02:47:55.868485 37518 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:47:55.869095 37518 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=int32)
Two attributes you’ll reach for constantly:
.numel() — the total number of elements.shape — the size along each axis (a tuple)12
(12,)
reshape rearranges the same elements into a different shape — the total numel is preserved.
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
A 12-element vector becomes a 3\times 4 matrix. No data is copied; only the stride metadata changes.
Constant fills take a shape tuple — any rank, any size:
Array([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]], dtype=float32)
For weight initialization, randn draws from \mathcal{N}(0, 1) (elements sampled independently):
Array([[ 1.6226422 , 2.0252647 , -0.43359444, -0.07861735],
[ 0.1760909 , -0.97208923, -0.49529874, 0.4943786 ],
[ 0.6643493 , -0.9501635 , 2.1795304 , -1.9551506 ]], dtype=float32)
ones, full(shape, value), eye(n), empty (uninitialized, fastest), and *_like(x) round out the family.
For exact control, pass a (nested) list literal — same row-major convention as NumPy:
Array([[2, 1, 4, 3],
[1, 2, 3, 4],
[4, 3, 2, 1]], dtype=int32)
Standard NumPy-style indexing:
X[-1] — the last rowX[1:3] — rows 1 and 2 (3 is exclusive)(Array([ 8, 9, 10, 11], dtype=int32),
Array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32))
Assignment works the same way:
Array([[ 0, 1, 2, 3],
[ 4, 5, 17, 7],
[ 8, 9, 10, 11]], dtype=int32)
Most common math is applied elementwise — same shape in, same shape out.
Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
5.4598148e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
2.9809580e+03, 8.1030840e+03, 2.2026467e+04, 5.9874145e+04], dtype=float32)
The arithmetic operators are overloaded — +, -, *, /, ** all run elementwise:
(Array([ 3., 4., 6., 10.], dtype=float32),
Array([-1., 0., 2., 6.], dtype=float32),
Array([ 2., 4., 8., 16.], dtype=float32),
Array([0.5, 1. , 2. , 4. ], dtype=float32),
Array([ 1., 4., 16., 64.], dtype=float32))
cat glues tensors along an existing axis. Pick the axis with dim:
dim=0 → stack rows (more rows out)dim=1 → stack columns (wider matrix out)(Array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 2., 1., 4., 3.],
[ 1., 2., 3., 4.],
[ 4., 3., 2., 1.]], dtype=float32),
Array([[ 0., 1., 2., 3., 2., 1., 4., 3.],
[ 4., 5., 6., 7., 1., 2., 3., 4.],
[ 8., 9., 10., 11., 4., 3., 2., 1.]], dtype=float32))
Comparison operators broadcast and return a boolean tensor of the same shape — useful for masking entries that satisfy a condition:
Array([[False, True, False, True],
[False, False, False, False],
[False, False, False, False]], dtype=bool)
When tensors of different shapes meet, the smaller one is virtually expanded along missing dimensions — no data copy.
The rule: dimensions of size 1 stretch; everything else must match.
(Array([[0],
[1],
[2]], dtype=int32),
Array([[0, 1]], dtype=int32))
Y = Y + XEvery assignment of an arithmetic expression allocates a new tensor. Matters a lot when Y is gigabytes:
False
id(Y) == before is False: Y now points at a brand-new buffer.
Pre-allocate the output and write into it with Z[:] = ...:
If the original value of X isn’t needed afterward, the most ergonomic forms are X[:] = X + Y or X += Y:
Tensors and NumPy ndarrays convert cheaply — most frameworks share storage with NumPy when possible:
(numpy.ndarray, jaxlib._jax.ArrayImpl)
arange / zeros / ones / randn / tensor(list) — create..shape, .numel(), reshape — inspect / reorganize.[i, j], [a:b, c:d] — read and write slices.+ - * / **, cat, ==, sum — element-wise ops, joins, comparisons, reductions..numpy() / .item() — leave the tensor world.