← back to blog

Writing a CuTe DSL Kernel for NVFP4 GEMM

Hello world! This is my first blogpost of many, and it will be about my submission to GPU Mode's x NVIDIA Kernel Competition, specifically the NVFP4 General Matrix Multiplication (GEMM) challenge. I placed 13th, slightly less than 1 microsecond away from the winning submission.

Kernel

The Challenge

Implement a block-scaled FP4 GEMM kernel that competes with NVIDIA's reference. The inputs:

  • FP4 (Float4E2M1FN) matrices A and B
  • FP8 (Float8E4M3FN) scale factors
  • FP16 output C
  • Blackwell SM100 with Tensor Memory (TMEM) and 5th-gen Tensor Cores

The benchmarks:

m=128, n=7168, k=16384
m=128, n=4096, k=7168
m=128, n=7168, k=2048

All skinny matrices with small M. This regime needs careful tuning.

Block-Scaled FP4

Standard GEMM is C = A @ B. Block-scaled FP4 adds scale factors because 4 bits isn't enough precision on its own.

Values are grouped into blocks of 16, each with a scale factor:

C[i,j] = sum_k (A[i,k] * SFA[i,k//16]) * (B[j,k] * SFB[j,k//16])

SFA and SFB undo the quantization per 16-element block. This is Microscaling (MX) format.

Blackwell Architecture

Tensor Memory (TMEM): 512-column scratchpad between registers and shared memory. Accumulators live here, so they don't compete with operands for register space.

5th-Gen Tensor Cores: Native block-scaled format support. Feed them FP4 + FP8 scale factors, they handle scaling internally. The tcgen05 instructions do this.

2-CTA Instructions: Two CTAs can fuse to work on a 256xN tile cooperatively. When M tile is 256, two CTAs pretend to be one.

Kernel Architecture

Warp-specialized persistent tile scheduler. 6 warps = 384 threads per CTA:

self.epilog_warp_id = (0, 1, 2, 3)  # Warps 0-3: Epilogue (store results)
self.mma_warp_id = 4                 # Warp 4: Matrix multiply
self.tma_warp_id = 5                 # Warp 5: TMA loads

Each warp has one job:

  • TMA Warp: Loads A, B, SFA, SFB from global to shared memory
  • MMA Warp: Matrix multiply on tensor cores
  • Epilogue Warps: Convert and store results to global memory

The Pipeline

Producer-consumer with multiple stages:

[TMA Loads] --> [SMEM Buffers] --> [MMA Compute] --> [TMEM Accum] --> [Epilogue Store]
     ^              |                    ^               |                   |
     |         ab_pipeline               |          acc_pipeline             v
     +------- (empty signal) ------------+                                [GMEM]

ab_pipeline manages shared memory buffers with barrier sync. TMA signals "full" when done loading, MMA signals "empty" when done consuming. Loads and compute overlap.

acc_pipeline manages the accumulator in tensor memory between MMA and epilogue.

Persistent Scheduling

Thread blocks stay resident and grab tiles until done:

tile_sched = utils.StaticPersistentTileScheduler.create(...)
work_tile = tile_sched.initial_work_tile_info()

while work_tile.is_valid_tile:
    # Process current tile
    tile_sched.advance_to_next_work()
    work_tile = tile_sched.get_current_work()

Amortizes launch overhead.

The Mainloop

MMA warp loops over K tiles:

for k_tile in cutlass.range(k_tile_cnt, unroll_full=True):
    # Wait for A/B data in SMEM
    ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
    
    # Copy scale factors from SMEM to TMEM
    cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t)
    cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
    
    # MMA for each K block
    for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
        tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
        tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator)
        
        cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
        
        tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
    
    # Done with this SMEM buffer
    ab_pipeline.consumer_release(ab_consumer_state)

cute.gemm invokes Blackwell tensor core instructions with block-scaled inputs.

Memory Layouts and TMA

Scale factors need a permuted layout for the tensor cores:

# ((Atom_M, Rest_M), (Atom_K, Rest_K), RestL)
sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, sf_vec_size)

