%matplotlib inline
from d2l import torch as d2l
import math
import torch
import torchvision
d2l.use_svg_display()Naive Bayes — the simplest probabilistic classifier. Apply Bayes’ rule:
P(y \mid \mathbf{x}) \propto P(y) \prod_i P(x_i \mid y).
The “naive” part is the assumption that features are conditionally independent given the class. Wrong in general — pixels of an image are obviously correlated — but the model is fast, requires little data, and is a useful starting point.
This deck applies it to MNIST digit classification with binarized pixels.
Binarize pixels so each pixel can be modeled as a Bernoulli random variable conditioned on the digit class.
Inspect the binarized digits before fitting: the class templates are recognizable, but neighboring pixels are clearly dependent.
data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
lambda x: torch.floor(x * 255 / 128).squeeze(dim=0)
])
mnist_train = torchvision.datasets.MNIST(
root='./temp', train=True, transform=data_transform, download=True)
mnist_test = torchvision.datasets.MNIST(
root='./temp', train=False, transform=data_transform, download=True)0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
...
93.4%
95.4%
97.4%
99.4%
100.0%
100.0%
For each class y and pixel i, estimate P(x_i = 1 \mid y) from the training set. With Laplace smoothing to avoid zeros:
(4, int)
(torch.Size([28, 28, 28]), torch.Size([28]))
Training is counting, not gradient descent: estimate class priors and per-pixel likelihoods directly from the labeled examples.
tensor([0.0987, 0.1124, 0.0993, 0.1022, 0.0974, 0.0904, 0.0986, 0.1044, 0.0975,
0.0992])
Training stores only class priors and per-class pixel probabilities; prediction multiplies those likelihood terms, usually in log-space.
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Sums of logs instead of products of probabilities — avoids underflow:
log_P_xy = torch.log(P_xy)
log_P_xy_neg = torch.log(1 - P_xy)
log_P_y = torch.log(P_y)
def bayes_pred_stable(x):
x = x.unsqueeze(0) # (28, 28) -> (1, 28, 28)
p_xy = log_P_xy * x + log_P_xy_neg * (1 - x)
p_xy = p_xy.reshape(10, -1).sum(axis=1) # p(x|y)
return p_xy + log_P_y
py = bayes_pred_stable(image)
pytensor([-268.9725, -301.7044, -245.1951, -218.8738, -193.4570, -206.0909,
-292.5226, -114.6257, -220.3313, -163.1784])
The accuracy is useful mostly as a sanity check: on images, the conditional-independence assumption leaves visible performance on the table.