Introduction

The previous post optimized single-precision GEMM using CUDA cores, reaching 89% of cuBLAS performance. But modern deep learning runs on FP16, and modern GPUs have dedicated tensor cores that deliver 4-16× higher throughput than CUDA cores for matrix operations.

This post implements half-precision GEMM (HGEMM) using NVIDIA’s WMMA (Warp Matrix Multiply-Accumulate) API, progressing from a naive tensor core kernel to 80% of cuBLAS performance on an A100. WMMA provides a high-level abstraction over tensor cores that handles the hardware complexity while still exposing the key optimization opportunities. The benchmark code is available in this repository.

A follow-up post will explore the lower-level PTX mma.sync instructions to close the remaining gap.

HGEMM Performance
Figure 1: HGEMM optimization progression on A100.

Tensor Cores and the WMMA API

Tensor cores are specialized hardware units that compute small matrix multiply-accumulate operations in a single cycle. The fundamental operation is:

\[D = A \times B + C\]

where A, B, C, and D are small matrix tiles. The supported tile dimensions and data types depend on the GPU architecture. NVIDIA provides support across multiple precisions:

  • FP16 / BF16: Half-precision for inference and training
  • TF32: 19-bit format for FP32-like range with tensor core throughput (Ampere+)
  • FP64: Double precision tensor cores (A100+, limited throughput)
  • INT8 / INT4: Integer formats for quantized inference

Each precision supports specific matrix shapes (m×n×k combinations). For example, FP16 on Ampere supports 16×16×16, 32×8×16, and 8×32×16 tile shapes. See the CUDA Programming Guide for the complete list of supported element types and sizes.

This post focuses on FP16 with 16×16×16 tiles, a common configuration for square matrices. On A100, each SM contains 3rd generation tensor cores, and a single wmma::mma_sync instruction computes 16×16×16 = 4096 multiply-add operations (8192 FLOPs) per warp.

The WMMA API abstracts tensor cores through fragments: opaque containers that hold matrix tiles distributed across a warp’s 32 threads.

#include <mma.h>
using namespace nvcuda;

// Declare fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> acc_frag;

// Load tiles from memory
wmma::load_matrix_sync(a_frag, A_ptr, lda);
wmma::load_matrix_sync(b_frag, B_ptr, ldb);
wmma::fill_fragment(acc_frag, 0.0f);

// Execute tensor core MMA
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);

// Store result
wmma::store_matrix_sync(C_ptr, acc_frag, ldc, wmma::mem_row_major);

The compiler maps these operations to the underlying tensor core instructions. Our job is to feed the tensor cores efficiently by optimizing memory access patterns, which is where all the performance comes from.

Kernel 1: Block Tiling Baseline

The first kernel adapts the SGEMM block tiling strategy for tensor cores. Each threadblock computes a $BM \times BN$ output tile; within the block, each warp computes a $WM \times WN$ sub-tile using multiple $16 \times 16$ WMMA operations.

Tensor Core GEMM Tiling Hierarchy
Figure 2: Three-level tiling hierarchy for tensor core GEMM. A threadblock computes a BM×BN output tile, each warp handles a WM×WN sub-tile, and each WMMA instruction processes 16×16 elements.

The diagram above shows the high-level structure. Let’s trace through the math with concrete numbers:

TILING HIERARCHY (BM=128, BN=128, WM=32, WN=32)
═══════════════════════════════════════════════════════════════

BLOCK TILE (128 × 128)
──────────────────────
                      BN = 128
        ◄────────────────────────────────────►
        ┌────────┬────────┬────────┬────────┐
        │ Warp 0 │ Warp 1 │ Warp 2 │ Warp 3 │  32
        ├────────┼────────┼────────┼────────┤
        │ Warp 4 │ Warp 5 │ Warp 6 │ Warp 7 │  32
 BM=128 ├────────┼────────┼────────┼────────┤
        │ Warp 8 │ Warp 9 │Warp 10 │Warp 11 │  32
        ├────────┼────────┼────────┼────────┤
        │Warp 12 │Warp 13 │Warp 14 │Warp 15 │  32
        └────────┴────────┴────────┴────────┘
            32       32       32       32

        → 16 warps = 512 threads per block
        → Each warp owns a 32×32 output region


