This is the second part of _Cracking the Annotated Transformer_, we will conclude the topic by talking through the training of transformer. For details of the architecture, you can refer to the [[Cracking the Annotated Transformer - Part I]] for the details. # Training objective To understand the transformer, let us be clear about the training objective. In the original paper, the model tries to predict the next token given previous sequence of tokens. Therefore, we can simply write the objective function as $\max_\theta \log \prod_{i,j=0}^lp(y_i|\mathbf{y}_{j<i}, \theta)$ The goal is to maximise the likelihood of generating the next token given its parent tokens from left to right. Let us revisit the architecture as below, ![[transformer_decoder_details.excalidraw.light.svg]] %%[[transformer_decoder_details.excalidraw.md|🖋 Edit in Excalidraw]], and the [[transformer_decoder_details.excalidraw.dark.svg|dark exported image]]%% Input and output come as a pair of sequences, such as in English-German translation task, that pair represent the original English sentence, the output represents the translated German sentence. The input goes through the encoder (on the left) and outputs the encoded tensor that represents a memory about the input. On the decoder side, we want to predict the next token, therefore it's important to mask out future tokens from being attended. This is where there first masked multi-head attention is trying to watch. The second self-attention block watches the memory of inputs as both $Q$ and $K$, simply put, the decoder is generating token based on both the memory and the left-side tokens. # Batch object for data loading Based on the objective, the batch object is defined to hold both input and output data. ```python class Batch: """Object for holding a batch of data with mask during training.""" def __init__(self, src, tgt=None, pad=2): # 2 = <blank> self.src = src self.src_mask = (src != pad).unsqueeze(-2) if tgt is not None: self.tgt = tgt[:, :-1] self.tgt_y = tgt[:, 1:] self.tgt_mask = self.make_std_mask(self.tgt, pad) self.ntokens = (self.tgt_y != pad).data.sum() @staticmethod def make_std_mask(tgt, pad): "Create a mask to hide padding and future words." tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as( tgt_mask.data ) return tgt_mask ``` The `src_mask` masks out all padding tokens. And we can see that the `tgt` and `tgt_y` are shifted by one. By making a standard left-to-right mask `tgt_mask`, decoder generates an output from `tgt` which is supposed to match `tgt_y`(labels). Following is an example illustration. ![[transformer_target_example.excalidraw.light.svg]] %%[[transformer_target_example.excalidraw.md|🖋 Edit in Excalidraw]], and the [[transformer_target_example.excalidraw.dark.svg|dark exported image]]%% # Loss function and regularisation Now we know the objective is to predict the next token. To choose a proper loss function, the [annotated-transformer](https://nlp.seas.harvard.edu/annotated-transformer/) chose to use [[Probability basics#Kullback-Leibler distance|KL-divergence]] loss instead of one-hot target distribution. In addition to dropout, [[Label smoothing]] is also used for regularisation. The implementation is as below, ```python class LabelSmoothing(nn.Module): "Implement label smoothing." def __init__(self, size, padding_idx, smoothing=0.0): super(LabelSmoothing, self).__init__() self.criterion = nn.KLDivLoss(reduction="sum") self.padding_idx = padding_idx self.confidence = 1.0 - smoothing self.smoothing = smoothing self.size = size self.true_dist = None def forward(self, x, target): assert x.size(1) == self.size true_dist = x.data.clone() true_dist.fill_(self.smoothing / (self.size - 2)) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist[:, self.padding_idx] = 0 mask = torch.nonzero(target.data == self.padding_idx) if mask.dim() > 0: true_dist.index_fill_(0, mask.squeeze(), 0.0) self.true_dist = true_dist return self.criterion(x, true_dist.clone().detach()) ``` Note that the input to `nn.KLDivLoss` should be in log-space. # Optimiser and learning rate scheduler The last bit of the trainer is the optimiser and learning scheduler. There is no surprise in optimiser where a standard Adam with $\beta_{1}=0.9, \beta_{2=0.98}$ and $\epsilon=10^{-9}$. However the learning rate increases first linearly for the `warmup` steps, and then decreases proportionally to the inverse square root of step number. The implementation uses a `LambdaLR` scheduler with a rate function defined as below, ```python def rate(step, model_size, factor, warmup): """ we have to default the step to 1 for LambdaLR function to avoid zero raising to negative power. """ if step == 0: step = 1 return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))) ```