In our last deep dive, we dissected the int8
implementation of Softmax. Now, we’ll turn our attention to another critical component of the Transformer architecture: the Gaussian Error Linear Unit, or GELU.
GELU is the modern default activation function, outperforming ReLU in models like BERT, GPT, and Vision Transformers. It introduces a probabilistic non-linearity, weighting inputs by their magnitude rather than gating them hard at zero.
The mathematical definition is smooth and elegant, but it’s a hardware designer’s worst enemy. It relies on the Gaussian cumulative distribution function, \(\Phi(x)\).
\[\text{GELU}(x) = x \cdot \Phi(x)\]A common high-performance approximation uses the tanh
function, which is itself a transcendental function and thus very expensive to compute directly.
Just like with Softmax, a naive floating-point implementation of this would be a major performance bottleneck. So, how does a real AI accelerator compute GELU using only simple integer operations? Let’s uncover the secrets behind this integer-based wizardry. 🧙♂️
This exploration is another key step in our journey to build a complete linear algebra library, demonstrating how to quantize complex, non-linear functions for int16
hardware.
## The Big Idea: Approximate, Don’t Calculate
The core secret to integer GELU is to abandon the exact formula and instead create a cheap piecewise approximation. The smooth curve of the GELU function is replaced by several simpler connecting functions (like quadratics) that are easy to compute with integers.
The goal is to create a function that is “close enough” to the real GELU but can be calculated using only the fast operations available on a chip: addition, multiplication, and bit-shifts.
## Secret #1: The “Compiler” Finds the Magic Numbers
The first step in this process doesn’t happen during inference. It happens once, during model compilation. A function, let’s call it a coefficient generator, acts as a “compiler.” Its job is to analyze the number system (i.e., the quantization scales) and pre-calculate a set of integer “magic numbers” or coefficients.
These coefficients (b_int
, c_int
, out_multiplier
, shift_int
, etc.) are the heart of the trick. They bake all the complex floating-point math from the GELU formula and the quantization scaling into simple integers that the hardware can use.
# The "Compiler" stage:
# This function is run once to prepare for the integer calculation.
def generate_gelu_coefficients(input_scalar, out_scalar):
# ... complex floating point math using logarithms and constants ...
# ... to calculate optimal integer coefficients and shifts ...
# ... that will best approximate the GELU curve for the given scales.
b_int = ...
c_int = ...
out_multiplier = ...
shift_int = ...
final_out_shift = ...
return (b_int, c_int, out_multiplier, shift_int, final_out_shift)
This stage also cleverly calculates a maximum safe input value (input_max
) by testing where the integer approximation would overflow. This prevents corrupted results during the actual calculation.
## Secret #2: The “Hardware” Performs the Integer-Only Dance
The function that performs the core computation is the blueprint for the actual hardware circuit. It takes an integer input and the pre-calculated magic coefficients and performs the GELU approximation using only fast integer math.
Here’s a simplified breakdown of the hardware pipeline it represents:
-
Safety First (Clamp): The input integer (
in_data
) is first clipped to the pre-calculatedinput_max
to guarantee the calculation won’t overflow. -
Piecewise Logic (Minimum): The hardware determines which “piece” of the approximation to use. A common technique is to use the
minimum
of the input’s absolute value and a coefficient (b_int_pos
). This is a clever way of saying: “If the input’s absolute value is less thanb_int_pos
, use one formula. If it’s larger, use a different (saturated) value.” This creates the bend in the GELU curve. -
Quadratic Approximation (The Bend): For the primary part of the curve, a simple quadratic function
(input + b)^2
is used. This is extremely fast for hardware to compute—just one addition and one multiplication. The bit-shift>> c_int_shift
is used to scale the result correctly. -
Final Assembly (Multiply and Shift): The final result is assembled by combining the result of the approximation with the original input, mimicking the
x * sigmoid(...)
structure of GELU. This is again done with a fast multiplication and a final bit-shift (>> out_shift
) to produce the finalint16
output.
This sequence of simple integer operations successfully approximates the complex GELU function without a single floating-point calculation.
## Bringing it to Life: A Runnable Example
To make these concepts easier to understand and experiment with, here is a simplified, self-contained Python script.
This script demonstrates the core “hardware” logic. We will hard-code the magic coefficients that the “compiler” would normally generate. This lets us focus on how the integer calculation works.
The script will:
- Define the true floating-point GELU function for comparison.
- Define our simplified
int16
GELU approximation. - Take a sample float tensor, quantize it to
int16
. - Run the integer GELU.
- Dequantize the result and compare it to the true float output.
import torch
import numpy as np
# --- For Comparison: The real GELU function ---
def float_gelu(x):
"""Standard GELU function using floating point math."""
return torch.nn.functional.gelu(x, approximate="tanh")
# --- Our simplified integer-only GELU approximation ---
def int16_gelu_approximation(x_int):
"""
A simplified integer-only GELU that mimics the hardware logic.
We hard-code the 'magic numbers' for this educational example.
"""
# These coefficients would be generated by the "compiler" step
# based on the specific quantization scales.
b_int_pos = np.int16(3665) # Threshold for the piecewise approximation
c_int = np.int16(-1605) # Quadratic adjustment coefficient
c_int_shift = np.int32(6) # Bit-shift for the quadratic part
shift_int = np.int16(17861) # Additive shift to approximate sigmoid
out_multiplier = np.int16(18300) # Final multiplier
out_shift = np.int32(18) # Final bit-shift for scaling
# 1. Safety Clamp (using a reasonable hard-coded max)
input_max = np.int16(32767)
x_clamped = torch.clamp(x_int, -input_max, input_max)
# 2. Piecewise Logic
x_abs = torch.abs(x_clamped)
min_int = torch.minimum(x_abs, torch.tensor(b_int_pos, dtype=torch.int16))
# 3. Quadratic Approximation
# Note: We use 32-bit integers for intermediate calculations to avoid overflow
min_int_adj = min_int.to(torch.int32) - b_int_pos
square = (min_int_adj * min_int_adj) >> c_int_shift
sigmoid_base = square + c_int
# Apply sign and final shift to approximate the sigmoid part of GELU
sign = torch.where(x_clamped < 0, -1, 1).to(torch.int32)
sigmoid_int = sigmoid_base * sign + shift_int
# 4. Final Assembly (x * sigmoid_approx)
x_adj_32 = (x_clamped.to(torch.int32) * out_multiplier) >> 15
output_32 = (sigmoid_int * x_adj_32) >> out_shift
# Clip to final int16 range
return torch.clamp(output_32, -32768, 32767).to(torch.int16)
if __name__ == "__main__":
# --- Setup ---
# Create a sample float tensor
float_input = torch.tensor([-3.0, -1.0, 0.0, 0.5, 1.0, 2.0, 4.0])
# Define our quantization scale.
# This scale maps the float values to the int16 range [-32768, 32767]
# For example, a float value of 1.0 becomes the integer 8192.
scale = 2**13 # 8192.0
# --- Run the comparison ---
print("--- Running GELU Comparison ---")
print(f"Original Floats:\n{float_input}\n")
# 1. Calculate the true result using floating-point math
float_output = float_gelu(float_input)
print(f"True Float GELU Output:\n{float_output.numpy()}\n")
# 2. Quantize our input from float to int16
quantized_input = (float_input * scale).round().to(torch.int16)
print(f"Quantized Input (int16):\n{quantized_input.numpy()}\n")
# 3. Run our integer-only GELU approximation
quantized_output = int16_gelu_approximation(quantized_input)
print(f"Quantized Output (int16):\n{quantized_output.numpy()}\n")
# 4. Dequantize the result back to float to see how close we got
reconstructed_float_output = quantized_output.to(torch.float32) / scale
print(f"Reconstructed Float Output (from int math):\n{reconstructed_float_output.numpy()}\n")
# 5. Calculate the error
error = torch.mean(torch.abs(float_output - reconstructed_float_output))
print(f"--- Finished ---")
print(f"\nAverage Absolute Error: {error.item():.6f}")