Intro
Let’s explore how LayerNorm is handled, as one of the layers in the model. We begin with the PyTorch documentation for LayerNorm. LayerNorm originates from the seminal paper by Ba et al. (2016) and was integrated into the Transformer architecture by Vaswani et al. in their renowned paper “Attention is All You Need.” GPT-2 adopted a similar architecture to the Transformer but notably shifted the position of LayerNorm, now referred to as the pre-normalization version. In this version, the residual path of the Transformer remains clean, with LayerNorm positioned as the initial layer of each block, leading to improved training stability.
Upon inspecting the PyTorch implementation of LayerNorm, you’ll likely notice the absence of the actual equation implementation. This is because it’s deeply embedded within the codebase, obscured behind a dynamic dispatcher, possibly in auto-generated CUDA code (for detailed enthusiasts, refer to layer_norm.cpp and layer_norm_kernel.cu). PyTorch prioritizes efficiency, which justifies this design choice. However, for our purposes, understanding LayerNorm necessitates starting by manually implementing it using simpler PyTorch operations. Although less efficient than using a LayerNorm
module directly, this approach is algorithmically instructive.
LayerNorm Implementation
Forward pass
Here’s a direct implementation of LayerNorm’s mathematics using basic PyTorch operations:
The activation tensors in the residual path of the Transformer during training are 3-dimensional arrays (tensors), of shape B,T,C
. is the batch size, is time, and is channels. For example, is one setting you might see, for the smallest (124 million parameter) GPT-2 model.
We can forward this layer with some random numbers:
What we get out is the tensor out
, also of shape B,T,C
, where each C-dimensional “fibre” of activations (as we call them) is normalized and then scaled and at the end also shifted by the weights and biases of this layer. Notice that, importantly, we also return a variable cache
, which is a tuple of the input activations x
, the weights w
, the mean mean
, and the reciprocal standard deviation rstd
. These are all variables we need during the backward pass.
Backward pass
PyTorch can of course do the backward pass of this layer for us with its Autograd. Let’s do that first:
You see here that we created a fakeloss
, which simply takes a (random) weighted combination of all the outputs of our layernorm. All this is doing is projecting all of the B,T,C
numbers into a single scalar value (loss), so that we have a single output of our “computational graph”. Typically this would be the loss of the model, but here we’re just doing a fake loss. We then call backward()
on this scalar, and PyTorch will compute all the gradients for us on all the inputs to this graph - i.e. the input activations x
, the weights w
, and the biases b
. If you don’t know too much about autograd, I’d encourage you to watch karpathy’s micrograd video, where he built a tiny autograd engine of unit tensors. So the magic of PyTorch autograd is that after we call .backward
, it will populate the .grad
attribute of all the tensors that have requires_grad=True
with the gradients of the loss with respect to that tensor. These gradients are telling us the slope of the loss for all of the input numbers in x,w,b. Therefore, the shape of x.grad
, w.grad
, and b.grad
are exactly the same as the shape of x
, w
, and b
.
But we don’t want to use PyTorch Autograd. We want to do the backward pass manually. So we take out pen and paper and write out the expression for LayerNorm. The forward pass has the following mathematical form:
where is elementwise multiplication, is the mean, is the variance, and is a small constant to avoid division by zero. Remembering the rules of differentiation from calculus, we now want to derive the gradients. When you work through the differentiation, you’ll notice that the expressions simplify analytically and you can move the terms around and simplify the expression somehwat. So you don’t have to manually backward every individual line in the forward pass. In particular, we get:
So given the gradients on every individual output number stored in dout
, and the cache
from the forward pass, we can now backward through this layer into the inputs, to continue the chain rule of the backward pass. So now we can do our own backward pass and see that they match (the errors are tiny):
Notice one more thing. Inside the backward pass we recomputed the variable norm
. We already calculated this variable in the forward pass but then we threw it away! Couldn’t we have made this also be a part of the cache
and save this recompute? Actually, we very well could and you’d of course get the exact same results. The amount of stuff we save into our cache
is completely up to us. We didn’t even have to save mean
and rstd
either, and we could have recomputed them in the backward pass. The difference is that mean
and rstd
are very small, only of shape B,T
, where as norm
is of shape B,T,C
. So this is simply a tradeoff between memory and compute. By not keeping norm
in the cache, we are saving memory, but we are trading it off for a bit of compute later in the backward pass. This is very common in all the layers, and you’ll see that different implementations of various layers in deep learning frameworks may all have different “checkpointing settings”. Yes, confusingly enough, this is called checkpointing and has nothing to do with saving the model weights to disk. It’s about saving intermediate variables in the forward pass to save compute in the backward pass.
Okay so that’s the version with PyTorch tensors. Now we have to move this to C
and get rid of the Tensor abstraction. Before I give you the full implementation of the forward pass, a brief word on Tensors. What are Tensors? They are
- a 1D block of memory called Storage that holds the raw data, and
- a View over that storage that holds its shape. PyTorch Internals could be helpful here. So for example if we have the 3D tensor:
This is 2x3x4 Tensor, but the underlying memory of it is just one single 1D array of size 2*3*4=24. The View is just a shape over this 1D array. So now when we index into this PyTorch tensor, for example a[1,2,3]
, PyTorch computes the offset into the 1D array as 1*3*4 + 2*4 + 3 = 23
, and return the value at that offset. The general formula is that if you want to retrieve any element b,t,c
, you compute the offset into Storage as b*T*C + t*C + c
. So for example:
Both of these print . So in this way, we know how to access all the individual elements, and how to offset all the pointers. Notice in particular that the channel dimension is the innermost dimension. So as we increase offset by 1, we are traversing the channel dimension. This is important to consider for the memory layout of our C implementation.
C implementation
Forward pass
The equivalent forward pass in C becomes:
You’ll see how I offset the pointer to the inp[b,t]
, and then you know that the next C
elements are the channels of that position in (batch, time). And the backward pass:
Backward pass
One additional detail to note is that we always into the gradients. We never use and we never use . This is important stylistically because if you have one variable used multiple times in a graph, the backward pass gradients always add up. In this repo this is not important because we don’t have exotic branching, but it’s proper. So during training we always first do zero_grad
to set all the gradients to zero, and then we accumulate into them during backward pass.
Difference between LayerNorm and RMSNorm
One more note on differences between training and inference. Some of you may have know RMSNorm or might heard of it.
- Algorithmic Difference: One key difference between LayerNorm and RMSNorm lies in their algorithms. RMSNorm doesn’t calculate or subtract the mean from the input; instead, it solely normalizes based on the norm. This means that it normalizes using the norm, not the standard deviation, as there’s no mean subtraction involved. This simplification has gained popularity due to its effectiveness, if not slightly better performance. Moreover, unlike LayerNorm, RMSNorm doesn’t incorporate biases in its computation; it only includes a weight for scaling after normalization.
- Inference Simplification: In terms of implementation, there’s a notable difference in how inference is handled. While LayerNorm typically operates with a batch dimension (
B
), assuming a batch size greater than 1, RMSNorm simplifies the inference process by assuming a batch size of 1. This simplification streamlines the codebase, removing the need for loops iterating over batch dimensions. - Absence of Time Dimension in Individual Layer: Another difference arises in handling the time dimension (
T
) within individual layers during inference. While LayerNorm may loop over time within each layer during training, calculating normalization across all time steps, RMSNorm simplifies this process. In inference, token generation occurs sequentially, with each token predicted at timet
fed into the forward pass at timet+1
. Therefore, you won’t find loops iterating over time dimensions within individual layers in the implementation of RMSNorm. - Lack of Intermediate Calculations Memory: Lastly, during inference, RMSNorm doesn’t retain intermediate calculations, memory, or cache. This is because there’s no backward pass to follow during inference. Therefore, there’s no need to keep track of intermediate variables, resulting in significantly lower memory consumption during inference compared to training. Additionally, there’s no implementation of a
backward
function for RMSNorm, as there’s no backward pass during inference.
Conclusion
This was just the LayerNorm. We go through the exact same process for all the other GPT layers. Most of the other layers are actually easier than LayerNorm. Hope that helps!