Introduction

The previous post optimized single-precision GEMM using CUDA cores, reaching 89% of cuBLAS performance. But modern deep learning runs on FP16, and modern GPUs have dedicated tensor cores that deliver 4-16× higher throughput than CUDA cores for matrix operations.

This post implements half-precision GEMM (HGEMM) using NVIDIA’s WMMA (Warp Matrix Multiply-Accumulate) API, progressing from a naive tensor core kernel to 72% of cuBLAS performance on an A100. WMMA provides a high-level abstraction over tensor cores that handles the hardware complexity while still exposing the key optimization opportunities. The benchmark code is available in this repository.

A follow-up post will explore the lower-level PTX mma.sync instructions to close the remaining gap.

HGEMM Performance
Figure 1: HGEMM optimization progression on A100. We start at 43 TFLOPS and reach 196 TFLOPS (72% of cuBLAS).

Tensor Cores and the WMMA API

Tensor cores are specialized hardware units that compute small matrix multiply-accumulate operations in a single cycle. The fundamental operation is:

\[D = A \times B + C\]

where A, B, C, and D are small matrix tiles. The supported tile dimensions and data types depend on the GPU architecture. NVIDIA provides support across multiple precisions:

  • FP16 / BF16: Half-precision for inference and training
  • TF32: 19-bit format for FP32-like range with tensor core throughput (Ampere+)
  • FP64: Double precision tensor cores (A100+, limited throughput)
  • INT8 / INT4: Integer formats for quantized inference

Each precision supports specific matrix shapes (m×n×k combinations). For example, FP16 on Ampere supports 16×16×16, 32×8×16, and 8×32×16 tile shapes. See the CUDA Programming Guide for the complete list of supported element types and sizes.

This post focuses on FP16 with 16×16×16 tiles, a common configuration for square matrices. On A100, each SM contains 3rd generation tensor cores, and a single wmma::mma_sync instruction computes 16×16×16 = 4096 multiply-add operations (8192 FLOPs) per warp.

The WMMA API abstracts tensor cores through fragments: opaque containers that hold matrix tiles distributed across a warp’s 32 threads.

#include <mma.h>
using namespace nvcuda;

// Declare fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> acc_frag;

// Load tiles from memory
wmma::load_matrix_sync(a_frag, A_ptr, lda);
wmma::load_matrix_sync(b_frag, B_ptr, ldb);
wmma::fill_fragment(acc_frag, 0.0f);

// Execute tensor core MMA
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);

// Store result
wmma::store_matrix_sync(C_ptr, acc_frag, ldc, wmma::mem_row_major);

The compiler maps these operations to the underlying tensor core instructions. Our job is to feed the tensor cores efficiently by optimizing memory access patterns, which is where all the performance comes from.

Kernel 1: Block Tiling Baseline

The first kernel adapts the SGEMM block tiling strategy for tensor cores. Each threadblock computes a $BM \times BN$ output tile; within the block, each warp computes a $WM \times WN$ sub-tile using multiple $16 \times 16$ WMMA operations.

Tensor Core GEMM Tiling Hierarchy
Figure 2: Three-level tiling hierarchy for tensor core GEMM. A threadblock computes a BM×BN output tile, each warp handles a WM×WN sub-tile, and each WMMA instruction processes 16×16 elements.

The diagram above shows the high-level structure. Let’s trace through the math with concrete numbers:

TILING HIERARCHY (BM=128, BN=128, WM=32, WN=32)
═══════════════════════════════════════════════════════════════

BLOCK TILE (128 × 128)
──────────────────────
                      BN = 128
        ◄────────────────────────────────────►
        ┌────────┬────────┬────────┬────────┐
        │ Warp 0 │ Warp 1 │ Warp 2 │ Warp 3 │  32
        ├────────┼────────┼────────┼────────┤
        │ Warp 4 │ Warp 5 │ Warp 6 │ Warp 7 │  32
 BM=128 ├────────┼────────┼────────┼────────┤
        │ Warp 8 │ Warp 9 │Warp 10 │Warp 11 │  32
        ├────────┼────────┼────────┼────────┤
        │Warp 12 │Warp 13 │Warp 14 │Warp 15 │  32
        └────────┴────────┴────────┴────────┘
            32       32       32       32

        → 16 warps = 512 threads per block
        → Each warp owns a 32×32 output region