WARP TILE (32 × 32)                    ◄── zooming into one warp
───────────────────
              WN = 32
        ◄────────────────►
        ┌───────┬────────┐
        │ MMA   │  MMA   │  16
 WM=32  ├───────┼────────┤
        │ MMA   │  MMA   │  16
        └───────┴────────┘
           16       16

        → 4 MMA operations per warp
        → Each MMA: 16×16×16 = 4096 FMAs


FRAGMENT DISTRIBUTION                  ◄── inside one MMA
─────────────────────
   A 16×16 fragment holds 256 values, but no single
   thread sees the whole matrix:

   Thread 0  → elements [?, ?, ?, ?, ?, ?, ?, ?]  ┐
   Thread 1  → elements [?, ?, ?, ?, ?, ?, ?, ?]  │
      ...                                         ├─ 32 threads × 8 = 256 ✓
   Thread 31 → elements [?, ?, ?, ?, ?, ?, ?, ?]  ┘

   The layout is opaque. WMMA handles distribution internally.
   We just call load_matrix_sync() and mma_sync().

═══════════════════════════════════════════════════════════════

The key insight: we never manipulate individual matrix elements. WMMA fragments are opaque containers. We load them from shared memory, execute MMA operations, and store the results. The hardware handles the complex data distribution across threads.

The kernel structure is shown below. To keep things focused, I’ve moved the shared memory loading and epilogue logic into helper functions (loadTileA_scalar, loadTileB_scalar, and epilogueAndStore). The full implementations are in the repository.

template <int BM, int BN, int BK, int WM, int WN>
__global__ void wmma_block_tiling(
    int M, int N, int K, __half alpha, 
    const __half *A, const __half *B, __half beta, __half *C)
{
    constexpr int MMA_M = 16, MMA_N = 16, MMA_K = 16;
    constexpr int MMA_M_TILES = WM / MMA_M;
    constexpr int MMA_N_TILES = WN / MMA_N;
    constexpr int WARPS_M = BM / WM;
    constexpr int WARPS_N = BN / WN;

    __shared__ __half As[BM * BK];
    __shared__ __half Bs[BK * BN];

    const uint warpId = threadIdx.x / 32;
    const uint warpM = warpId / WARPS_N;
    const uint warpN = warpId % WARPS_N;

    // Initialize accumulators
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> acc[MMA_M_TILES][MMA_N_TILES];
    for (int m = 0; m < MMA_M_TILES; ++m)
        for (int n = 0; n < MMA_N_TILES; ++n)
            wmma::fill_fragment(acc[m][n], __float2half(0.0f));
    
    for (int tileK = 0; tileK < K; tileK += BK) {
        // Load tiles with scalar loads (one element per thread)
        loadTileA_scalar<BM, BK, NUM_THREADS>(A, As, K, tid);
        loadTileB_scalar<BK, BN, NUM_THREADS>(B, Bs, N, tid);
        __syncthreads();

        // Process BK elements in MMA_K chunks
        for (int innerK = 0; innerK < BK; innerK += MMA_K) {
            wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> 
                a_frag[MMA_M_TILES];
            wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> 
                b_frag[MMA_N_TILES];
            
            // Load fragments and execute MMA
            for (int m = 0; m < MMA_M_TILES; ++m)
                wmma::load_matrix_sync(a_frag[m], 
                    &As[(warpM * WM + m * MMA_M) * BK + innerK], BK);
            for (int n = 0; n < MMA_N_TILES; ++n)
                wmma::load_matrix_sync(b_frag[n], 
                    &Bs[innerK * BN + warpN * WN + n * MMA_N], BN);
            
            for (int m = 0; m < MMA_M_TILES; ++m)
                for (int n = 0; n < MMA_N_TILES; ++n)
                    wmma::mma_sync(acc[m][n], a_frag[m], b_frag[n], acc[m][n]);
        }
        __syncthreads();
    }

    epilogueAndStore<...>(acc, C, N, alpha, beta, warpM, warpN);
}

Performance: 42 TFLOPS at N=4096 (15% of cuBLAS)

The baseline is slow because scalar loads waste memory bandwidth—each thread loads one 2-byte half instead of the 16-byte transactions the memory system prefers.

Kernel 2: Vectorized Memory Access

The straightforward fix: use 128-bit (float4) loads, which transfer 8 half-precision values per instruction.