Multicast TMA broadcasts data to multiple CTAs:

if self.is_a_mcast:
    a_full_mcast_mask = cpasync.create_tma_multicast_mask(
        cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
    )

Same A slice goes to multiple CTAs computing different C columns.

Stage Calculation

How many pipeline stages fit in shared memory:

@staticmethod
def _compute_stages(...):
    ab_bytes_per_stage = (
        cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
        + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
        + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one)
        + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
    )
    
    pos = (smem_capacity // occupancy - overhead) // ab_bytes_per_stage
    num_ab_stage = max(pos, 5)

More stages = more overlap = hidden memory latency. Blackwell has 228KB shared memory.

N=192 and N=64 Edge Cases

Special handling for tile sizes that don't divide evenly:

if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
    offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
    shifted_ptr = cute.recast_ptr(
        acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
        dtype=self.sf_dtype,
    )
    tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)

192 doesn't divide into tensor core tile sizes, so indexing gets creative.

The Epilogue

Convert FP32 accumulators to FP16 and store:

for subtile_idx in cutlass.range(subtile_cnt):
    # Load accumulator from TMEM to registers
    cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
    
    # Convert to output type
    acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
    acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
    tRS_rC.store(acc_vec)
    
    # Store to SMEM
    cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
    
    # TMA store to global memory
    cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)])

Overlapping Accumulator

When N=256, only one accumulator stage fits in TMEM. Release it early:

if cutlass.const_expr(self.overlapping_accum):
    if subtile_idx == self.iter_acc_early_release_in_epilogue:
        cute.arch.fence_view_async_tmem_load()
        acc_pipeline.consumer_release(acc_consumer_state)

Epilogue releases the accumulator as soon as it's copied out, so MMA can start the next tile. Sub-tile double buffering.

Split-K (Correct but 7x Slower)

For k=16384, I tried split-K: multiple CTAs compute partial sums, then reduce.

Encoded split index into the L dimension:

SPLIT_K = 4

num_ctas_mnl_splitk = (num_ctas_m, num_ctas_n, l * split_k)

combined_l = cur_tile_coord[2]
actual_l = combined_l // split_k
split_idx = combined_l % split_k

k_tile_begin = split_idx * k_tile_cnt_per_split
k_tile_end = min((split_idx + 1) * k_tile_cnt_per_split, k_tile_cnt)

Partial results go to FP32 workspace, then a Triton kernel reduces:

@triton.jit
def fused_reduce_kernel(
    workspace_ptr,  # (L*SPLIT_K, M, N) FP32
    output_ptr,     # (M, N, L) FP16
    M, N, L, SPLIT_K: tl.constexpr,
    BLOCK_SIZE: tl.constexpr = 256,
):
    acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for s in range(SPLIT_K):
        ws_offset = (l * SPLIT_K + s) * M * N + m * N + n
        val = tl.load(workspace_ptr + ws_offset, mask=mask, other=0.0)
        acc += val

It produces correct results. It's 7x slower.

Why:

  1. Workspace overhead: FP32 partials = 4x memory traffic vs direct FP16, plus extra kernel launch
  2. Reduction bottleneck: 4 non-coalesced global loads per output element
  3. Already saturated: N=7168 gives 56 output tiles. GPU is busy; split-K adds overhead
  4. Pipeline inefficiency: Fewer K-tiles per split = less work to hide latency
  5. Lost data reuse: Each split reloads A instead of reusing across K iterations

Split-K helps when M and N are small and you can't saturate the GPU. Here, N is big enough. The optimization wasn't needed.

What I Learned

  1. CuTe is powerful but dense: Learning curve is steep. Reading CUTLASS source was essential.

  2. Warp specialization enables overlap: Dedicated warps for loads/compute/store run truly in parallel.

  3. Obvious isn't always fastest: Expected 256x256 tiles to win. 128x64 with more stages was better for skinny matrices.

  4. TMEM changes everything: Accumulators in tensor memory don't compete with operands for registers.

Acknowledgements

Thanks to GPU Mode and NVIDIA for the competition. CUTLASS/CuTe docs and examples were invaluable.