WARP TILE (32 × 32)                    ◄── zooming into one warp
───────────────────
              WN = 32
        ◄────────────────►
        ┌───────┬────────┐
        │ MMA   │  MMA   │  16
 WM=32  ├───────┼────────┤
        │ MMA   │  MMA   │  16
        └───────┴────────┘
           16       16

        → 4 MMA operations per warp
        → Each MMA: 16×16×16 = 4096 FMAs


FRAGMENT DISTRIBUTION                  ◄── inside one MMA
─────────────────────
   A 16×16 fragment holds 256 values, but no single
   thread sees the whole matrix:

   Thread 0  → elements [?, ?, ?, ?, ?, ?, ?, ?]  ┐
   Thread 1  → elements [?, ?, ?, ?, ?, ?, ?, ?]  │
      ...                                         ├─ 32 threads × 8 = 256 ✓
   Thread 31 → elements [?, ?, ?, ?, ?, ?, ?, ?]  ┘

   The layout is opaque. WMMA handles distribution internally.
   We just call load_matrix_sync() and mma_sync().

═══════════════════════════════════════════════════════════════

The key insight: we never manipulate individual matrix elements. WMMA fragments are opaque containers. We load them from shared memory, execute MMA operations, and store the results. The hardware handles the complex data distribution across threads.

The kernel structure is shown below. To keep things focused, I’ve moved the shared memory loading and epilogue logic into helper functions (loadTileA_scalar, loadTileB_scalar, and epilogueAndStore). The full implementations are in the repository.

template <int BM, int BN, int BK, int WM, int WN>
__global__ void wmma_block_tiling(
    int M, int N, int K, __half alpha, 
    const __half *A, const __half *B, __half beta, __half *C)
{
    constexpr int MMA_M = 16, MMA_N = 16, MMA_K = 16;
    constexpr int MMA_M_TILES = WM / MMA_M;
    constexpr int MMA_N_TILES = WN / MMA_N;
    constexpr int WARPS_M = BM / WM;
    constexpr int WARPS_N = BN / WN;

    __shared__ __half As[BM * BK];
    __shared__ __half Bs[BK * BN];

    const uint warpId = threadIdx.x / 32;
    const uint warpM = warpId / WARPS_N;
    const uint warpN = warpId % WARPS_N;

    // Initialize accumulators
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> acc[MMA_M_TILES][MMA_N_TILES];
    for (int m = 0; m < MMA_M_TILES; ++m)
        for (int n = 0; n < MMA_N_TILES; ++n)
            wmma::fill_fragment(acc[m][n], __float2half(0.0f));
    
    for (int tileK = 0; tileK < K; tileK += BK) {
        // Load tiles with scalar loads (one element per thread)
        loadTileA_scalar<BM, BK, NUM_THREADS>(A, As, K, tid);
        loadTileB_scalar<BK, BN, NUM_THREADS>(B, Bs, N, tid);
        __syncthreads();

        // Process BK elements in MMA_K chunks
        for (int innerK = 0; innerK < BK; innerK += MMA_K) {
            wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> 
                a_frag[MMA_M_TILES];
            wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> 
                b_frag[MMA_N_TILES];
            
            // Load A fragments for this warp
            for (int m = 0; m < MMA_M_TILES; ++m) {
                const __half *As_ptr = &As[(warpM * WM + m * MMA_M) * BK + innerK];
                wmma::load_matrix_sync(a_frag[m], As_ptr, BK);
            }

            // Load B fragments for this warp
            for (int n = 0; n < MMA_N_TILES; ++n) {
                const __half *Bs_ptr = &Bs[innerK * BN + warpN * WN + n * MMA_N];
                wmma::load_matrix_sync(b_frag[n], Bs_ptr, BN);
            }

            // Tensor core MMA
            for (int m = 0; m < MMA_M_TILES; ++m)
                for (int n = 0; n < MMA_N_TILES; ++n)
                    wmma::mma_sync(acc[m][n], a_frag[m], b_frag[n], acc[m][n]);
        }
        __syncthreads();
    }
    
    // Store results
    epilogueAndStore(acc, C, N, alpha, beta, warpM, warpN);
}

