This is the second part of _Cracking the Annotated Transformer_, we will conclude the topic by talking through the training of transformer. For details of the architecture, you can refer to the [[Cracking the Annotated Transformer - Part I]] for the details.
# Training objective
To understand the transformer, let us be clear about the training objective. In the original paper, the model tries to predict the next token given previous sequence of tokens. Therefore, we can simply write the objective function as $\max_\theta \log \prod_{i,j=0}^lp(y_i|\mathbf{y}_{j<i}, \theta)$
The goal is to maximise the likelihood of generating the next token given its parent tokens from left to right. Let us revisit the architecture as below,
![[transformer_decoder_details.excalidraw.light.svg]]
%%[[transformer_decoder_details.excalidraw.md|🖋 Edit in Excalidraw]], and the [[transformer_decoder_details.excalidraw.dark.svg|dark exported image]]%%
Input and output come as a pair of sequences, such as in English-German translation task, that pair represent the original English sentence, the output represents the translated German sentence. The input goes through the encoder (on the left) and outputs the encoded tensor that represents a memory about the input. On the decoder side, we want to predict the next token, therefore it's important to mask out future tokens from being attended. This is where there first masked multi-head attention is trying to watch. The second self-attention block watches the memory of inputs as both $Q$ and $K$, simply put, the decoder is generating token based on both the memory and the left-side tokens.
# Batch object for data loading
Based on the objective, the batch object is defined to hold both input and output data.
```python
class Batch:
"""Object for holding a batch of data with mask during training."""
def __init__(self, src, tgt=None, pad=2): # 2 = <blank>
self.src = src
self.src_mask = (src != pad).unsqueeze(-2)
if tgt is not None:
self.tgt = tgt[:, :-1]
self.tgt_y = tgt[:, 1:]
self.tgt_mask = self.make_std_mask(self.tgt, pad)
self.ntokens = (self.tgt_y != pad).data.sum()
@staticmethod
def make_std_mask(tgt, pad):
"Create a mask to hide padding and future words."
tgt_mask = (tgt != pad).unsqueeze(-2)
tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as( tgt_mask.data )
return tgt_mask
```
The `src_mask` masks out all padding tokens. And we can see that the `tgt` and `tgt_y` are shifted by one. By making a standard left-to-right mask `tgt_mask`, decoder generates an output from `tgt` which is supposed to match `tgt_y`(labels). Following is an example illustration.
![[transformer_target_example.excalidraw.light.svg]]
%%[[transformer_target_example.excalidraw.md|🖋 Edit in Excalidraw]], and the [[transformer_target_example.excalidraw.dark.svg|dark exported image]]%%
# Loss function and regularisation
Now we know the objective is to predict the next token. To choose a proper loss function, the [annotated-transformer](https://nlp.seas.harvard.edu/annotated-transformer/) chose to use [[Probability basics#Kullback-Leibler distance|KL-divergence]] loss instead of one-hot target distribution. In addition to dropout, [[Label smoothing]] is also used for regularisation. The implementation is as below,
```python
class LabelSmoothing(nn.Module):
"Implement label smoothing."
def __init__(self, size, padding_idx, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(reduction="sum")
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
def forward(self, x, target):
assert x.size(1) == self.size
true_dist = x.data.clone()
true_dist.fill_(self.smoothing / (self.size - 2))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
true_dist[:, self.padding_idx] = 0
mask = torch.nonzero(target.data == self.padding_idx)
if mask.dim() > 0:
true_dist.index_fill_(0, mask.squeeze(), 0.0)
self.true_dist = true_dist
return self.criterion(x, true_dist.clone().detach())
```
Note that the input to `nn.KLDivLoss` should be in log-space.
# Optimiser and learning rate scheduler
The last bit of the trainer is the optimiser and learning scheduler. There is no surprise in optimiser where a standard Adam with $\beta_{1}=0.9, \beta_{2=0.98}$ and $\epsilon=10^{-9}$. However the learning rate increases first linearly for the `warmup` steps, and then decreases proportionally to the inverse square root of step number. The implementation uses a `LambdaLR` scheduler with a rate function defined as below,
```python
def rate(step, model_size, factor, warmup):
""" we have to default the step to 1 for LambdaLR function to avoid zero raising to negative power. """
if step == 0:
step = 1
return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)))
```