Encoder interface

The Encoder–Decoder Architecture

Encoder-Decoder

Translation, summarization, dialogue — variable-length input mapped to variable-length output, no positional alignment between the two.

  • Encoder reads the source sequence, compresses it into a state.
  • Decoder reads that state plus target tokens so far, predicts the next target token.

The decoder is a conditional language model: P(y_t \mid y_{<t}, \text{enc}(x)).

The architecture

Encoder–decoder: a state in between handles arbitrary in/out lengths.

Setup

from d2l import tensorflow as d2l
import tensorflow as tf

One method: read a variable-length input. The downstream implementation chooses the architecture (RNN now, Transformer later):

class Encoder(tf.keras.layers.Layer):
    """The base encoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def call(self, X, *args):
        raise NotImplementedError

Decoder interface

Two methods. init_state packs the encoder output into a state object the decoder consumes; forward takes the next input token plus the state and returns logits + updated state:

class Decoder(tf.keras.layers.Layer):
    """The base decoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def call(self, X, state):
        raise NotImplementedError

Wiring them together

EncoderDecoder runs the encoder once on the source, hands its output to init_state, then drives the decoder with the target tokens (teacher forcing during training):

class EncoderDecoder(d2l.Classifier):
    """The base class for the encoder--decoder architecture."""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def call(self, enc_X, dec_X, *args, training=None):
        enc_all_outputs = self.encoder(enc_X, *args, training=training)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        # Return decoder output only
        return self.decoder(dec_X, dec_state, training=training)[0]

Recap

  • Encoder–decoder splits seq2seq into two pieces with a state in between — fixed shape inside, variable lengths outside.
  • The decoder is just a conditional language model: same P(y_t \mid y_{<t}, \cdot) factorization, with the encoder output as extra context.
  • Concrete implementations only override the encoder, the decoder, and how the encoder output becomes a state.
  • Same scaffold will host RNN seq2seq, attention, and Transformers in the chapters ahead.