template <int BM, int BK, int NUM_THREADS>
__device__ void loadTileA_vec4(const __half *A, __half *As, int K, uint tid) {
    constexpr int TOTAL_VEC = (BM * BK) / 8;
    constexpr int VEC_PER_THREAD = TOTAL_VEC / NUM_THREADS;

    #pragma unroll
    for (int i = 0; i < VEC_PER_THREAD; ++i) {
        uint idx = tid + i * NUM_THREADS;
        uint row = idx / (BK / 8);
        uint col8 = idx % (BK / 8);
        
        float4 val = reinterpret_cast<const float4*>(&A[row * K + col8 * 8])[0];
        reinterpret_cast<float4*>(&As[row * BK + col8 * 8])[0] = val;
    }
}

Performance: 74 TFLOPS at N=4096 (27% of cuBLAS)

A solid 1.8× improvement, but we’re still synchronous—threads wait for memory before computing.

Kernel 3: Asynchronous Copy

Ampere introduced cp.async, enabling direct global→shared memory transfers without going through registers. The CUDA pipeline API makes this easy:

template <int BM, int BK, int NUM_THREADS>
__device__ void loadTileA_async(const __half *A, __half *As, int K, uint tid) {
    constexpr int TOTAL_VEC = (BM * BK) / 8;
    constexpr int VEC_PER_THREAD = TOTAL_VEC / NUM_THREADS;

    #pragma unroll
    for (int i = 0; i < VEC_PER_THREAD; ++i) {
        uint idx = tid + i * NUM_THREADS;
        uint row = idx / (BK / 8);
        uint col8 = idx % (BK / 8);
        __pipeline_memcpy_async(
            &As[row * BK + col8 * 8],
            &A[row * K + col8 * 8],
            sizeof(float4)
        );
    }
}

// In main loop:
loadTileA_async<BM, BK, NUM_THREADS>(A + tileK, As, K, tid);
loadTileB_async<BK, BN, NUM_THREADS>(B + tileK * N, Bs, N, tid);
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();

Performance: 82 TFLOPS at N=4096 (30% of cuBLAS)

Modest gains—async copies reduce register pressure and allow the memory system to work ahead, but we’re hitting a different bottleneck.

Kernel 4: Shared Memory Padding

The profiler reveals the real problem: shared memory bank conflicts. When 32 threads in a warp access the same bank, accesses serialize. For FP16 data with a stride that’s a multiple of 32 bytes, many threads hit the same bank.

Padding each row by 8 half-precision elements (16 bytes) helps reduce these conflicts:

constexpr int SMEM_PAD = 8;
constexpr int A_STRIDE = BK + SMEM_PAD;  // 32 + 8 = 40 halves per row
constexpr int B_STRIDE = BN + SMEM_PAD;

// Load with padding
template <int BM, int BK, int A_STRIDE, int NUM_THREADS>
__device__ void loadTileA_async_padded(
    const __half *A, __half *As, int K, uint tid)
{
    #pragma unroll
    for (int i = 0; i < VEC_PER_THREAD; ++i) {
        uint idx = tid + i * NUM_THREADS;
        uint row = idx / (BK / 8);
        uint col8 = idx % (BK / 8);
        __pipeline_memcpy_async(
            &As[row * A_STRIDE + col8 * 8],  // Padded stride
            &A[row * K + col8 * 8],
            sizeof(float4)
        );
    }
}

Performance: 156 TFLOPS at N=4096 (57% of cuBLAS)

A massive 1.9× jump. Bank conflicts were a major bottleneck, and padding significantly reduces their impact.

Kernel 5: Multi-Stage Pipeline

With bank conflicts reduced, we turn to latency hiding. A multi-stage pipeline keeps multiple tiles in flight:

NO PIPELINING

Memory and compute are serialized. Compute stalls waiting for each load.

Cycle:        1     2     3     4     5     6     7     8     9    ...
             ──────────────────────────────────────────────────────────
Memory:      [L0]        [L1]        [L2]        [L3]        [L4]
Compute:           [C0]        [C1]        [C2]        [C3]        ...


2-STAGE PIPELINE

Prologue loads 1 tile. Main loop overlaps next load with current compute.

Cycle:        1     2     3     4     5     6     7
             ────────────────────────────────────────────
Memory:      [L0]  [L1]  [L2]  [L3]  [L4]  [L5]
Compute:           [C0]  [C1]  [C2]  [C3]  [C4]  [C5]
             
             └pro┘ └────────── main loop ───────────┘


3-STAGE PIPELINE

Prologue loads 2 tiles. Deeper buffer hides more memory latency.

Cycle:        1     2     3     4     5     6     7     8
             ─────────────────────────────────────────────────
