from d2l import jax as d2l
from flax import linen as nnTranslation, summarization, dialogue — variable-length input mapped to variable-length output, no positional alignment between the two.
The decoder is a conditional language model: P(y_t \mid y_{<t}, \text{enc}(x)).
Encoder–decoder: a state in between handles arbitrary in/out lengths.
One method: read a variable-length input. The downstream implementation chooses the architecture (RNN now, Transformer later):
Two methods. init_state packs the encoder output into a state object the decoder consumes; forward takes the next input token plus the state and returns logits + updated state:
class Decoder(nn.Module):
"""The base decoder interface for the encoder--decoder architecture."""
def setup(self):
raise NotImplementedError
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def __call__(self, X, state):
raise NotImplementedErrorEncoderDecoder runs the encoder once on the source, hands its output to init_state, then drives the decoder with the target tokens (teacher forcing during training):
class EncoderDecoder(d2l.Classifier):
"""The base class for the encoder--decoder architecture."""
encoder: nn.Module
decoder: nn.Module
def __call__(self, enc_X, dec_X, *args, training=False):
enc_all_outputs = self.encoder(enc_X, *args, training=training)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=training)[0]