Summary
Using MmaAtomSM80Type.get() with FP8 types (Float8E4M3FN, shape (16, 8, 32)) causes a segfault during MLIR→PTX compilation when targeting SM120 (RTX 5090, GB10/DGX Spark, sm_121a). The MLIR backend appears to lack lowering support for SM120's mma.sync.aligned.kind::f8f6f4.m16n8k32 instruction variant.
Background
SM120 has an FP8 MMA instruction defined in C++ CuTe:
// cute/arch/mma_sm120.hpp lines 668-697
template <>
struct SM120_16x8x32_TN<float_e4m3_t, float_e4m3_t, float> {
// PTX: mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32
};
This is a different instruction from SM89's FP8 MMA (mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 — no kind:: prefix). SM120 has CUTE_ARCH_F8F6F4_MMA_ENABLED but not CUTE_ARCH_MMA_F32_SM89_ENABLED (per cute/arch/config.hpp).
The CuTe Python DSL currently only exposes:
MmaF16BF16Op → MmaAtomSM80Type with F16/BF16 (works on SM120)
MmaSM120BlockScaledOp → MmaAtomSM120BlockScaledType with FP4 block-scaled (works on SM120)
There is no exposed path for SM120's non-block-scaled FP8 MMA.
Reproducer
Bypassing MmaF16BF16Op's Python-level validation and calling MmaAtomSM80Type.get() directly with FP8 types:
import cutlass
from cutlass import cute
from cutlass.cute.nvgpu.warp.mma import WarpMmaOp, MmaF16BF16Trait, make_atom, _pack_shape
from cutlass.cute.typing import Float8E4M3FN, Float32
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from dataclasses import dataclass
from typing import Type
@dataclass(frozen=True)
class MmaFP8Op(WarpMmaOp):
ab_dtype: Type[cutlass.Numeric] = Float8E4M3FN
acc_dtype: Type[cutlass.Numeric] = Float32
shape_mnk: tuple = (16, 8, 32)
def _make_trait(self, *, loc=None, ip=None, **kwargs):
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM80Type.get(
shape_mnk.type.attribute,
self.ab_dtype.mlir_type,
self.ab_dtype.mlir_type,
self.acc_dtype.mlir_type,
)
return MmaF16BF16Trait(make_atom(ty, loc=loc, ip=ip))
def _verify_fragment_A(self, input, *, loc=None, ip=None): pass
def _verify_fragment_B(self, input, *, loc=None, ip=None): pass
The MLIR type creation succeeds — a kernel that creates the atom without using it compiles fine. However, when the MMA is actually used in a compute path (cute.gemm), the MLIR→PTX lowering segfaults:
!!!!!!! Segfault encountered !!!!!!!
File "./nptl/pthread_once.c", line 116, in __pthread_once_slow
Behavior
| Step |
Result |
MmaAtomSM80Type.get(shape=(16,8,32), aType=f8E4M3FN, bType=f8E4M3FN, cType=f32) |
Success (MLIR type created) |
Create tiled_mma from the atom |
Success |
cute.gemm(tiled_mma, acc, fragA, fragB, acc) in kernel body |
Segfault during cute.compile() |
Expected Behavior
MmaAtomSM80Type.get() with FP8 types should either:
- Lower to SM120's
mma.sync.aligned.kind::f8f6f4.m16n8k32 PTX instruction when targeting sm_120a/sm_121a
- Or raise a clear error (not segfault) indicating FP8 MMA is unsupported for the target arch
Environment
- Device: NVIDIA GB10 (DGX Spark), SM121a
- CUDA: 13.0
- nvidia-cutlass-dsl: 4.3.5 (pip)
- Target arch:
CUTE_DSL_ARCH=sm_121a
- Python: 3.13.3
Impact
SM120's non-block-scaled FP8 MMA (m16n8k32 kind::f8f6f4) offers 2x theoretical throughput over BF16 (m16n8k16) for attention kernels. This is the highest-impact optimization available for SM120 flash attention, and it's currently blocked by this lowering gap.
The C++ CuTe headers already define the instruction (SM120_16x8x32_TN with 77 specializations for different FP8/FP6/FP4 type combinations). The Python DSL just needs the MLIR lowering to be connected.
Contributed by Second Nature Computing — tested on DGX Spark hardware
Summary
Using
MmaAtomSM80Type.get()with FP8 types (Float8E4M3FN, shape(16, 8, 32)) causes a segfault during MLIR→PTX compilation when targeting SM120 (RTX 5090, GB10/DGX Spark,sm_121a). The MLIR backend appears to lack lowering support for SM120'smma.sync.aligned.kind::f8f6f4.m16n8k32instruction variant.Background
SM120 has an FP8 MMA instruction defined in C++ CuTe:
This is a different instruction from SM89's FP8 MMA (
mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32— nokind::prefix). SM120 hasCUTE_ARCH_F8F6F4_MMA_ENABLEDbut notCUTE_ARCH_MMA_F32_SM89_ENABLED(percute/arch/config.hpp).The CuTe Python DSL currently only exposes:
MmaF16BF16Op→MmaAtomSM80Typewith F16/BF16 (works on SM120)MmaSM120BlockScaledOp→MmaAtomSM120BlockScaledTypewith FP4 block-scaled (works on SM120)There is no exposed path for SM120's non-block-scaled FP8 MMA.
Reproducer
Bypassing
MmaF16BF16Op's Python-level validation and callingMmaAtomSM80Type.get()directly with FP8 types:The MLIR type creation succeeds — a kernel that creates the atom without using it compiles fine. However, when the MMA is actually used in a compute path (
cute.gemm), the MLIR→PTX lowering segfaults:Behavior
MmaAtomSM80Type.get(shape=(16,8,32), aType=f8E4M3FN, bType=f8E4M3FN, cType=f32)tiled_mmafrom the atomcute.gemm(tiled_mma, acc, fragA, fragB, acc)in kernel bodycute.compile()Expected Behavior
MmaAtomSM80Type.get()with FP8 types should either:mma.sync.aligned.kind::f8f6f4.m16n8k32PTX instruction when targetingsm_120a/sm_121aEnvironment
CUTE_DSL_ARCH=sm_121aImpact
SM120's non-block-scaled FP8 MMA (
m16n8k32 kind::f8f6f4) offers 2x theoretical throughput over BF16 (m16n8k16) for attention kernels. This is the highest-impact optimization available for SM120 flash attention, and it's currently blocked by this lowering gap.The C++ CuTe headers already define the instruction (
SM120_16x8x32_TNwith 77 specializations for different FP8/FP6/FP4 type combinations). The Python DSL just needs the MLIR lowering to be connected.Contributed by Second Nature Computing — tested on DGX Spark hardware