Introduction

General Matrix Multiply (GEMM) computes C = αAB + βC and is the computational backbone of deep learning. Every linear layer, attention mechanism, and convolution (via im2col) reduces to GEMM. Understanding how to optimize it reveals the core principles of GPU performance.

This post implements single-precision GEMM (SGEMM) from scratch using CUDA cores only, no tensor cores. We progressively optimize from a naive baseline to 89% of cuBLAS performance on an H200. A follow-up post will explore tensor core acceleration. The benchmark code is available in this repository.

SGEMM Performance
Figure 1: SGEMM optimization progression on H200. We start at 4 TFLOPS and reach 45 TFLOPS (89% of cuBLAS).

The Problem

Given matrices A ($M \times K$) and B ($K \times N$), compute C ($M \times N$) where each element is a dot product:

\[c_{ij} = \sum_{k=0}^{K-1} a_{ik} \cdot b_{kj}\]
One thread per output
Figure 1: Each output element requires reading one row of A and one column of B.

The naive approach assigns one thread per output element. Each thread independently computes a dot product by iterating through the K dimension. This results in massive redundant global memory traffic: every row of A and column of B is re-read N and M times respectively.

Kernel 1: Naive Baseline

template <int BM, int BN>
__global__ void sgemm_baseline(
    int M, int N, int K, float alpha, 
    const float *A, const float *B, float beta, float *C)
{
    // 1D thread index
    uint tid = threadIdx.x;

    // Convert 1D thread index to 2D position within a block tile
    uint tx = tid % BN; // column within block
    uint ty = tid / BN; // row within block

    // Global position in C
    uint col = blockIdx.x * BN + tx;
    uint row = blockIdx.y * BM + ty;

    if (row < M && col < N) {
        // Compute dot product
        float sum = 0.0f;
        for (uint k = 0; k < K; ++k) {
            sum += A[row * K + k] * B[k * N + col];
        }

        C[row * N + col] = alpha * sum + beta * C[row * N + col];
    }
}
// Host side kernel launch
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
dim3 block(BM * BN);
sgemm_baseline<BM, BN><<<grid, block>>>(M, N, K, alpha, A, B, beta, C);

Each thread loads K elements from A and K elements from B. That’s 2K floats = 8K bytes. For those loads, we get K multiply-adds = 2K FLOPs. The arithmetic intensity:

\[\text{Arithmetic Intensity} = \frac{2K \text{ FLOPs}}{8K \text{ bytes}} = 0.25 \text{ FLOPs/byte}\]

Modern GPUs need 10-30+ FLOPs/byte to be compute-bound. At 0.25, we’re severely memory-bound with zero data reuse.

Kernel Time (ms) TFLOPS % cuBLAS
01_Baseline 269.5 4.1 8%

Kernel 2: Block Tiling with Shared Memory

The insight: threads within a block can share data. By loading tiles of A and B into shared memory cooperatively, we amortize global memory traffic across all threads in the block.

Block tiling
Figure 2: A block loads $BM \times BK$ from A and $BK \times BN$ from B into shared memory, then iterates along K.

This reduces global memory reads from 2K per thread to 2K/BK per thread for the cooperative loads. With BK=32, that’s a $32\times$ reduction in global traffic.

However, we’re still limited: each thread computes one output (2K FLOPs) while the block loads $(BM \times BK + BK \times BN)$ floats per tile $\times$ K/BK tiles. The arithmetic intensity improves but remains modest because each thread only produces one output.

template <int BM, int BN, int BK>
__global__ void sgemm_block_tiling(
    int M, int N, int K, float alpha, 
    const float *A, const float *B, float beta, float *C)
{
    static_assert(BK == BN && BK == BM, "This kernel requires BM == BN == BK");

    __shared__ float As[BM * BK];
    __shared__ float Bs[BK * BN];

    // 1D thread index
    uint tid = threadIdx.x;

    // Convert to 2D position within block tile
    uint tx = tid % BN; // column within block
    uint ty = tid / BN; // row within block

    // Global position in C
    uint col = blockIdx.x * BN + tx;
    uint row = blockIdx.y * BM + ty;

    // Move A and B pointers to this block's starting position
    A += blockIdx.y * BM * K;
    B += blockIdx.x * BN;

    float sum = 0.0f;

    // Loop over tiles along K
    for (uint tileIdx = 0; tileIdx < K; tileIdx += BK) {
        // Load tile into shared memory
        // Each thread loads one element of As and one element of Bs
        As[ty * BK + tx] = A[ty * K + tx]; // Note: tx < BK required
        Bs[ty * BN + tx] = B[ty * N + tx]; // Note: ty < BK required

        __syncthreads();

        // Advance pointers for next iteration
        A += BK;
        B += BK * N;

        // Compute partial dot product from shared memory
        for (uint k = 0; k < BK; ++k) {
            sum += As[ty * BK + k] * Bs[k * BN + tx];
        }
        
        __syncthreads();
    }

    if (row < M && col < N) {
        C[row * N + col] = alpha * sum + beta * C[row * N + col];
    }
}
Kernel Time (ms) TFLOPS % cuBLAS
01_Baseline 269.5 4.1 8%
02_BlockTiling 123.2 8.9 17%