Memory:      [L0]  [L1]  [L2]  [L3]  [L4]  [L5]
Compute:                 [C0]  [C1]  [C2]  [C3]  [C4]  [C5]
             
             └─prologue─┘ └─────────── main loop ───────────┘
extern __shared__ __half smem[];
__half* As = smem;
__half* Bs = smem + STAGES * A_STAGE_SIZE;

// Prologue: fill pipeline with STAGES-1 tiles
for (int s = 0; s < STAGES - 1 && s < numTiles; ++s) {
    loadTileA_async_padded<...>(A + s * BK, As + s * A_STAGE_SIZE, K, tid);
    loadTileB_async_padded<...>(B + s * BK * N, Bs + s * B_STAGE_SIZE, N, tid);
    __pipeline_commit();
}
__pipeline_wait_prior(0);
__syncthreads();

// Main loop
for (int tile = 0; tile < numTiles; ++tile) {
    int stage = tile % STAGES;
    
    // Issue next load while computing current tile
    if (tile + STAGES - 1 < numTiles) {
        int loadStage = (tile + STAGES - 1) % STAGES;
        loadTileA_async_padded<...>(A + (tile + STAGES - 1) * BK, 
                                    As + loadStage * A_STAGE_SIZE, K, tid);
        loadTileB_async_padded<...>(B + (tile + STAGES - 1) * BK * N, 
                                    Bs + loadStage * B_STAGE_SIZE, N, tid);
        __pipeline_commit();
    }
    
    // Compute on current stage
    computeOnStage(As + stage * A_STAGE_SIZE, Bs + stage * B_STAGE_SIZE, acc);
    
    __pipeline_wait_prior(STAGES - 2);
    __syncthreads();
}

Performance: 193 TFLOPS at N=4096 (71% of cuBLAS)

Pipeline depth matters—2-stage typically beats 3-stage on A100 due to reduced register pressure.

Kernel 6: Software Pipelining

The next optimization overlaps fragment loads from shared memory with tensor core compute. We double-buffer the WMMA fragments and interleave async DMA with compute, creating three levels of overlap: async DMA (global→shared), fragment loads (shared→register), and tensor core compute.

// Double-buffered fragments
wmma::fragment<wmma::matrix_a, ...> a_frag[2][MMA_M_TILES];
wmma::fragment<wmma::matrix_b, ...> b_frag[2][MMA_N_TILES];

// ====== PROLOGUE: fill pipeline ======
for (int s = 0; s < STAGES - 1 && s < numTiles; ++s) {
    loadTileA_async_padded<...>(A + s * BK, As + s * A_STAGE_SIZE, K, tid);
    loadTileB_async_padded<...>(B + s * BK * N, Bs + s * B_STAGE_SIZE, N, tid);
    __pipeline_commit();
}
__pipeline_wait_prior(0);
__syncthreads();

// Load initial k=0 fragments from stage 0
loadFragments(a_frag[0], b_frag[0], As, Bs, k=0);

int loadTile = STAGES - 1;

// ====== MAIN LOOP ======
for (int tile = 0; tile < numTiles; ++tile) {
    int stage = tile % STAGES;
    __half* As_tile = As + stage * A_STAGE_SIZE;
    __half* Bs_tile = Bs + stage * B_STAGE_SIZE;

    // Phase 1: load k+1 fragments, MMA on k (for k = 0..K_STEPS-2)
    for (int k = 1; k < K_STEPS; ++k) {
        // Load k into buffer slot (k % 2)
        loadFragments(a_frag[k % 2], b_frag[k % 2], As_tile, Bs_tile, k);
        
        // MMA on k-1 (already in buffer slot (k-1) % 2)
        for (int m = 0; m < MMA_M_TILES; ++m)
            for (int n = 0; n < MMA_N_TILES; ++n)
                wmma::mma_sync(acc[m][n], a_frag[(k-1) % 2][m], 
                               b_frag[(k-1) % 2][n], acc[m][n]);
        
        // Issue next tile's async DMA at midpoint (overlaps with MMA)
        if (k == K_STEPS / 2 && loadTile < numTiles) {
            int loadStage = loadTile % STAGES;
            loadTileA_async_padded<...>(A + loadTile * BK, 
                                        As + loadStage * A_STAGE_SIZE, K, tid);
            loadTileB_async_padded<...>(B + loadTile * BK * N, 
                                        Bs + loadStage * B_STAGE_SIZE, N, tid);
            __pipeline_commit();
            ++loadTile;
        }
    }

    // Wait for next tile's data
    __pipeline_wait_prior(STAGES - 2);
    __syncthreads();

    // Phase 2: load k=0 of NEXT tile, MMA on last k of CURRENT tile
    if (tile + 1 < numTiles) {
        int nextStage = (tile + 1) % STAGES;
        loadFragments(a_frag[0], b_frag[0], As + nextStage * A_STAGE_SIZE, 
                      Bs + nextStage * B_STAGE_SIZE, k=0);
    }

    // MMA on last k-step (buffer slot (K_STEPS-1) % 2)
    for (int m = 0; m < MMA_M_TILES; ++m)
        for (int n = 0; n < MMA_N_TILES; ++n)
            wmma::mma_sync(acc[m][n], a_frag[(K_STEPS-1) % 2][m], 
                           b_frag[(K_STEPS-1) % 2][n], acc[m][n]);
}