The loadTileA_scalar function reveals our first bottleneck. Each thread loads just one __half element at a time:

template <int BM, int BK, int NUM_THREADS>
__device__ void loadTileA_scalar(const __half *A, __half *As, int K, uint tid)
{
    constexpr int TOTAL_ELEMENTS = BM * BK;
    constexpr int ELEMENTS_PER_THREAD = TOTAL_ELEMENTS / NUM_THREADS;
    
    for (int i = 0; i < ELEMENTS_PER_THREAD; ++i) {
        uint idx = tid + i * NUM_THREADS;
        uint row = idx / BK;
        uint col = idx % BK;
        As[row * BK + col] = A[row * K + col];  // Single __half load
    }
}

The bottleneck is scalar memory loads. Each thread loads one __half at a time, but a float4 can load 8 halves in one instruction. We’re leaving 87.5% of potential memory throughput on the table.

Kernel N=4096 Time TFLOPS % cuBLAS
01_WMMABlockTiling 3.23 ms 42.6 16%

Kernel 2: Vectorized Loads

FP16 elements are only 2 bytes. To maximize memory bandwidth, we use float4 loads (16 bytes = 8 half elements) for coalesced 128-byte transactions:

template <int BM, int BK, int NUM_THREADS>
__device__ void loadTileA_vec4(const __half *A, __half *As, int K, uint tid)
{
    constexpr int TOTAL_VEC = (BM * BK) / 8;  // 8 halves per float4
    constexpr int VEC_PER_THREAD = TOTAL_VEC / NUM_THREADS;
    
    for (int i = 0; i < VEC_PER_THREAD; ++i) {
        uint idx = tid + i * NUM_THREADS;
        uint row = idx / (BK / 8);
        uint col8 = idx % (BK / 8);
        
        // Load 8 halves at once via float4
        float4 val = reinterpret_cast<const float4*>(&A[row * K + col8 * 8])[0];
        reinterpret_cast<float4*>(&As[row * BK + col8 * 8])[0] = val;
    }
}

Vectorization nearly doubles performance by utilizing full memory transaction width.

Kernel N=4096 Time TFLOPS % cuBLAS
01_WMMABlockTiling 3.23 ms 42.6 16%
02_WMMAVectorized 1.85 ms 74.3 27%

Kernel 3: Asynchronous Copy

On SM80+ architectures (Ampere and later), cp.async copies directly from global to shared memory without staging through registers. This reduces register pressure and enables better latency hiding. It also sets us up nicely for multi-stage pipelining, which we’ll implement later.

template <int BM, int BK, int NUM_THREADS>
__device__ void loadTileA_async(const __half *A, __half *As, int K, uint tid)
{
    constexpr int VEC_PER_THREAD = (BM * BK) / 8 / NUM_THREADS;

    for (int i = 0; i < VEC_PER_THREAD; ++i) {
        uint idx = tid + i * NUM_THREADS;
        uint row = idx / (BK / 8);
        uint col8 = idx % (BK / 8);
        
        __pipeline_memcpy_async(
            &As[row * BK + col8 * 8],
            &A[row * K + col8 * 8],
            sizeof(float4)
        );
    }
}

The pipeline API commits and waits for async copies:

loadTileA_async<BM, BK, NUM_THREADS>(A, As, K, tid);
loadTileB_async<BK, BN, NUM_THREADS>(B, Bs, N, tid);
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
Kernel N=4096 Time TFLOPS % cuBLAS
02_WMMAVectorized 1.85 ms 74.3 27%
03_WMMAAsync 1.66 ms 82.6 30%

Kernel 4: Shared Memory Padding