Kernel 3: Thread Tiling (2D Register Blocking)

Block tiling reduces global memory traffic, but each thread still makes BK trips to shared memory per inner loop. The next insight: have each thread compute a tile of outputs, not just one element.

Thread tiling
Figure 3: Each thread computes a $TM \times TN$ tile of outputs using outer products. Thread 0 loads a column fragment from A and a row fragment from B, then accumulates all $TM \times TN$ results.

Instead of dot products, we compute outer products. Each thread loads TM values from A and TN values from B into registers, then computes all $TM \times TN$ products. This amortizes the shared memory loads across $TM \times TN$ outputs.

template <int BM, int BN, int BK, int TM, int TN>
__global__ void sgemm_thread_tiling(
    int M, int N, int K, float alpha, 
    const float *A, const float *B, float beta, float *C) 
{
    constexpr int NUM_THREADS = (BM / TM) * (BN / TN);

    __shared__ float As[BM * BK];
    __shared__ float Bs[BK * BN];
    
    uint tid = threadIdx.x;
    
    // Thread position in block
    uint tx = tid % (BN / TN);
    uint ty = tid / (BN / TN);
    
    // Global starting position for this thread's tile
    uint colStart = blockIdx.x * BN + tx * TN;
    uint rowStart = blockIdx.y * BM + ty * TM;
    
    // Move A and B pointers to this block's starting position
    A += blockIdx.y * BM * K;
    B += blockIdx.x * BN;
    
    // Registers for results and fragments
    float tmp[TM * TN] = {0.0f};
    float regM[TM];
    float regN[TN];
    
    // How many elements each thread loads
    constexpr uint numAPerThread = (BM * BK) / NUM_THREADS;
    constexpr uint numBPerThread = (BK * BN) / NUM_THREADS;
    
    // Loop over tiles along K
    for (uint tileIdx = 0; tileIdx < K; tileIdx += BK) {
        
        // Load As into shared memory
        for (uint i = 0; i < numAPerThread; ++i) {
            uint idx = tid + i * NUM_THREADS;
            uint aRow = idx / BK;
            uint aCol = idx % BK;
            As[aRow * BK + aCol] = A[aRow * K + aCol];
        }
        
        // Load Bs into shared memory
        for (uint i = 0; i < numBPerThread; ++i) {
            uint idx = tid + i * NUM_THREADS;
            uint bRow = idx / BN;
            uint bCol = idx % BN;
            Bs[bRow * BN + bCol] = B[bRow * N + bCol];
        }
        
        __syncthreads();
        
        A += BK;
        B += BK * N;
        
        // Compute outer products
        for (uint k = 0; k < BK; ++k) {
            // Load regM from As
            for (uint m = 0; m < TM; ++m) {
                regM[m] = As[(ty * TM + m) * BK + k];
            }
            
            // Load regN from Bs
            for (uint n = 0; n < TN; ++n) {
                regN[n] = Bs[k * BN + tx * TN + n];
            }
            
            // Outer product
            for (uint m = 0; m < TM; ++m) {
                for (uint n = 0; n < TN; ++n) {
                    tmp[m * TN + n] += regM[m] * regN[n];
                }
            }
        }
        
        __syncthreads();
    }
    
    // Write results
    for (uint m = 0; m < TM; ++m) {
        uint row = rowStart + m;
        for (uint n = 0; n < TN; ++n) {
            uint col = colStart + n;
            if (row < M && col < N) {
                C[row * N + col] = alpha * tmp[m * TN + n] + beta * C[row * N + col];
            }
        }
    }
}

