[EdgeAI - Part 3]: The Integer Trick - How Hardware Optimally Calculates GELU


In our last deep dive, we dissected the int8 implementation of Softmax. Now, we’ll turn our attention to another critical component of the Transformer architecture: the Gaussian Error Linear Unit, or GELU.

GELU is the modern default activation function, outperforming ReLU in models like BERT, GPT, and Vision Transformers. It introduces a probabilistic non-linearity, weighting inputs by their magnitude rather than gating them hard at zero.

The mathematical definition is smooth and elegant, but it’s a hardware designer’s worst enemy. It relies on the Gaussian cumulative distribution function, \(\Phi(x)\).

\[\text{GELU}(x) = x \cdot \Phi(x)\]

A common high-performance approximation uses the tanh function, which is itself a transcendental function and thus very expensive to compute directly.

\[\text{GELU}(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left[\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3\right)\right]\right)\]

Just like with Softmax, a naive floating-point implementation of this would be a major performance bottleneck. So, how does a real AI accelerator compute GELU using only simple integer operations? Let’s uncover the secrets behind this integer-based wizardry. 🧙‍♂️

This exploration is another key step in our journey to build a complete linear algebra library, demonstrating how to quantize complex, non-linear functions for int16 hardware.


## The Big Idea: Approximate, Don’t Calculate

The core secret to integer GELU is to abandon the exact formula and instead create a cheap piecewise approximation. The smooth curve of the GELU function is replaced by several simpler connecting functions (like quadratics) that are easy to compute with integers.

The goal is to create a function that is “close enough” to the real GELU but can be calculated using only the fast operations available on a chip: addition, multiplication, and bit-shifts.


## Secret #1: The “Compiler” Finds the Magic Numbers

The first step in this process doesn’t happen during inference. It happens once, during model compilation. A function, let’s call it a coefficient generator, acts as a “compiler.” Its job is to analyze the number system (i.e., the quantization scales) and pre-calculate a set of integer “magic numbers” or coefficients.

These coefficients (b_int, c_int, out_multiplier, shift_int, etc.) are the heart of the trick. They bake all the complex floating-point math from the GELU formula and the quantization scaling into simple integers that the hardware can use.

# The "Compiler" stage:
# This function is run once to prepare for the integer calculation.
def generate_gelu_coefficients(input_scalar, out_scalar):
    # ... complex floating point math using logarithms and constants ...
    # ... to calculate optimal integer coefficients and shifts ...
    # ... that will best approximate the GELU curve for the given scales.

    b_int = ...
    c_int = ...
    out_multiplier = ...
    shift_int = ...
    final_out_shift = ...

    return (b_int, c_int, out_multiplier, shift_int, final_out_shift)

This stage also cleverly calculates a maximum safe input value (input_max) by testing where the integer approximation would overflow. This prevents corrupted results during the actual calculation.


## Secret #2: The “Hardware” Performs the Integer-Only Dance

The function that performs the core computation is the blueprint for the actual hardware circuit. It takes an integer input and the pre-calculated magic coefficients and performs the GELU approximation using only fast integer math.

Here’s a simplified breakdown of the hardware pipeline it represents:

  1. Safety First (Clamp): The input integer (in_data) is first clipped to the pre-calculated input_max to guarantee the calculation won’t overflow.

  2. Piecewise Logic (Minimum): The hardware determines which “piece” of the approximation to use. A common technique is to use the minimum of the input’s absolute value and a coefficient (b_int_pos). This is a clever way of saying: “If the input’s absolute value is less than b_int_pos, use one formula. If it’s larger, use a different (saturated) value.” This creates the bend in the GELU curve.

  3. Quadratic Approximation (The Bend): For the primary part of the curve, a simple quadratic function (input + b)^2 is used. This is extremely fast for hardware to compute—just one addition and one multiplication. The bit-shift >> c_int_shift is used to scale the result correctly.

  4. Final Assembly (Multiply and Shift): The final result is assembled by combining the result of the approximation with the original input, mimicking the x * sigmoid(...) structure of GELU. This is again done with a fast multiplication and a final bit-shift (>> out_shift) to produce the final int16 output.

This sequence of simple integer operations successfully approximates the complex GELU function without a single floating-point calculation.


