%matplotlib inline
from d2l import jax as d2l
import math
import jax
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
d2l.use_svg_display()Naive Bayes — the simplest probabilistic classifier. Apply Bayes’ rule:
P(y \mid \mathbf{x}) \propto P(y) \prod_i P(x_i \mid y).
The “naive” part is the assumption that features are conditionally independent given the class. Wrong in general — pixels of an image are obviously correlated — but the model is fast, requires little data, and is a useful starting point.
This deck applies it to MNIST digit classification with binarized pixels.
Binarize pixels so each pixel can be modeled as a Bernoulli random variable conditioned on the digit class.
Inspect the binarized digits before fitting: the class templates are recognizable, but neighboring pixels are clearly dependent.
(train_images, train_labels), (
test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = np.floor(train_images / 128).astype(np.float32)
test_images = np.floor(test_images / 128).astype(np.float32)
train_labels = train_labels.astype(np.int32)
test_labels = test_labels.astype(np.int32)For each class y and pixel i, estimate P(x_i = 1 \mid y) from the training set. With Laplace smoothing to avoid zeros:
(np.int32(4), numpy.int32)
((28, 28, 28), (28,))
Training is counting, not gradient descent: estimate class priors and per-pixel likelihoods directly from the labeled examples.
Array([0.09871667, 0.11236667, 0.0993 , 0.10218333, 0.09736667,
0.09035 , 0.09863333, 0.10441667, 0.09751666, 0.09915 ], dtype=float32)
Training stores only class priors and per-class pixel probabilities; prediction multiplies those likelihood terms, usually in log-space.
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
Sums of logs instead of products of probabilities — avoids underflow:
log_P_xy = jnp.log(P_xy)
log_P_xy_neg = jnp.log(1 - P_xy)
log_P_y = jnp.log(P_y)
def bayes_pred_stable(x):
x = jnp.expand_dims(x, axis=0) # (28, 28) -> (1, 28, 28)
p_xy = log_P_xy * x + log_P_xy_neg * (1 - x)
p_xy = p_xy.reshape(10, -1).sum(axis=1) # p(x|y)
return p_xy + log_P_y
py = bayes_pred_stable(image)
pyArray([-268.97256, -301.70444, -245.19516, -218.87387, -193.457 ,
-206.09085, -292.52264, -114.62566, -220.33133, -163.17844], dtype=float32)
The accuracy is useful mostly as a sanity check: on images, the conditional-independence assumption leaves visible performance on the table.