Shared memory is organized into 32 banks of 4 bytes each. When multiple threads access addresses that map to the same bank, the accesses serialize (bank conflicts). With FP16, 8 consecutive elements span 16 bytes = 4 banks, and the 16×16 WMMA tile layout creates systematic conflicts.

Since we don’t know exactly how WMMA accesses shared memory internally (it’s opaque), padding is an educated guess based on knowing that it generally helps with bank conflicts. We pad by 8 elements (16 bytes) since this maintains the 16-byte alignment required for our vectorized cp.async copies while shifting successive rows to different banks.

constexpr int SMEM_PAD = 8;
constexpr int A_STRIDE = BK + SMEM_PAD;  // 32 + 8 = 40
constexpr int B_STRIDE = BN + SMEM_PAD;

__shared__ __half As[BM * A_STRIDE];
__shared__ __half Bs[BK * B_STRIDE];

The trade-off is wasted shared memory. For a BM=128, BK=32 tile, we go from 128×32×2 = 8KB to 128×40×2 = 10KB. That’s 25% overhead per tile. Multiply by 2 matrices and multiple pipeline stages, and it adds up. A better alternative is XOR-based swizzling, which eliminates bank conflicts without padding overhead. We can’t use swizzle patterns with WMMA (the API doesn’t expose the required control), but we’ll explore this in the PTX mma.sync follow-up.

Loading stores to the padded layout:

// Store with padded stride
__pipeline_memcpy_async(
    &As[row * A_STRIDE + col8 * 8],  // Padded stride
    &A[row * K + col8 * 8],           // Natural stride
    sizeof(float4)
);

Fragment loads also use the padded stride:

wmma::load_matrix_sync(a_frag[m], As_ptr, A_STRIDE);  // Lead dimension = padded stride

This reduces bank conflicts and nearly doubles our performance:

Kernel N=4096 Time TFLOPS % cuBLAS
03_WMMAAsync 1.66 ms 82.6 30%
04_WMMAPadded 0.89 ms 154.5 57%

Kernel 5: Multi-Stage Pipeline

With single buffering, computation waits for memory. A multi-stage pipeline keeps STAGES-1 loads in flight while computing on the current tile. The following diagrams show how pipelining overlaps memory and compute:

NO PIPELINING

Memory and compute are serialized. Compute stalls waiting for each load.

Cycle:        1     2     3     4     5     6     7     8     9    ...
             ──────────────────────────────────────────────────────────
Memory:      [L0]        [L1]        [L2]        [L3]        [L4]
Compute:           [C0]        [C1]        [C2]        [C3]        ...


2-STAGE PIPELINE

Prologue loads 1 tile. Main loop overlaps next load with current compute.

Cycle:        1     2     3     4     5     6     7
             ────────────────────────────────────────────
Memory:      [L0]  [L1]  [L2]  [L3]  [L4]  [L5]
Compute:           [C0]  [C1]  [C2]  [C3]  [C4]  [C5]
             
             └pro┘ └────────── main loop ───────────┘


3-STAGE PIPELINE

Prologue loads 2 tiles. Deeper buffer hides more memory latency.

Cycle:        1     2     3     4     5     6     7     8
             ─────────────────────────────────────────────────
Memory:      [L0]  [L1]  [L2]  [L3]  [L4]  [L5]
Compute:                 [C0]  [C1]  [C2]  [C3]  [C4]  [C5]
             
             └─prologue─┘ └─────────── main loop ───────────┘

The kernel maintains circular buffers and tracks which stage to compute vs. load:

__shared__ __half As[STAGES][BM * A_STRIDE];
__shared__ __half Bs[STAGES][BK * B_STRIDE];

// Prologue: fill pipeline
for (int s = 0; s < STAGES - 1 && s < numTiles; ++s) {
    loadTileA_async_padded<BM, BK, A_STRIDE, NUM_THREADS>(A + s * BK, As[s], K, tid);
    loadTileB_async_padded<BK, BN, B_STRIDE, NUM_THREADS>(B + s * BK * N, Bs[s], N, tid);
    __pipeline_commit();
}