## Bringing it to Life: A Runnable Example

To make these concepts easier to understand and experiment with, here is a simplified, self-contained Python script.

This script demonstrates the core “hardware” logic. We will hard-code the magic coefficients that the “compiler” would normally generate. This lets us focus on how the integer calculation works.

The script will:

  1. Define the true floating-point GELU function for comparison.
  2. Define our simplified int16 GELU approximation.
  3. Take a sample float tensor, quantize it to int16.
  4. Run the integer GELU.
  5. Dequantize the result and compare it to the true float output.
import torch
import numpy as np

# --- For Comparison: The real GELU function ---
def float_gelu(x):
    """Standard GELU function using floating point math."""
    return torch.nn.functional.gelu(x, approximate="tanh")

# --- Our simplified integer-only GELU approximation ---
def int16_gelu_approximation(x_int):
    """
    A simplified integer-only GELU that mimics the hardware logic.
    We hard-code the 'magic numbers' for this educational example.
    """
    # These coefficients would be generated by the "compiler" step
    # based on the specific quantization scales.
    b_int_pos = np.int16(3665)      # Threshold for the piecewise approximation
    c_int = np.int16(-1605)        # Quadratic adjustment coefficient
    c_int_shift = np.int32(6)      # Bit-shift for the quadratic part
    shift_int = np.int16(17861)    # Additive shift to approximate sigmoid
    out_multiplier = np.int16(18300) # Final multiplier
    out_shift = np.int32(18)       # Final bit-shift for scaling

    # 1. Safety Clamp (using a reasonable hard-coded max)
    input_max = np.int16(32767)
    x_clamped = torch.clamp(x_int, -input_max, input_max)

    # 2. Piecewise Logic
    x_abs = torch.abs(x_clamped)
    min_int = torch.minimum(x_abs, torch.tensor(b_int_pos, dtype=torch.int16))

    # 3. Quadratic Approximation
    # Note: We use 32-bit integers for intermediate calculations to avoid overflow
    min_int_adj = min_int.to(torch.int32) - b_int_pos
    square = (min_int_adj * min_int_adj) >> c_int_shift
    sigmoid_base = square + c_int

    # Apply sign and final shift to approximate the sigmoid part of GELU
    sign = torch.where(x_clamped < 0, -1, 1).to(torch.int32)
    sigmoid_int = sigmoid_base * sign + shift_int

    # 4. Final Assembly (x * sigmoid_approx)
    x_adj_32 = (x_clamped.to(torch.int32) * out_multiplier) >> 15
    output_32 = (sigmoid_int * x_adj_32) >> out_shift

    # Clip to final int16 range
    return torch.clamp(output_32, -32768, 32767).to(torch.int16)


if __name__ == "__main__":
    # --- Setup ---
    # Create a sample float tensor
    float_input = torch.tensor([-3.0, -1.0, 0.0, 0.5, 1.0, 2.0, 4.0])

    # Define our quantization scale.
    # This scale maps the float values to the int16 range [-32768, 32767]
    # For example, a float value of 1.0 becomes the integer 8192.
    scale = 2**13 # 8192.0

    # --- Run the comparison ---
    print("--- Running GELU Comparison ---")
    print(f"Original Floats:\n{float_input}\n")

    # 1. Calculate the true result using floating-point math
    float_output = float_gelu(float_input)
    print(f"True Float GELU Output:\n{float_output.numpy()}\n")

    # 2. Quantize our input from float to int16
    quantized_input = (float_input * scale).round().to(torch.int16)
    print(f"Quantized Input (int16):\n{quantized_input.numpy()}\n")

    # 3. Run our integer-only GELU approximation
    quantized_output = int16_gelu_approximation(quantized_input)
    print(f"Quantized Output (int16):\n{quantized_output.numpy()}\n")

    # 4. Dequantize the result back to float to see how close we got
    reconstructed_float_output = quantized_output.to(torch.float32) / scale
    print(f"Reconstructed Float Output (from int math):\n{reconstructed_float_output.numpy()}\n")

    # 5. Calculate the error
    error = torch.mean(torch.abs(float_output - reconstructed_float_output))
    print(f"--- Finished ---")
    print(f"\nAverage Absolute Error: {error.item():.6f}")