Authors Kuzmin, Baalen et. al. (2024) explore how FP8 choices (number of bits allocated between mantissa and exponent) impact accuracies of various neural net architectures (Transformer encoders, convolutional nets etc.) on benchmark tasks. Authors conclude that for Post Training Quantization, FP8 achieves better accuracies than INT8 for a variety of networks. Further, for networks with severe activation outliers such as Transformers, a larger number of exponent bits benefits accuracy.
Overview of Floating Point Numbers and the Precision/Range Tradeoff
A signed \(N\) bit floating point representation, with \(m\) bit mantissa and \(e\) bit exponent (we have \(N = 1 + e + m\)), is this subset of real numbers: \[ F = \left\{(-1)^s2^{p-b}\left(1 + \sum_{i=1}^m\frac{d_i}{2^i}\right) \bigg\lvert d_i \in \{0, 1\}, s \in \{0, 1\}, p \in \{0,1,\cdots,2^e-1\}\right\} \] Here \(b \in \mathbb{Z}\) is called the "exponent bias", typically fixed at \(2^{e-1}\). Subtracting this bias from \(p\) allows for negative exponents.
Example 1: Consider a 4-bit signed FP number system with 2 bits reserved for mantissa. The number \(1011\) in this system represents \((-1)^1 2^{0-1} \left(1 + \frac{1}{2} + \frac{1}{2^2} \right) = -0.875\). \[ \underbrace{1}_{\text{sign bit}}\quad \underbrace{0}_{\text{exponent}}\quad \underbrace{11}_{\text{mantissa}} \]Example 2: Let's calculate the maximum and minimum values for E4M3, one of two formats proposed in FP8 Formats for Deep Learning, and used in H100 GPUs (along with E5M2). E4M3 has 1 bit for sign, 4 bits for the exponent, 3 bits for the mantissa, and an exponential bias \(b = 7\).
- Maximum: The largest value of 4-bit exponent is \(p = 2^4-1 = 15\), from which we subtract the bias to get \(p-b = 8\). The largest mantissa value occurs when all bits are 1. However, by convention, the number where all exponent and mantissa bits are 1 is defined as NaN. The next highest mantissa value then is with the bits \(110\), that is \(1 + \frac{1}{2} + \frac{1}{4} = 1.75\). The maximum value represented in E4M3 then is \(1.75 \times 2^8 = 448\).
- Minimum: The number with all exponent and mantissa bits set to 0 is defined as 0. The smallest positive number in this representation has all mantissa bits set to 0 (making the mantissa value 1) and the least significant exponent bit set to 1, resulting in \(p−b=−6\). This value is \(2^{-6}\).
The authors note:
Floating point numbers can be seen as a uniform m-bit grid between two consecutive (integer) powers of two \(2^a, 2^{a+1}\). The distance between grid points in the range \(\left[2^a, 2^{a+1}\right]\) is \(2^{a-m}\).
To understand this, consider a given \(p\) and \(b\) and define \(a = p-b\). The \(m\) mantissa bits can represent \(2^m\) different values (each bit can take two values and so \(2 \times 2 \times \cdots \text{m times}\)). The smallest value of mantissa bits is \(1\) when \(d_i = 0 \quad \forall i\), and the largest value is \(2\left(1-\frac{1}{2^{m+1}}\right)\) when \(d_i = 1 \quad \forall i\).The smallest and largest values of the floating point number represented by these exponent and mantissa bits are \(2^a\) and \(2^{a+1}\left(1-\frac{1}{2^{m+1}}\right) < 2^{a+1}\), respectively. Thus, between \(\left[2^a, 2^{a+1}\right)\), we have \(2^m\) values (this set of values includes \(2^a\)) which can be represented using our floating point system. This is what authors refer to by the "m-bit grid". Lastly, the distance between \(2^a\) and \(2^{a+1}\) is \(2^a\), and since we have \(2^m\) segments (each between two consecutive grid points) lining this distnace, each segment is of length \(2^{a-m}\).
The precision/range tradeoff:- A larger number of bits \(e\) for the exponent allows for a greater maximum value of \(p\), enabling the representation of larger floating point numbers.
- A larger number of bits \(m\) for the mantissa allows for the representation of more numbers (\(2^m\)) within any given range \([2^a, 2^{a+1}]\), thereby increasing precision.
Based on this, the authors conclude:
... compared to integer formats, floating point formats have more precision close to zero ... and less precision away from zero.
FP8 Quantization
NVIDIA's FP8 Primer illustrates floating point quantization using four schemes (FP8 E4M3 and E5M2, FP16 and BF16) - I used authors' FP8 quantization scheme to replicate NVIDIA's examples in this spreadsheet. Here is authors' intuition around FP8 quantization:
- Say a number \(x \in \mathbb{R}\) lies within the range \([2^a, 2^{a+1})\), for some \(a\). We know that in a fixed precision (mantissa of \(m\) bits) floating point system, \(x\) would be a point on an m-bit grid betwewn \([2^a, 2^{a+1})\), and that is the point we want to find.
- To find \(a\), note that \(2^a < x < 2^{a+1} \implies a < \log_2 x < a+1 \implies a = \lfloor \log_2 x \rfloor\). For an m-bit grid between \([2^a, 2^{a+1})\), the distance between consecutive points in the grid is \(s = 2^{a-m} = 2^{\log_2 x - m}\) (refer to the previous section).
- To represent \(x\) in our floating number system, we first quantize it to an integer (an integer point on our m-bit grid) by computing \(\big\lfloor \frac{x}{s}\big\rceil\). Since \(2^a < x < 2^{a+1}\), and \(2^a\) is divisible by \(s = 2^{a-m}\), we are assured that \(\big\lfloor \frac{x}{s}\big\rceil\) equals \(2^a + ns\) for some \(n \in \{0,1,\cdots, 2^m-1\}\). In other words, \(\big\lfloor \frac{x}{s}\big\rceil\) is a point on our m-bit grid.
- Finally, we compute the representation of this point on our m-bit grid in our floating point number system, by computing \(x^{(q)} = s\big\lfloor \frac{x}{s}\big\rceil\).