from d2l import torch as d2l
from torch import nnTranslation, summarization, dialogue — variable-length input mapped to variable-length output, no positional alignment between the two.
The decoder is a conditional language model: P(y_t \mid y_{<t}, \text{enc}(x)).
Encoder–decoder: a state in between handles arbitrary in/out lengths.
One method: read a variable-length input. The downstream implementation chooses the architecture (RNN now, Transformer later):
Two methods. init_state packs the encoder output into a state object the decoder consumes; forward takes the next input token plus the state and returns logits + updated state:
class Decoder(nn.Module):
"""The base decoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedErrorEncoderDecoder runs the encoder once on the source, hands its output to init_state, then drives the decoder with the target tokens (teacher forcing during training):
class EncoderDecoder(d2l.Classifier):
"""The base class for the encoder--decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]