// ====== EPILOGUE ======
epilogueAndStore<...>(acc, C, N, alpha, beta, warpM, warpN);

The key insight is the double-buffering pattern. With K_STEPS=2 (BK=32):

Tile 0:
  Phase 1, k=1: load frag[1], MMA frag[0], issue DMA for tile 1
  Phase 2:      load tile1.k=0 into frag[0], MMA frag[1]

Tile 1:
  Phase 1, k=1: load frag[1], MMA frag[0], issue DMA for tile 2
  Phase 2:      load tile2.k=0 into frag[0], MMA frag[1]
  ...

Performance: 203 TFLOPS at N=4096 (74% of cuBLAS)

Kernel 7: Final Optimizations

The final kernel adds block swizzling for L2 cache locality and a vectorized epilogue:

Block swizzling: Instead of processing blocks row-by-row, we process in column groups. Blocks in the same column share B tiles, improving L2 reuse:

STANDARD 2D GRID (row-major execution)

                        blockIdx.x →
                 0       1       2       3
              ┌───────┬───────┬───────┬───────┐
            0 │   0   │   1   │   2   │   3   │
              ├───────┼───────┼───────┼───────┤
            1 │   4   │   5   │   6   │   7   │
blockIdx.y    ├───────┼───────┼───────┼───────┤
     ↓      2 │   8   │   9   │  10   │  11   │
              ├───────┼───────┼───────┼───────┤
            3 │  12   │  13   │  14   │  15   │
              └───────┴───────┴───────┴───────┘

Execution: 0→1→2→3→4→5→... (row by row)
Problem: Block 4 needs B[:,0] but it may be evicted by B[:,1], B[:,2], B[:,3].


SWIZZLED 1D GRID (BLOCK_STRIDE=2)

                        blockIdx.x →
                  0       1       2       3
              ┌───────┬───────*───────┬───────┐
            0 │   0   │   2   *   4   │   6   │
              ├───────┼───────*───────┼───────┤
            1 │   1   │   3   *   5   │   7   │
blockIdx.y    * * * * * * * * * * * * * * * * *
     ↓      2 │   8   │  10   *  12   │  14   │
              ├───────┼───────*───────┼───────┤
            3 │   9   │  11   *  13   │  15   │
              └───────┴───────*───────┴───────┘

Execution: 0→1→2→3→4→5→... (down columns within each group)
Blocks 0,1 share B[:,0]. Blocks 2,3 share B[:,1]. Better L2 reuse.
if constexpr (USE_SWIZZLE) {
    const uint numBlocksM = (M + BM - 1) / BM;
    const uint numBlocksN = (N + BN - 1) / BN;
    const uint numBlocksInGroup = BLOCK_STRIDE * numBlocksN;
    const uint groupId = blockIdx.x / numBlocksInGroup;
    const uint firstBlockM = groupId * BLOCK_STRIDE;
    const uint groupSizeM = min(numBlocksM - firstBlockM, (uint)BLOCK_STRIDE);
    blockM = firstBlockM + (blockIdx.x % groupSizeM);
    blockN = (blockIdx.x % numBlocksInGroup) / groupSizeM;
}

Vectorized epilogue: After accumulation, we reuse the A/B shared memory to stage the output tile, then write to global memory with coalesced float4 stores:

// Store fragments to shared memory (reusing A/B space)
for (int m = 0; m < MMA_M_TILES; ++m)
    for (int n = 0; n < MMA_N_TILES; ++n)
        wmma::store_matrix_sync(&C_smem[(warpM * WM + m * MMA_M) * C_SMEM_STRIDE
                                       + (warpN * WN + n * MMA_N)], 
                                acc[m][n], C_SMEM_STRIDE, wmma::mem_row_major);
