Information Theory

Information Theory for Learning

Information theory (Shannon, 1948) gives the right language for many things in deep learning:

  • Self-information I(x) = -\log p(x) — surprise of observing x.
  • Entropy H(X) = -\mathbb{E}[\log p(X)] — expected surprise of a distribution.
  • Cross-entropy H(p, q) = -\mathbb{E}_{p}[\log q] — what we minimize during classification.
  • KL divergence D_{KL}(p \| q) = H(p, q) - H(p) — “extra bits” needed to encode p using q.
  • Mutual information — how much knowing X reduces uncertainty about Y.

Cross-entropy loss = KL between true and predicted distributions, up to a constant.

Self-information

Rare events carry more information than common ones. The log base only chooses the unit: bits for base 2, nats for base e.

import jax
from jax import numpy as jnp
import numpy as np

def nansum(x):
    return jnp.nansum(x)

def self_information(p):
    return -jnp.log2(jnp.array(p)).item()

self_information(1 / 64)
6.0

Entropy

H(X) = -\sum_x p(x) \log p(x). Maximum at uniform distribution; zero at point masses:

def entropy(p):
    return nansum(- p * jnp.log2(p))

entropy(jnp.array([0.1, 0.5, 0.1, 0.3]))
Array(1.6854753, dtype=float32)

Joint and conditional entropy

H(X, Y), H(X \mid Y) — and the chain rule H(X, Y) = H(X) + H(Y \mid X):

def joint_entropy(p_xy):
    joint_ent = -p_xy * jnp.log2(p_xy)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(joint_ent)
    return out

joint_entropy(jnp.array([[0.1, 0.5], [0.1, 0.3]]))
Array(1.6854753, dtype=float32)
def conditional_entropy(p_xy, p_x):
    p_y_given_x = p_xy/p_x
    cond_ent = -p_xy * jnp.log2(p_y_given_x)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(cond_ent)
    return out

conditional_entropy(jnp.array([[0.1, 0.5], [0.2, 0.3]]),
                    jnp.array([0.2, 0.8]))
Array(0.86354727, dtype=float32)

Mutual information

I(X; Y) = H(X) - H(X \mid Y) = H(X) + H(Y) - H(X, Y) — how much X and Y share. Symmetric, non-negative, zero iff independent:

def mutual_information(p_xy, p_x, p_y):
    p = p_xy / (p_x * p_y)
    mutual = p_xy * jnp.log2(p)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(mutual)
    return out

mutual_information(jnp.array([[0.1, 0.5], [0.1, 0.3]]),
                   jnp.array([0.2, 0.8]), jnp.array([[0.75, 0.25]]))
Array(0.71946037, dtype=float32)

KL divergence

D_{KL}(p \| q) = \sum_x p(x) \log \frac{p(x)}{q(x)} \ge 0. Asymmetric (not a metric); zero iff p = q:

def kl_divergence(p, q):
    kl = p * jnp.log2(p / q)
    out = nansum(kl)
    return jnp.abs(out).item()

Examples

Small distributions make the abstractions concrete: entropy grows with uncertainty, while KL is zero only when the distributions match.

key = jax.random.PRNGKey(1)
keys = jax.random.split(key, 3)

tensor_len = 10000
p = jax.random.normal(keys[0], (tensor_len, ))
q1 = jax.random.normal(keys[1], (tensor_len, )) - 1
q2 = jax.random.normal(keys[2], (tensor_len, )) + 1

p = jnp.sort(p)
q1 = jnp.sort(q1)
q2 = jnp.sort(q2)
kl_pq1 = kl_divergence(p, q1)
kl_pq2 = kl_divergence(p, q2)
similar_percentage = abs(kl_pq1 - kl_pq2) / ((kl_pq1 + kl_pq2) / 2) * 100

kl_pq1, kl_pq2, similar_percentage
(8675.0947265625, 8916.708984375, 2.7468957905922875)
kl_q2p = kl_divergence(q2, p)
differ_percentage = abs(kl_q2p - kl_pq2) / ((kl_q2p + kl_pq2) / 2) * 100

kl_q2p, differ_percentage
(13366.1494140625, 39.935993400195905)

Formal definitions

Entropy, cross-entropy, and KL differ by which distribution supplies the expectation and which log-probability is scored.

def cross_entropy(y_hat, y):
    ce = -jnp.log(y_hat[jnp.arange(len(y_hat)), y])
    return ce.mean()
labels = jnp.array([0, 2])
preds = jnp.array([[0.3, 0.6, 0.1], [0.2, 0.3, 0.5]])

cross_entropy(preds, labels)
Array(0.94856, dtype=float32)

Cross-entropy in classification

Multi-class classification: data distribution = one-hot on the true class; model = softmax. Cross-entropy = NLL of the true class:

\mathcal{L} = -\sum_i \log q(y_i \mid x_i).

import optax

loss = optax.softmax_cross_entropy_with_integer_labels(
    jnp.log(preds), labels).mean()
loss
Array(0.94856, dtype=float32)

Recap

  • Entropy: expected surprise; KL: extra bits; cross-entropy: KL + entropy.
  • Most DL classification = minimizing cross-entropy = minimizing KL to the empirical distribution.
  • Mutual information appears in InfoNCE / contrastive learning, the IB principle, and many others.