Information Theory

Information Theory for Learning

Information theory (Shannon, 1948) gives the right language for many things in deep learning:

  • Self-information I(x) = -\log p(x) — surprise of observing x.
  • Entropy H(X) = -\mathbb{E}[\log p(X)] — expected surprise of a distribution.
  • Cross-entropy H(p, q) = -\mathbb{E}_{p}[\log q] — what we minimize during classification.
  • KL divergence D_{KL}(p \| q) = H(p, q) - H(p) — “extra bits” needed to encode p using q.
  • Mutual information — how much knowing X reduces uncertainty about Y.

Cross-entropy loss = KL between true and predicted distributions, up to a constant.

Self-information

Rare events carry more information than common ones. The log base only chooses the unit: bits for base 2, nats for base e.

import torch
from torch.nn import NLLLoss

def nansum(x):
    # Define nansum, as pytorch does not offer it inbuilt.
    return x[~torch.isnan(x)].sum()

def self_information(p):
    return -torch.log2(torch.tensor(p)).item()

self_information(1 / 64)
6.0

Entropy

H(X) = -\sum_x p(x) \log p(x). Maximum at uniform distribution; zero at point masses:

def entropy(p):
    entropy = - p * torch.log2(p)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(entropy)
    return out

entropy(torch.tensor([0.1, 0.5, 0.1, 0.3]))
tensor(1.6855)

Joint and conditional entropy

H(X, Y), H(X \mid Y) — and the chain rule H(X, Y) = H(X) + H(Y \mid X):

def joint_entropy(p_xy):
    joint_ent = -p_xy * torch.log2(p_xy)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(joint_ent)
    return out

joint_entropy(torch.tensor([[0.1, 0.5], [0.1, 0.3]]))
tensor(1.6855)
def conditional_entropy(p_xy, p_x):
    p_y_given_x = p_xy/p_x
    cond_ent = -p_xy * torch.log2(p_y_given_x)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(cond_ent)
    return out

conditional_entropy(torch.tensor([[0.1, 0.5], [0.2, 0.3]]),
                    torch.tensor([0.2, 0.8]))
tensor(0.8635)

Mutual information

I(X; Y) = H(X) - H(X \mid Y) = H(X) + H(Y) - H(X, Y) — how much X and Y share. Symmetric, non-negative, zero iff independent:

def mutual_information(p_xy, p_x, p_y):
    p = p_xy / (p_x * p_y)
    mutual = p_xy * torch.log2(p)
    # Operator `nansum` will sum up the non-nan number
    out = nansum(mutual)
    return out

mutual_information(torch.tensor([[0.1, 0.5], [0.1, 0.3]]),
                   torch.tensor([0.2, 0.8]), torch.tensor([[0.75, 0.25]]))
tensor(0.7195)

KL divergence

D_{KL}(p \| q) = \sum_x p(x) \log \frac{p(x)}{q(x)} \ge 0. Asymmetric (not a metric); zero iff p = q:

def kl_divergence(p, q):
    kl = p * torch.log2(p / q)
    out = nansum(kl)
    return out.abs().item()

Examples

Small distributions make the abstractions concrete: entropy grows with uncertainty, while KL is zero only when the distributions match.

torch.manual_seed(1)

tensor_len = 10000
p = torch.normal(0, 1, (tensor_len, ))
q1 = torch.normal(-1, 1, (tensor_len, ))
q2 = torch.normal(1, 1, (tensor_len, ))

p = torch.sort(p)[0]
q1 = torch.sort(q1)[0]
q2 = torch.sort(q2)[0]
kl_pq1 = kl_divergence(p, q1)
kl_pq2 = kl_divergence(p, q2)
similar_percentage = abs(kl_pq1 - kl_pq2) / ((kl_pq1 + kl_pq2) / 2) * 100

kl_pq1, kl_pq2, similar_percentage
(8582.0341796875, 8828.3095703125, 2.8290698237936858)
kl_q2p = kl_divergence(q2, p)
differ_percentage = abs(kl_q2p - kl_pq2) / ((kl_q2p + kl_pq2) / 2) * 100

kl_q2p, differ_percentage
(14130.125, 46.18621024399691)

Formal definitions

Entropy, cross-entropy, and KL differ by which distribution supplies the expectation and which log-probability is scored.

def cross_entropy(y_hat, y):
    ce = -torch.log(y_hat[range(len(y_hat)), y])
    return ce.mean()
labels = torch.tensor([0, 2])
preds = torch.tensor([[0.3, 0.6, 0.1], [0.2, 0.3, 0.5]])

cross_entropy(preds, labels)
tensor(0.9486)

Cross-entropy in classification

Multi-class classification: data distribution = one-hot on the true class; model = softmax. Cross-entropy = NLL of the true class:

\mathcal{L} = -\sum_i \log q(y_i \mid x_i).

# Implementation of cross-entropy loss in PyTorch combines `nn.LogSoftmax()`
# and `nn.NLLLoss()`
nll_loss = NLLLoss()
loss = nll_loss(torch.log(preds), labels)
loss
tensor(0.9486)

Recap

  • Entropy: expected surprise; KL: extra bits; cross-entropy: KL + entropy.
  • Most DL classification = minimizing cross-entropy = minimizing KL to the empirical distribution.
  • Mutual information appears in InfoNCE / contrastive learning, the IB principle, and many others.