In an article recently posted to the Meta Research website, researchers proposed a fused kernel implementation with SplitK work decomposition for W4A16 quantized inference. The implementation improved the skinny matrix-matrix multiplications prevalent in foundation model inference workloads.
Background
Although several studies have been performed to design efficient general matrix multiplication (GEMM) kernels, only a few studies have focused on high-performance kernel designing for memory-bound computations, specifically those kernels common in inference. Matrix multiplications for inference are typically memory bound as the problem size creates skinny matmuls where the small dimension is the batch size, coupled with memory throughput and graphics processing unit (GPU) compute limitations.
The proposed approach
In this study, researchers presented a fused matrix multiplication kernel implementation for W4A16 quantized inference, where they used a SplitK work decomposition to perform GEMM and dequantization in a fused kernel. Specifically, an optimized Triton kernel implementation was proposed for quantized matrix-matrix multiplications in inference workloads where the problem is memory-bound.
The SplitK work decomposition with atomic reductions was leveraged in place of the conventional data parallel (DP) decomposition and integrated with dequantization to provide a single-step fused dequant and matrix multiply kernel. Researchers implemented the kernel in Triton as it offers cross-hardware compatibility and an easy-to-use interface. They performed experiments by focusing on M = 1–16, corresponding to 1–16 batch size, a usual batch size range in large language model (LLM) inference. The experiments were performed on both NVIDIA H100 and NVIDIA A100.
Researchers primarily performed a comparative analysis between the performance of the traditional DP kernel and the SplitK kernel as a fused int4 dequantization kernel with a modified decomposition strategy, which was the key contribution of this study.
The methodology
The kernel received a quantized/packed int32 weight matrix, an FP16 activation matrix, and the zero and scale parameters as input, with the parameters being used for dequantizing the weight matrix. Bitwise operations were utilized to dequantize, scale, and shift the weight matrix.
Although the proposed implementation was tailored to adhere to the GPTQ-style int4 quantization, the method is general purpose and can be applied to other n-bit quantization methods. Subsequently, the dequantized weight matrix was integrated with the optimized SplitK GEMM.
The kernel in this study was based on the Triton FP16 SplitK implementation. SplitK launched extra thread blocks along the k dimension for partial sum calculation. Then, the partial results from every thread block were summed using atomic reduction. Thus, the fused kernel implementation performed GEMM, dequantization, and atomic reduction in one kernel launch with significant performance improvements.
The SplitK effectively decomposed the work into finer-grained partitions compared to the conventional DP block style, which allowed more evenly balanced usage of resources of the streaming multiprocessors (SMs). This results in a 61% increase in waves per SM on an A100 with the proposed kernel when compared with the standard DP block tiling. The finer-grained decomposition was enabled through the application of atomic operations.
Researchers leveraged the atomic add function to ensure that thread blocks could update their partial results directly to the C output memory safely after completing their inner accumulation step, which is a part of the final result. This output memory was also updated by other thread blocks operating on subsets that contributed to the aggregated final multiply-accumulate output results for that specific C output tile. Thus, atomic adds ensured that thread blocks had exclusive access while writing their result, and updated the latest results.
However, steady performance degradation was observed with the rising matrix size on an A100 while the SplitK parameter was increased from 4 to 16 due to longer wait times of every thread block to obtain exclusive write access to a single memory output buffer. Proportionate performance gains were observed while moving from an A100 to H100 using the proposed SplitK kernel relative to the approach based on the DP block.
Research findings
The proposed fused kernel implementation successfully performs GEMM and dequantization through SplitK atomic reduction. Researchers benchmarked m, n, and k values that were relevant for llama-style inference inputs and displayed an average of 65% speed improvement on A100 and an average of 124% speed improvement on H100, with a peak improvement of 295%, compared to conventional blocked data parallelization strategies.
These speedups were driven by finer-grained work decomposition through the SplitK algorithm, which led to a higher SM occupancy, resulting in an enhanced global memory throughput through latency hiding and decreased wave quantization inefficiency.
Overall, the findings of this study demonstrated the effectiveness of the proposed approach and the potential of the current and future iterations of this fused kernel to cater to different hardware/software stacks that exist in workloads across the industry.
In the future, StreamK, the natural successor to SplitK, can be explored as the StreamK decomposition can enable even finer-grained optimal work decomposition than SplitK, leading to additional performance improvements for GEMM workloads.