With TM=TN=8, each thread computes 64 outputs. Per K iteration, a thread loads TM + TN = 16 values from shared memory and performs $TM \times TN$ = 64 multiply-adds = 128 FLOPs. The arithmetic intensity for shared memory access:

\[\frac{128 \text{ FLOPs}}{16 \times 4 \text{ bytes}} = 2.0 \text{ FLOPs/byte}\]

That’s $8\times$ better than the naive kernel. The global memory intensity improves even more dramatically since shared memory loads are amortized across all BK iterations.

Kernel Time (ms) TFLOPS % cuBLAS
01_Baseline 269.5 4.1 8%
02_BlockTiling 123.2 8.9 17%
03_ThreadTiling 59.0 18.6 36%

Kernel 4: Vectorized Memory Access

Global memory transactions are most efficient at 128 bits (float4). We apply vectorization to both global→shared loads and shared→register loads.

Loading A (with transpose):

Matrix A is row-major, but threads reading a column of As for their regM fragment would cause strided access. We transpose during the store:

// Load 4 floats from a row of A
float4 tmp4 = reinterpret_cast<const float4*>(&A[aRow * K + aCol])[0];

// Store transposed: row becomes column
As[(aCol + 0) * BM + aRow] = tmp4.x;
As[(aCol + 1) * BM + aRow] = tmp4.y;
As[(aCol + 2) * BM + aRow] = tmp4.z;
As[(aCol + 3) * BM + aRow] = tmp4.w;

Now reading a column of As (for regM) accesses consecutive addresses.

Loading B (contiguous):

Matrix B loads directly. Rows of B map to rows of Bs, and threads reading regN access consecutive elements:

// Load 4 floats from B
float4 tmp4 = reinterpret_cast<const float4*>(&B[bRow * N + bCol])[0];

// Store directly (no transpose needed)
reinterpret_cast<float4*>(&Bs[bRow * BN + bCol])[0] = tmp4;

Vectorized fragment loads:

The same principle applies when loading from shared memory to registers. With TM=8, we load regM as two float4 operations:

for (uint m = 0; m < TM; m += 4) {
    float4 tmp4 = reinterpret_cast<const float4*>(&As[k * BM + ty * TM + m])[0];
    regM[m + 0] = tmp4.x;
    regM[m + 1] = tmp4.y;
    regM[m + 2] = tmp4.z;
    regM[m + 3] = tmp4.w;
}

The same pattern applies to regN. Both kernels 04a (vectorized global only) and 04b (vectorized global + shared) were benchmarked:

Kernel Time (ms) TFLOPS % cuBLAS
03_ThreadTiling 59.0 18.6 36%
04a_VecGmem 33.9 32.5 63%
04b_VecGmemSmem 32.4 33.9 66%

Kernel 5: Double Buffering

So far, each tile iteration follows: load → sync → compute → sync. The compute phase sits idle while waiting for loads. Double buffering overlaps load of tile i+1 with computation on tile i.

The pipeline pattern:

Without double buffering:
  [Load 0][Sync][Compute 0][Sync][Load 1][Sync][Compute 1][Sync]...

With double buffering:
  [Load 0][Sync]
               [Compute 0 + Load 1][Sync]
                                        [Compute 1 + Load 2][Sync]...

The kernel follows a prologue → main loop → epilogue structure:

__shared__ float As[2][BK * BM];  // Two buffers
__shared__ float Bs[2][BK * BN];

// ====== PROLOGUE ======
// Load first tile (no overlap possible yet)
loadTileA(A, As[0], ...);
loadTileB(B, Bs[0], ...);
__syncthreads();

A += BK;
B += BK * N;

// Prefetch first fragment to registers
loadFragment(As[0], Bs[0], regM[0], regN[0], 0, ...);

// ====== MAIN LOOP ======
uint smemWrite = 1, smemRead = 0;

for (uint tile = 1; tile < numTiles; ++tile) {
    // Prefetch next tile to alternate buffer
    loadTileA(A, As[smemWrite], ...);
    loadTileB(B, Bs[smemWrite], ...);
    
    A += BK;
    B += BK * N;
    
    // Process current tile (overlapped with loads)
    processTile(As[smemRead], Bs[smemRead], regM, regN, tmp, ...);
    
    __syncthreads();
    
    // Swap buffers
    smemWrite = 1 - smemWrite;
    smemRead = 1 - smemRead;
    
    // Prefetch first fragment of next tile
    loadFragment(As[smemRead], Bs[smemRead], regM[0], regN[0], 0, ...);
}

