Linear Algebra

Linear Algebra Toolkit

The minimum linear-algebra vocabulary every chapter that follows assumes:

  • Scalars / vectors / matrices / tensors — the four ranks.
  • Arithmetic — element-wise, with broadcasting.
  • Reductions — sum, mean, along chosen axes.
  • Products — dot, matrix-vector, matrix-matrix.
  • Norms\ell_1, \ell_2, Frobenius.

Each piece comes with a one-liner of code so you can see the API.

Scalars

Scalars are rank-0 tensors — a single number with all the usual arithmetic operators:

x = jnp.array(3.0)
y = jnp.array(2.0)

x + y, x * y, x / y, x**y
(Array(5., dtype=float32, weak_type=True),
 Array(6., dtype=float32, weak_type=True),
 Array(1.5, dtype=float32, weak_type=True),
 Array(9., dtype=float32, weak_type=True))

Vectors

A vector is a 1-D array of scalars:

x = jnp.arange(3)
x
Array([0, 1, 2], dtype=int32)

Element access uses standard indexing:

x[2]
Array(2, dtype=int32)

Length and shape

The length of a vector is its number of elements:

len(x)
3

For higher-rank tensors len() is just shape[0]. Use .shape when you need every axis:

x.shape
(3,)

Matrices

A matrix is a rank-2 tensor — m rows × n columns:

A = jnp.arange(6).reshape(3, 2)
A
Array([[0, 1],
       [2, 3],
       [4, 5]], dtype=int32)

The transpose flips rows and columns; the same data, axes swapped:

A.T
Array([[0, 2, 4],
       [1, 3, 5]], dtype=int32)

Symmetric matrices

A matrix is symmetric when it equals its own transpose:

\mathbf{A} = \mathbf{A}^\top.

A = jnp.array([[1, 2, 3], [2, 0, 4], [3, 4, 5]])
A == A.T
Array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]], dtype=bool)

Useful: the input to many losses (covariance, Gram matrix) is symmetric.

Higher-rank tensors

The naming generalizes — a rank-n tensor has n axes. A 3-D tensor is the shape of a stack of matrices (think batched RGB images: batch × height × width × channels in TF, batch × channels × height × width in PyTorch):

jnp.arange(24).reshape(2, 3, 4)
Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]], dtype=int32)

Element-wise arithmetic

Two tensors of the same shape combine element-wise:

A = jnp.arange(6, dtype=jnp.float32).reshape(2, 3)
B = A
A, A + B
(Array([[0., 1., 2.],
        [3., 4., 5.]], dtype=float32),
 Array([[ 0.,  2.,  4.],
        [ 6.,  8., 10.]], dtype=float32))

The element-wise product of matrices is the Hadamard product \mathbf{A} \odot \mathbf{B}:

A * B
Array([[ 0.,  1.,  4.],
       [ 9., 16., 25.]], dtype=float32)

Scalar–tensor arithmetic

A scalar broadcasts to every element of a tensor:

a = 2
X = jnp.arange(24).reshape(2, 3, 4)
a + X, (a * X).shape
(Array([[[ 2,  3,  4,  5],
         [ 6,  7,  8,  9],
         [10, 11, 12, 13]],
 
        [[14, 15, 16, 17],
         [18, 19, 20, 21],
         [22, 23, 24, 25]]], dtype=int32),
 (2, 3, 4))

Reductions: sum

The sum \sum_i x_i collapses every element into one scalar:

x = jnp.arange(3, dtype=jnp.float32)
x, x.sum()
(Array([0., 1., 2.], dtype=float32), Array(3., dtype=float32))

Same call works for any rank — it folds across all axes by default:

A.shape, A.sum()
((2, 3), Array(15., dtype=float32))

Reducing along an axis

To collapse only one or some axes, pass axis=:

A.shape, A.sum(axis=0).shape
((2, 3), (3,))
A.shape, A.sum(axis=1).shape
((2, 3), (2,))

