class BiRNN(d2l.Classifier):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
**kwargs):
super().__init__(**kwargs)
self.embedding = keras.layers.Embedding(vocab_size, embed_size)
# Stack bidirectional LSTM layers; all layers return the full
# sequence so we can concatenate the initial- and final-step
# hidden states downstream.
self.encoder = keras.Sequential([
keras.layers.Bidirectional(
keras.layers.LSTM(num_hiddens, return_sequences=True))
for _ in range(num_layers - 1)
] + [
keras.layers.Bidirectional(
keras.layers.LSTM(num_hiddens, return_sequences=True))
])
self.decoder = keras.layers.Dense(2)
def call(self, inputs, training=False):
# inputs shape: (batch_size, num_steps)
embeddings = self.embedding(inputs)
# outputs shape: (batch_size, num_steps, 2 * num_hiddens)
outputs = self.encoder(embeddings, training=training)
# Concatenate hidden states at initial and final time steps
# Shape: (batch_size, 4 * num_hiddens)
encoding = tf.concat([outputs[:, 0, :], outputs[:, -1, :]], axis=1)
outs = self.decoder(encoding)
return outs