batch_size = 2048
data_dir = d2l.download_extract('ctr')
train_data = d2l.CTRDataset(os.path.join(data_dir, 'train.csv'))
test_data = d2l.CTRDataset(os.path.join(data_dir, 'test.csv'),
feat_mapper=train_data.feat_mapper,
defaults=train_data.defaults)
field_dims = train_data.field_dims
train_iter = torch.utils.data.DataLoader(
train_data, shuffle=True, drop_last=True, batch_size=batch_size,
num_workers=d2l.get_dataloader_workers())
test_iter = torch.utils.data.DataLoader(
test_data, shuffle=False, drop_last=True, batch_size=batch_size,
num_workers=d2l.get_dataloader_workers())
devices = d2l.try_all_gpus()
net = DeepFM(field_dims, num_factors=10, mlp_dims=[30, 20, 10])
def init_weights(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
if type(m) == nn.Embedding:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
lr, num_epochs = 0.01, 30
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.BCEWithLogitsLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, devices)