// Main loop
int loadTile = STAGES - 1;
for (int tile = 0; tile < numTiles; ++tile) {
    int computeStage = tile % STAGES;

    // Prefetch next tile
    if (loadTile < numTiles) {
        int loadStage = loadTile % STAGES;
        loadTileA_async_padded(..., As[loadStage], ...);
        loadTileB_async_padded(..., Bs[loadStage], ...);
        __pipeline_commit();
        ++loadTile;
    }

    // Wait for compute stage (STAGES-1 loads remain in flight)
    __pipeline_wait_prior(STAGES - 1);
    __syncthreads();

    // Compute on As[computeStage], Bs[computeStage]
    // ...
}
Kernel N=4096 Time TFLOPS % cuBLAS
04_WMMAPadded 0.89 ms 154.5 57%
05_WMMAMultistage 0.73 ms 189.0 69%

Kernel 6: Fragment Double Buffering

Multi-stage pipelining overlaps global to shared transfers. We can also overlap shared to register transfers by double-buffering the WMMA fragments:

// Double-buffered fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag[2][MMA_M_TILES];
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> b_frag[2][MMA_N_TILES];

int frag_load = 0;
int frag_compute = 1;

// Prologue: load first fragments
for (int m = 0; m < MMA_M_TILES; ++m)
    wmma::load_matrix_sync(a_frag[frag_load][m], As_ptr_k0, A_STRIDE);
for (int n = 0; n < MMA_N_TILES; ++n)
    wmma::load_matrix_sync(b_frag[frag_load][n], Bs_ptr_k0, B_STRIDE);

// Main inner loop
for (int innerK = 0; innerK < BK; innerK += MMA_K) {
    frag_load ^= 1;      // Swap buffers
    frag_compute ^= 1;

    // Load NEXT fragments (overlapped with compute)
    if (innerK + MMA_K < BK) {
        for (int m = 0; m < MMA_M_TILES; ++m)
            wmma::load_matrix_sync(a_frag[frag_load][m], As_ptr_next, A_STRIDE);
        for (int n = 0; n < MMA_N_TILES; ++n)
            wmma::load_matrix_sync(b_frag[frag_load][n], Bs_ptr_next, B_STRIDE);
    }

    // Compute with CURRENT fragments
    for (int m = 0; m < MMA_M_TILES; ++m)
        for (int n = 0; n < MMA_N_TILES; ++n)
            wmma::mma_sync(acc[m][n], a_frag[frag_compute][m], 
                           b_frag[frag_compute][n], acc[m][n]);
}
Kernel N=4096 Time TFLOPS % cuBLAS
05_WMMAMultistage 0.73 ms 189.0 69%
06_WMMADoubleBuffer 0.72 ms 189.8 70%

The gains from fragment double buffering are minimal here because the inner loop is fully unrolled and the compiler already schedules loads well.

Kernel 7: Dynamic Shared Memory

Static shared memory is limited to 48KB on A100. For larger tiles or more pipeline stages, we use dynamic shared memory (up to 164KB on A100):

extern __shared__ __half smem[];
__half* As = smem;
__half* Bs = smem + STAGES * A_STAGE_SIZE;

At launch, configure and allocate:

cudaFuncSetAttribute(
    wmma_dynsmem<BM, BN, BK, WM, WN, STAGES>,
    cudaFuncAttributeMaxDynamicSharedMemorySize,
    SMEM_SIZE
);

wmma_dynsmem<<<grid, block, SMEM_SIZE>>>(...);

This enables larger configurations that wouldn’t fit in static shared memory.

Kernel N=4096 Time TFLOPS % cuBLAS
06_WMMADoubleBuffer 0.72 ms 189.8 70%
07_WMMADynSmem 0.74 ms 186.8 69%

We don’t see any benefit at this size since we’re already using efficient configurations within the 48KB static limit. The real benefit comes in the final kernel where autotuning can explore configurations that require more shared memory.

Kernel 8: Final Optimizations

The final kernel combines all optimizations plus additional refinements:

