def _tf_gram(X):
"""Gram matrix for a (N,C,H,W) TF tensor."""
num_channels = tf.shape(X)[1]
n = tf.cast(tf.reduce_prod(tf.shape(X)) // num_channels, tf.float32)
X_flat = tf.reshape(X, (tf.shape(X)[0], num_channels, -1))
return tf.matmul(X_flat, tf.transpose(X_flat, (0, 2, 1))) / (
tf.cast(num_channels, tf.float32) * n)
def train(X, contents_Y, styles_Y, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram = get_inits(X, lr, styles_Y)
# Pre-extract content / style targets ONCE outside the loop and keep them
# as TF constants so the per-step graph never does host round-trips.
contents_Y_tf = [tf.constant(np.array(y), dtype=tf.float32)
for y in contents_Y]
styles_Y_gram_tf = [tf.constant(np.array(g), dtype=tf.float32)
for g in styles_Y_gram]
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
# Use TF for gradient computation since VGG features are in TF
X_tf = tf.Variable(np.array(X))
tf_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
# Compile a single train_step into a TF graph so the per-step Python
# overhead (and the per-step TF dispatch) is fused into one graph call.
@tf.function(reduce_retracing=True)
def train_step(X_tf):
with tf.GradientTape() as tape:
X_vgg = _to_vgg_input(X_tf)
contents_Y_hat_tf, styles_Y_hat_tf = extract_features(
X_vgg, content_layers, style_layers)
contents_l_vec = tf.stack([
tf.reduce_mean(tf.square(yh - y)) * content_weight
for yh, y in zip(contents_Y_hat_tf, contents_Y_tf)])
styles_l_vec = tf.stack([
tf.reduce_mean(tf.square(_tf_gram(yh) - g)) * style_weight
for yh, g in zip(styles_Y_hat_tf, styles_Y_gram_tf)])
tv_l = 0.5 * (
tf.reduce_mean(tf.abs(X_tf[:, :, 1:, :] - X_tf[:, :, :-1, :]))
+ tf.reduce_mean(tf.abs(X_tf[:, :, :, 1:] - X_tf[:, :, :, :-1]))
) * tv_weight
total_loss = (tf.reduce_sum(contents_l_vec)
+ tf.reduce_sum(styles_l_vec) + tv_l)
grads = tape.gradient(total_loss, X_tf)
tf_optimizer.apply_gradients([(grads, X_tf)])
return contents_l_vec, styles_l_vec, tv_l
for epoch in range(num_epochs):
contents_l_vec, styles_l_vec, tv_l = train_step(X_tf)
# Learning rate decay
if (epoch + 1) % lr_decay_epoch == 0:
scale = 0.8 ** ((epoch + 1) // lr_decay_epoch)
tf_optimizer.learning_rate.assign(lr * scale)
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(
jnp.array(X_tf.numpy())))
animator.add(epoch + 1,
[float(tf.reduce_sum(contents_l_vec)),
float(tf.reduce_sum(styles_l_vec)),
float(tv_l)])
return jnp.array(X_tf.numpy())