Training

Deep Factorization Machines

DeepFM

DeepFM (Guo et al., 2017) — combine a factorization machine and a deep MLP, sharing the embedding table.

  • FM branch — linear + pairwise bilinear interactions (same as the previous deck).
  • Deep branch — concat all field embeddings, feed to an MLP. Captures high-order nonlinear interactions that the bilinear FM misses.

Final prediction: \sigma(\hat y_{FM} + \hat y_{Deep}). End-to-end training. This became a widely used template for CTR models after 2017: explicit interaction terms plus learned nonlinear feature mixing.

Architecture

Shared embeddings feed both the FM head and the deep MLP head:

DeepFM architecture: shared field embeddings feed both an FM branch and a deep MLP branch.

\hat y = \sigma(\hat y^{(FM)} + \hat y^{(DNN)})

Implementation

class DeepFM(nn.Module):
    def __init__(self, field_dims, num_factors, mlp_dims, drop_rate=0.1):
        super().__init__()
        num_inputs = int(sum(field_dims))
        self.embedding = nn.Embedding(num_inputs, num_factors)
        self.fc = nn.Embedding(num_inputs, 1)
        self.linear_layer = nn.Linear(1, 1)
        input_dim = self.embed_output_dim = len(field_dims) * num_factors
        mlp_layers = []
        for dim in mlp_dims:
            mlp_layers.append(nn.Linear(input_dim, dim))
            mlp_layers.append(nn.ReLU())
            mlp_layers.append(nn.Dropout(p=drop_rate))
            input_dim = dim
        mlp_layers.append(nn.Linear(input_dim, 1))
        self.mlp = nn.Sequential(*mlp_layers)

    def forward(self, x):
        embed_x = self.embedding(x)
        square_of_sum = embed_x.sum(dim=1) ** 2
        sum_of_square = (embed_x ** 2).sum(dim=1)
        inputs = embed_x.reshape(-1, self.embed_output_dim)
        x = self.linear_layer(self.fc(x).sum(dim=1)) \
            + 0.5 * (square_of_sum - sum_of_square).sum(dim=1, keepdim=True) \
            + self.mlp(inputs)
        return x

Same CTR pipeline as the FM deck — only the model changes:

batch_size = 2048
data_dir = d2l.download_extract('ctr')
train_data = d2l.CTRDataset(os.path.join(data_dir, 'train.csv'))
test_data = d2l.CTRDataset(os.path.join(data_dir, 'test.csv'),
                           feat_mapper=train_data.feat_mapper,
                           defaults=train_data.defaults)
field_dims = train_data.field_dims
train_iter = torch.utils.data.DataLoader(
    train_data, shuffle=True, drop_last=True, batch_size=batch_size,
    num_workers=d2l.get_dataloader_workers())
test_iter = torch.utils.data.DataLoader(
    test_data, shuffle=False, drop_last=True, batch_size=batch_size,
    num_workers=d2l.get_dataloader_workers())
devices = d2l.try_all_gpus()
net = DeepFM(field_dims, num_factors=10, mlp_dims=[30, 20, 10])

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
    if type(m) == nn.Embedding:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)
lr, num_epochs = 0.01, 30
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.BCEWithLogitsLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, devices)

loss 0.012, train acc 0.995, test acc 0.939
260216.3 examples/sec on [device(type='cuda', index=0)]

Recap

  • DeepFM = FM (low-order) + deep MLP (high-order), sharing the same embedding table.
  • Same input format as FM; one extra branch.
  • A member of the wide/deep interaction-model family: explicit low-order terms plus a nonlinear feature mixer and a sigmoid head.
  • Unlike retrieval architectures that score independently encoded users and items, DeepFM fuses all impression features before scoring one candidate.