BitNet: Scaling 1-bit Transformers for LLMs
Background
-
LLMs is just a bunch of optimized matrices in N layers with floating-point weights.
-
Generation of LLMs means doing a bunch of matrix multiplication. And doing floating point operations is expensive in terms of energy consumption.
-
These numbers in matrices have a certain precision (32-bit, 16-bit, 8-bit, etc.). Based on the precision, the energy consumption varies.
-
With increasing size of LLMs, deployment and training of such models have raised environmental concerns due to high energy consumption.
What did the authors try to accomplish?
-
This paper proposes a Transformer architecture called
BitNet
which uses 1-bit weights (0 or 1). -
Usually, we do training and quantization separately. But, in this paper, the authors propose a method to do quantization during training.
-
As a result, the memory footprint is reduced and inference is expected faster.
What are the key elements of the approach?
- The authors proposed a new linear layer with binary weights (+1 or -1) instead of conventional floating-point weights.
BitLinear
-
LayerNorm is introduced before activation quantization to ensure stability of the model. Mean is centered to zero.
-
$\gamma$ and $\beta$ are learned scaling factors for layers.
$$ \bar{x} = \text{Clip}(x \times \frac{1}{\gamma}, -1, 1) $$
Techniques used in BitLinear:
-
Large learning rate - Because there is zero precision in the weights, small learning rates don’t have much effect.
-
Straight-through estimator (STE) to approximate the gradient during backpropagation. (Don’t fully understand how this works)
-
Mixed precision training - Gradients & activations are high precision; linear layers in attention block are low precision.
Weights centered to be zero-mean (what does this mean?) before binarization.
- scaling factor $\beta$ is used after binarization to reduce l2 error between the original and binarized weights.
What can we use/remember from the paper??
- If this proposed approach is scalable and stable in bigger datasets - it would significantly reduce the cost of using LLMs.
- A new type of compute device specifically optimized for addition could take advantage of this approach.