axis=0 collapses rows (output rank drops by one along that axis), axis=1 collapses columns.

Reducing all axes

A list of axes reduces over each:

A.sum(axis=[0, 1]) == A.sum()  # Same as A.sum()
Array(True, dtype=bool)

axis=[0,1] is identical to the default sum() for a rank-2 tensor.

Mean

\bar x = \frac{1}{n} \sum_i x_i. Either built-in mean() or sum() / numel():

A.mean(), A.sum() / A.size
(Array(2.5, dtype=float32), Array(2.5, dtype=float32))

And along a single axis:

A.mean(axis=0), A.sum(axis=0) / A.shape[0]
(Array([1.5, 2.5, 3.5], dtype=float32), Array([1.5, 2.5, 3.5], dtype=float32))

Non-reducing sum (keepdims)

Set keepdims=True to preserve the reduced axis (size 1) so broadcasting still works:

sum_A = A.sum(axis=1, keepdims=True)
sum_A, sum_A.shape
(Array([[ 3.],
        [12.]], dtype=float32),
 (2, 1))

Now A / sum_A divides every row by its sum — common normalization:

A / sum_A
Array([[0.        , 0.33333334, 0.6666667 ],
       [0.25      , 0.33333334, 0.41666666]], dtype=float32)

Cumulative sum

cumsum(axis=k) keeps the axis but reports a running total — useful for time-series and prefix sums:

A.cumsum(axis=0)
Array([[0., 1., 2.],
       [3., 5., 7.]], dtype=float32)

Dot product

\mathbf{x}^\top \mathbf{y} = \sum_i x_i y_i — element-wise multiply, then sum:

y = jnp.ones(3, dtype = jnp.float32)
x, y, jnp.dot(x, y)
(Array([0., 1., 2.], dtype=float32),
 Array([1., 1., 1.], dtype=float32),
 Array(3., dtype=float32))

Two equivalent ways to compute it:

jnp.sum(x * y)
Array(3., dtype=float32)

Matrix products

\mathbf{A}\mathbf{x} is a length-m vector — one dot product per row of A. The most ubiquitous operation in deep learning: a fully-connected layer’s forward pass.

A.shape, x.shape, jnp.matmul(A, x)
((2, 3), (3,), Array([ 5., 14.], dtype=float32))

\mathbf{AB} is m matrix-vector products stitched into a matrix (equivalently, m \cdot n row-by-column dot products):

B = jnp.ones((3, 4))
jnp.matmul(A, B)
Array([[ 3.,  3.,  3.,  3.],
       [12., 12., 12., 12.]], dtype=float32)

Norms

The \ell_2 norm — Euclidean length, the workhorse of optimization:

\|\mathbf{x}\|_2 = \sqrt{\sum_i x_i^2}.\qquad \|\mathbf{x}\|_1 = \sum_i |x_i|.\qquad \|\mathbf{X}\|_\text{F} = \sqrt{\sum_{i,j} x_{ij}^2}.

u = jnp.array([3.0, -4.0])
jnp.linalg.norm(u)
Array(5., dtype=float32)

\ell_1 is less sensitive to outliers and promotes sparsity:

jnp.linalg.norm(u, ord=1) # same as jnp.abs(u).sum()
Array(7., dtype=float32)

For matrices, Frobenius is the \ell_2 of the flattened matrix:

jnp.linalg.norm(jnp.ones((4, 9)))
Array(6., dtype=float32)

Recap

  • Scalars / vectors / matrices / tensors are ranks 0 / 1 / 2 / n.
  • Element-wise ops, scalar broadcasting, Hadamard product (*).
  • Reductions: sum, mean, with axis= and keepdims=.
  • Products: dot, mv, mm / @.
  • Norms: \ell_1, \ell_2, Frobenius.

Most deep-learning math compiles down to this short list.