// ====== EPILOGUE ======
// Process final tile (no more loads)
processTile(As[smemRead], Bs[smemRead], regM, regN, tmp, ...);

// Write results to global memory
storeResult(C, tmp, ...);

We also double-buffer at the register level: while computing with regM[0]/regN[0], prefetch regM[1]/regN[1] for the next k iteration. Both shared memory (05a) and register (05b) double buffering were benchmarked:

Kernel Time (ms) TFLOPS % cuBLAS
04b_VecGmemSmem 32.4 33.9 66%
05a_DoubleBufferSmem 26.6 41.3 81%
05b_DoubleBufferSmemReg 26.5 41.5 81%

Kernel 6: Asynchronous Copy (Ampere+)

On SM80+, cp.async copies data directly from global to shared memory without staging through registers:

__pipeline_memcpy_async(
    &Bs[bRow * BN + bCol],
    &B[bRow * N + bCol],
    sizeof(float4)
);
__pipeline_commit();

// ... compute while copy is in flight ...

__pipeline_wait_prior(0);
__syncthreads();

This further improves latency hiding by freeing registers that would otherwise hold in-flight data.

Kernel Time (ms) TFLOPS % cuBLAS
05b_DoubleBufferSmemReg 26.5 41.5 81%
06_AsyncCopy 25.7 42.8 83%

Autotuning

Template parameters (BM, BN, BK, TM, TN) significantly impact performance. Rather than guess, we autotune across configurations:

[Autotune] Testing 20 configurations on 8192x8192x8192...
  64x64x16_8x8              28.175 ms   39.02 TFLOPS
  64x128x16_8x8             24.842 ms   44.26 TFLOPS
  64x256x16_8x8             24.224 ms   45.39 TFLOPS  <- Best
  128x128x16_8x8            25.675 ms   42.82 TFLOPS
  ...
[Autotune] Best: 64x256x16_8x8 (24.224 ms, 45.39 TFLOPS)

The optimal configuration ($64 \times 256 \times 16$ with $8 \times 8$ thread tiles) differs from CUTLASS defaults because our implementation uses row-major layouts and different memory access patterns.

Results

Benchmarked on NVIDIA H200 (141 GB HBM3e) with N=8192:

Kernel Time (ms) TFLOPS % cuBLAS
cuBLAS (FP32) 21.4 51.3 100%
01_Baseline 269.5 4.1 8%
02_BlockTiling 123.2 8.9 17%
03_ThreadTiling 59.0 18.6 36%
04a_VecGmem 33.9 32.5 63%
04b_VecGmemSmem 32.4 33.9 66%
05a_DoubleBufferSmem 26.6 41.3 81%
05b_DoubleBufferSmemReg 26.5 41.5 81%
06_AsyncCopy 25.7 42.8 83%
06_AsyncCopy_Autotuned 24.2 45.4 89%
SGEMM Performance
Figure 1: SGEMM optimization progression on H200. We start at 4 TFLOPS and reach 45 TFLOPS (89% of cuBLAS).

The Gap to cuBLAS

The remaining 11% gap comes from techniques we haven’t implemented:

  • Shared memory swizzling: XOR-based addressing to eliminate bank conflicts entirely
  • Warp specialization: Dedicated producer/consumer warps for better pipelining
  • Instruction scheduling: Hand-tuned assembly for optimal ILP
  • Tensor cores: cuBLAS uses TF32 by default for FP32 inputs (we disabled this for fair comparison via CUBLAS_COMPUTE_32F)

For a hand-written SIMT kernel, 89% of cuBLAS is a reasonable ceiling without diving into PTX.

Conclusion

GEMM optimization follows a clear hierarchy:

  1. Reduce global memory traffic: Block tiling with shared memory
  2. Increase arithmetic intensity: Thread tiling (outer products)
  3. Maximize memory bandwidth: Vectorized loads, coalesced access
  4. Hide latency: Double buffering, async copy
  5. Find optimal parameters: Autotuning

The same principles (tiling, vectorization, latency hiding) recur throughout GPU optimization. Master them on GEMM and they transfer everywhere.

References