Zig-zag MMA order: When iterating over MMA tiles, we alternate the N direction on odd M rows. This maximizes fragment reuse: the last B fragment used at the end of one row is the first one needed at the start of the next.

Access pattern (MMA_M_TILES=4, MMA_N_TILES=4):

Sequential order:         Zig-zag order:
m=0: n=0,1,2,3            m=0: n=0,1,2,3
m=1: n=0,1,2,3            m=1: n=3,2,1,0  ← reversed
m=2: n=0,1,2,3            m=2: n=0,1,2,3
m=3: n=0,1,2,3            m=3: n=3,2,1,0  ← reversed

With zig-zag: the last B fragment used at the end of m=0 is b_frag[3].
The first B fragment needed at the start of m=1 is also b_frag[3].
It's already in registers!
for (int m = 0; m < MMA_M_TILES; ++m) {
    for (int n = 0; n < MMA_N_TILES; ++n) {
        int n_idx = (m % 2) ? (MMA_N_TILES - 1 - n) : n;  // Zig-zag
        wmma::mma_sync(acc[m][n_idx], a_frag[frag_compute][m], 
                       b_frag[frag_compute][n_idx], acc[m][n_idx]);
    }
}

Vectorized epilogue: Reuse A/B shared memory for C tile, then copy to global with float4 stores for better coalescing.

Block swizzling: Remap block indices to improve L2 cache locality. Instead of using a 2D grid, we use a 1D grid and remap the linear block index to (blockM, blockN) coordinates. Blocks are processed in groups of GROUP_SIZE_M rows before advancing to the next column group:

STANDARD 2D GRID (row-major execution)

                        blockIdx.x →
                 0       1       2       3
              ┌───────┬───────┬───────┬───────┐
            0 │   0   │   1   │   2   │   3   │
              ├───────┼───────┼───────┼───────┤
            1 │   4   │   5   │   6   │   7   │
blockIdx.y    ├───────┼───────┼───────┼───────┤
     ↓      2 │   8   │   9   │  10   │  11   │
              ├───────┼───────┼───────┼───────┤
            3 │  12   │  13   │  14   │  15   │
              └───────┴───────┴───────┴───────┘

Execution: 0→1→2→3→4→5→... (row by row)
Problem: Block 4 needs B[:,0] but it may be evicted by B[:,1], B[:,2], B[:,3].


SWIZZLED 1D GRID (GROUP_SIZE_M=2)

                        blockIdx.x →
                  0       1       2       3
              ┌───────┬───────*───────┬───────┐
            0 │   0   │   2   *   4   │   6   │
              ├───────┼───────*───────┼───────┤
            1 │   1   │   3   *   5   │   7   │
blockIdx.y    * * * * * * * * * * * * * * * * *
     ↓      2 │   8   │  10   *  12   │  14   │
              ├───────┼───────*───────┼───────┤
            3 │   9   │  11   *  13   │  15   │
              └───────┴───────*───────┴───────┘

Execution: 0→1→2→3→4→5→... (down columns within each group)
Blocks 0,1 share B[:,0]. Blocks 2,3 share B[:,1]. Better L2 reuse.
if constexpr (USE_SWIZZLE) {
    const uint num_blocks_m = (M + BM - 1) / BM;
    const uint num_blocks_n = (N + BN - 1) / BN;
    const uint num_blocks_in_group = GROUP_SIZE_M * num_blocks_n;
    
    const uint bid = blockIdx.x;  // 1D block index
    const uint group_id = bid / num_blocks_in_group;
    const uint first_block_m = group_id * GROUP_SIZE_M;
    const uint group_size_m = min(num_blocks_m - first_block_m, (uint)GROUP_SIZE_M);
    
    blockM = first_block_m + (bid % group_size_m);
    blockN = (bid % num_blocks_in_group) / group_size_m;
}

Autotuning

Tile sizes, pipeline stages, and swizzle parameters significantly affect performance. We autotune across configurations:

