Data Manipulation

Tensor Basics

A tensor is an n-dimensional array of numbers — the fundamental data structure for everything that follows in this book.

  • Like a NumPy ndarray, but GPU-accelerated and differentiable.
  • 1-D tensor → vector, 2-D → matrix, n-D → general tensor.
  • All four frameworks expose nearly identical tensor APIs.

In this section: how to create, reshape, index, operate on, and share memory with tensors.

Getting Started

A single import wires up the framework’s tensor library:

import jax
from jax import numpy as jnp

A 1-D tensor of n evenly spaced floats — our running example:

x = jnp.arange(12)
x
E0524 02:47:55.868485 37518 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:47:55.869095 37518 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

Shape and size

Two attributes you’ll reach for constantly:

  • .numel() — the total number of elements
  • .shape — the size along each axis (a tuple)
x.size
12
x.shape
(12,)

Reshaping

reshape rearranges the same elements into a different shape — the total numel is preserved.

X = x.reshape(3, 4)
X
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

A 12-element vector becomes a 3\times 4 matrix. No data is copied; only the stride metadata changes.

Filled and random tensors

Constant fills take a shape tuple — any rank, any size:

jnp.zeros((2, 3, 4))
Array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]], dtype=float32)

For weight initialization, randn draws from \mathcal{N}(0, 1) (elements sampled independently):

# Any call of a random function in JAX requires a key to be
# specified, feeding the same key to a random function will
# always result in the same sample being generated
jax.random.normal(jax.random.PRNGKey(0), (3, 4))
Array([[ 1.6226422 ,  2.0252647 , -0.43359444, -0.07861735],
       [ 0.1760909 , -0.97208923, -0.49529874,  0.4943786 ],
       [ 0.6643493 , -0.9501635 ,  2.1795304 , -1.9551506 ]],      dtype=float32)

ones, full(shape, value), eye(n), empty (uninitialized, fastest), and *_like(x) round out the family.

Tensors from Python lists

For exact control, pass a (nested) list literal — same row-major convention as NumPy:

jnp.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
Array([[2, 1, 4, 3],
       [1, 2, 3, 4],
       [4, 3, 2, 1]], dtype=int32)

Reading

Standard NumPy-style indexing:

  • X[-1] — the last row
  • X[1:3] — rows 1 and 2 (3 is exclusive)
X[-1], X[1:3]
(Array([ 8,  9, 10, 11], dtype=int32),
 Array([[ 4,  5,  6,  7],
        [ 8,  9, 10, 11]], dtype=int32))

Writing

Assignment works the same way:

# JAX arrays are immutable. jax.numpy.ndarray.at index
# update operators create a new array with the corresponding
# modifications made
X_new_1 = X.at[1, 2].set(17)
X_new_1
Array([[ 0,  1,  2,  3],
       [ 4,  5, 17,  7],
       [ 8,  9, 10, 11]], dtype=int32)

A slice on the left sets multiple elements at once:

X_new_2 = X_new_1.at[:2, :].set(12)
X_new_2
Array([[12, 12, 12, 12],
       [12, 12, 12, 12],
       [ 8,  9, 10, 11]], dtype=int32)

Elementwise

Most common math is applied elementwise — same shape in, same shape out.

jnp.exp(x)
Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
       5.4598148e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
       2.9809580e+03, 8.1030840e+03, 2.2026467e+04, 5.9874145e+04],      dtype=float32)

The arithmetic operators are overloaded — +, -, *, /, ** all run elementwise:

x = jnp.array([1.0, 2, 4, 8])
y = jnp.array([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y
(Array([ 3.,  4.,  6., 10.], dtype=float32),
 Array([-1.,  0.,  2.,  6.], dtype=float32),
 Array([ 2.,  4.,  8., 16.], dtype=float32),
 Array([0.5, 1. , 2. , 4. ], dtype=float32),
 Array([ 1.,  4., 16., 64.], dtype=float32))

Concatenation

cat glues tensors along an existing axis. Pick the axis with dim:

  • dim=0 → stack rows (more rows out)
  • dim=1 → stack columns (wider matrix out)
X = jnp.arange(12, dtype=jnp.float32).reshape((3, 4))
Y = jnp.array([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
jnp.concatenate((X, Y), axis=0), jnp.concatenate((X, Y), axis=1)
(Array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [ 2.,  1.,  4.,  3.],
        [ 1.,  2.,  3.,  4.],
        [ 4.,  3.,  2.,  1.]], dtype=float32),
 Array([[ 0.,  1.,  2.,  3.,  2.,  1.,  4.,  3.],
        [ 4.,  5.,  6.,  7.,  1.,  2.,  3.,  4.],
        [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]], dtype=float32))

Comparisons and reductions

Comparison operators broadcast and return a boolean tensor of the same shape — useful for masking entries that satisfy a condition:

X == Y
Array([[False,  True, False,  True],
       [False, False, False, False],
       [False, False, False, False]], dtype=bool)

sum, mean, max, … collapse one or more axes. Without a dim= argument the whole tensor reduces to a scalar:

X.sum()
Array(66., dtype=float32)

Broadcasting

When tensors of different shapes meet, the smaller one is virtually expanded along missing dimensions — no data copy.

The rule: dimensions of size 1 stretch; everything else must match.

a = jnp.arange(3).reshape((3, 1))
b = jnp.arange(2).reshape((1, 2))
a, b
(Array([[0],
        [1],
        [2]], dtype=int32),
 Array([[0, 1]], dtype=int32))
a + b
Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

A 3\times 1 + 1\times 2 becomes a 3\times 2 matrix.

The hidden cost of Y = Y + X

Every assignment of an arithmetic expression allocates a new tensor. Matters a lot when Y is gigabytes:

before = id(Y)
Y = Y + X
id(Y) == before
False

id(Y) == before is False: Y now points at a brand-new buffer.

In-place operations

Pre-allocate the output and write into it with Z[:] = ...:

# JAX arrays do not allow in-place operations

If the original value of X isn’t needed afterward, the most ergonomic forms are X[:] = X + Y or X += Y:

NumPy round-trip

Tensors and NumPy ndarrays convert cheaply — most frameworks share storage with NumPy when possible:

A = jax.device_get(X)
B = jax.device_put(A)
type(A), type(B)
(numpy.ndarray, jaxlib._jax.ArrayImpl)

A size-1 tensor unwraps to a Python scalar with .item(), float(x), or int(x):

a = jnp.array(3.5)
a, a.item(), float(a), int(a)
(Array(3.5, dtype=float32, weak_type=True), 3.5, 3.5, 3)

Recap

  • arange / zeros / ones / randn / tensor(list) — create.
  • .shape, .numel(), reshape — inspect / reorganize.
  • [i, j], [a:b, c:d] — read and write slices.
  • + - * / **, cat, ==, sum — element-wise ops, joins, comparisons, reductions.
  • Broadcasting stretches mismatched shapes; in-place ops avoid copying for large tensors.
  • .numpy() / .item() — leave the tensor world.