%matplotlib inline
from d2l import jax as d2l
from IPython import display
import jax
from jax import numpy as jnp
jnp.linalg.eig(jnp.array([[2, 1], [2, 3]], dtype=jnp.float64))A square matrix \mathbf{A} has eigenvalue \lambda and eigenvector \mathbf{v} when
\mathbf{A}\mathbf{v} = \lambda \mathbf{v}.
Geometrically: \mathbf{A} stretches \mathbf{v} by \lambda but doesn’t rotate it. If \mathbf{A} is diagonalizable: \mathbf{A} = \mathbf{V}\mathbf{\Lambda}\mathbf{V}^{-1} — a basis change in which the action is just stretching along axes.
Why we care: matrix powers \mathbf{A}^t are governed by \lambda^t. Repeated application of \mathbf{A} aligns arbitrary inputs with the dominant eigenvector. That’s the heart of vanishing/exploding gradients in RNNs, of PageRank, and of every iterative solver.
Use a small matrix so the geometry is visible: applying \mathbf{A} to an eigenvector changes scale but not direction.
jnp.linalg.eig(jnp.array([[2, 1], [2, 3]], dtype=jnp.float64))
EigResult(eigenvalues=Array([1.0000001+0.j, 4. +0.j], dtype=complex64), eigenvectors=Array([[-0.70710677+0.j, -0.4472136 +0.j],
[ 0.70710677+0.j, -0.89442724+0.j]], dtype=complex64))
Cheap eigenvalue bounds without computing them: eigenvalues lie in the union of disks centered at a_{ii} with radius \sum_{j \ne i} |a_{ij}|. Useful for stability arguments:
Array([9.080341 +0.j, 0.99228513+0.j, 4.95394 +0.j, 2.9734364 +0.j], dtype=complex64)
Power iteration: keep multiplying by \mathbf{A}. The direction converges to the leading eigenvector; the norm grows like \lambda_1^t:
A = jax.random.normal(jax.random.PRNGKey(42), (k, k), dtype=jnp.float64)
Array([[-0.02830462, 0.46713185, 0.29570296, 0.15354592, -0.12403282],
[ 0.21692315, -1.440879 , 0.7558599 , 0.52140963, 0.9101704 ],
[-0.3844966 , 1.1398233 , 1.4457862 , 1.0809066 , -0.05629321],
[ 0.9095945 , 0.55734617, 0.21905719, -1.4485087 , 0.7641875 ],
[-0.24154697, -1.179381 , -1.9389184 , 0.35626462, -0.24111967]], dtype=float32)
# Calculate the sequence of norms after repeatedly applying `A`
v_in = jax.random.normal(jax.random.PRNGKey(1), (k, 1), dtype=jnp.float64)
norm_list = [float(jnp.linalg.norm(v_in))]
for i in range(1, 100):
v_in = A @ v_in
norm_list.append(float(jnp.linalg.norm(v_in)))
d2l.plot(jnp.arange(0, 100), norm_list, 'Iteration', 'Value') v_in = jax.random.normal(jax.random.PRNGKey(1), (k, 1), dtype=jnp.float64)
After repeated multiplication, normalize the vector to read off the direction; the scale factor estimates the dominant eigenvalue.
norms of eigenvalues: [0.0797363817691803, 1.085721558445927, 1.085721558445927, 1.867778381089393, 1.867778381089393]
# Rescale the matrix `A`
A = A / norm_eigs[-1]
# Do the same experiment again
v_in = jax.random.normal(jax.random.PRNGKey(2), (k, 1), dtype=jnp.float64)
norm_list = [float(jnp.linalg.norm(v_in))]
for i in range(1, 100):
v_in = A @ v_in
norm_list.append(float(jnp.linalg.norm(v_in)))
d2l.plot(jnp.arange(0, 100), norm_list, 'Iteration', 'Value') v_in = jax.random.normal(jax.random.PRNGKey(2), (k, 1), dtype=jnp.float64)