[Autotune] Testing 9 configurations on 4096x4096x4096...
  128x128x32_64x64_S2_swfalse_g8         0.700 ms  196.43 TFLOPS
  128x128x32_64x64_S3_swfalse_g8         0.788 ms  174.47 TFLOPS
  256x128x32_64x64_S3_swfalse_g8         0.907 ms  151.49 TFLOPS
  128x128x32_64x64_S2_swtrue_g8          0.698 ms  197.03 TFLOPS
  128x128x32_64x64_S3_swtrue_g8          0.786 ms  174.97 TFLOPS
  256x128x32_64x64_S3_swtrue_g8          0.907 ms  151.50 TFLOPS
  128x128x32_64x64_S2_swtrue_g16         0.697 ms  197.21 TFLOPS  <- Best
  128x128x32_64x64_S3_swtrue_g16         0.785 ms  175.06 TFLOPS
  256x128x32_64x64_S3_swtrue_g16         0.907 ms  151.47 TFLOPS
[Autotune] Best: 128x128x32_64x64_S2_swtrue_g16 (0.697 ms, 197.21 TFLOPS)

Key findings from autotuning on A100:

  • 2-stage pipelines beat 3-stage: deeper pipelines increase register pressure without benefit
  • Swizzle helps at large sizes: 197 vs 196 TFLOPS at N=4096, bigger gains at N=8192
  • 128×128 blocks dominate: 256×128 blocks underperform despite larger tiles
  • GROUP_SIZE_M=8 or 16 both work well depending on matrix size

Results

Benchmarked on NVIDIA A100-SXM4-40GB with N=4096:

Kernel Time (ms) TFLOPS % cuBLAS
cuBLAS (FP16) 0.51 272.1 100%
01_WMMABlockTiling 3.23 42.6 16%
02_WMMAVectorized 1.85 74.3 27%
03_WMMAAsync 1.66 82.6 30%
04_WMMAPadded 0.89 154.5 57%
05_WMMAMultistage 0.73 189.0 69%
06_WMMADoubleBuffer 0.72 189.8 70%
07_WMMADynSmem 0.74 186.8 69%
08_WMMAFinal 0.70 195.7 72%
HGEMM Performance
Figure 3: HGEMM optimization progression on A100.

Performance across matrix sizes:

Size cuBLAS Our Kernel % cuBLAS
1024 113 TFLOPS 73 TFLOPS 65%
2048 178 TFLOPS 154 TFLOPS 87%
4096 272 TFLOPS 196 TFLOPS 72%
8192 285 TFLOPS 201 TFLOPS 70%
16384 281 TFLOPS 171 TFLOPS 61%

The efficiency varies with matrix size. Our kernel performs best at medium sizes (N=2048, 8192) where tile configurations match well, and the gap widens at very large sizes (N=16384) where cuBLAS’s more advanced techniques dominate.

The Gap to cuBLAS

The remaining ~30% gap comes from techniques beyond WMMA’s abstraction level:

  • PTX mma.sync instructions: Direct control over tensor core operations, enabling finer instruction scheduling and register management
  • XOR-based swizzling: Shared memory addressing patterns that eliminate bank conflicts without the 25% padding overhead we’re paying
  • Architecture-specific tuning: cuBLAS has years of per-GPU, per-size optimization tables that we can’t match with a simple autotuner

WMMA is a good starting point, but extracting the last bits of performance requires dropping to PTX.

Conclusion

Optimizing tensor core GEMM follows the same principles as CUDA core optimization, with additional considerations:

  1. Use tensor cores: 4-16× theoretical speedup over CUDA cores for supported precisions
  2. Vectorize memory access: float4 loads for FP16, matching 128-byte transaction size
  3. Eliminate bank conflicts: Padding shared memory rows by 8 FP16 elements (or use swizzle in PTX)
  4. Pipeline memory transfers: Multi-stage async pipeline keeps tensor cores fed
  5. Autotune configurations: Optimal parameters vary significantly by matrix size and GPU architecture

WMMA provides a clean abstraction that captures 72% of cuBLAS performance on A100. The remaining gap requires lower-level techniques: PTX mma.sync for finer instruction scheduling and XOR-based swizzling to eliminate padding overhead. These will be explored in the follow-up post.

References