__syncthreads();

// Vectorized copy to global
for (int i = 0; i < VECS_PER_THREAD; ++i) {
    int vec_idx = tid + i * NUM_THREADS;
    int row = vec_idx / (BN / 8);
    int col8 = vec_idx % (BN / 8);
    float4 val = *reinterpret_cast<float4*>(&C_smem[row * C_SMEM_STRIDE + col8 * 8]);
    *reinterpret_cast<float4*>(&C[row * N + col8 * 8]) = val;
}

Performance: 219 TFLOPS at N=4096 (80% of cuBLAS)

Autotuning

Tile sizes, pipeline stages, and swizzle parameters significantly affect performance. We autotune across configurations:

[Autotune] Testing 10 configurations on 4096x4096x4096...
  256x128x32_64x64_S3_swfalse_bs16            0.664 ms  207.06 TFLOPS
  256x128x32_64x64_S3_swtrue_bs16             0.663 ms  207.35 TFLOPS
  128x128x32_64x64_S3_swfalse_bs16            0.697 ms  197.29 TFLOPS
  128x128x32_64x64_S3_swtrue_bs16             0.692 ms  198.64 TFLOPS
  128x256x32_64x64_S3_swtrue_bs16             0.640 ms  214.85 TFLOPS  <- Best
  256x128x32_64x32_S3_swtrue_bs16             0.751 ms  183.06 TFLOPS
  256x128x32_64x64_S2_swtrue_bs16             1.058 ms  129.90 TFLOPS
  128x128x32_64x64_S2_swtrue_bs16             0.829 ms  165.76 TFLOPS
  256x128x64_64x64_S2_swtrue_bs16             0.905 ms  151.90 TFLOPS
  128x128x64_64x64_S3_swtrue_bs16             0.848 ms  162.08 TFLOPS
[Autotune] Best: 128x256x32_64x64_S3_swtrue_bs16 (0.640 ms, 214.85 TFLOPS)

Key findings from autotuning on A100:

  • 3-stage pipelines often beat 2-stage at larger sizes where the extra latency hiding matters
  • Swizzle helps at large sizes: the benefit grows with matrix size
  • Best tile sizes vary by N: 128×128 for N=1024, 128×256 for N≥4096
  • BLOCK_STRIDE=16 works well across sizes

Results

Benchmarked on NVIDIA A100-SXM4-40GB with N=4096:

Kernel Time (ms) TFLOPS % cuBLAS
cuBLAS (FP16) 0.50 274 100%
01_WMMABlockTiling 3.27 42 15%
02_WMMAVectorized 1.85 74 27%
03_WMMAAsync 1.68 82 30%
04_WMMAPadded 0.88 156 57%
05_WMMAMultistage 0.71 193 71%
06_WMMAPipelining 0.68 203 74%
07_WMMAFinal 0.63 219 80%
HGEMM Performance
Figure 3: HGEMM optimization progression on A100.

The Gap to cuBLAS

The remaining ~20% gap comes from techniques beyond WMMA’s abstraction level:

  • PTX mma.sync instructions: Direct control over tensor core operations, enabling finer instruction scheduling and register management
  • XOR-based swizzling: Shared memory addressing patterns that eliminate bank conflicts without the 25% padding overhead we’re paying
  • Architecture-specific tuning: cuBLAS has years of per-GPU, per-size optimization tables that we can’t match with a simple autotuner

WMMA is a good starting point, but extracting the last bits of performance requires dropping to PTX.

Conclusion

Optimizing tensor core GEMM follows the same principles as CUDA core optimization, with additional considerations:

  1. Use tensor cores: 4-16× theoretical speedup over CUDA cores for supported precisions
  2. Vectorize memory access: float4 loads for FP16, matching 128-byte transaction size
  3. Reduce bank conflicts: Padding shared memory rows by 8 FP16 elements helps (XOR swizzle in PTX can do better)
  4. Pipeline at multiple levels: Async DMA, fragment prefetch, and cross-tile overlap
  5. Autotune configurations: Optimal parameters vary significantly by matrix size and GPU architecture

WMMA provides a clean abstraction that captures 80% of cuBLAS performance on A100. The remaining gap requires lower-level techniques: PTX mma.sync for finer instruction scheduling and XOR-based swizzling to eliminate padding overhead. These will be explored in the follow-up post.

References