Skip to content

gemm/conv: promote indices_ gather buffer from int32 to int64#36

Open
ladvu wants to merge 1 commit intoFindDefinition:mainfrom
ladvu:int64-offset-fix
Open

gemm/conv: promote indices_ gather buffer from int32 to int64#36
ladvu wants to merge 1 commit intoFindDefinition:mainfrom
ladvu:int64-offset-fix

Conversation

@ladvu
Copy link
Copy Markdown

@ladvu ladvu commented Apr 22, 2026

The MaskTileIteratorGather and ForwardDgradSparseIOIterator classes store per-row gather offsets computed as idx * stride * sizeof(dtype) in a fixed int32 array. The multiplication is done in int32 and overflows once idx * stride * nbytes >= 2^31, producing illegal memory accesses or silent wrong results.

Minimal reproducer: shuffle-AC GEMM with fp32 and K=256. max_idx = 2^21 (where max_idx * K * 4 = 2^31) crashes the kernel; max_idx = 2^21 - 1 works. The break point matches the overflow exactly.

Fix:

  • indices_ storage type: int32 -> int64
  • cast indice_ptr_[...] and thread_offset_[...] to int64_t before the stride * nbytes multiply, so the product is computed in 64 bits
  • widen the host-side assertion guard (nvrtc_code.py / conv/main.py / conv/params.py) from INT32_MAX to INT64_MAX so the wrapper no longer refuses shapes the kernel can now handle
  • conv/sparse_iters.py: widen get_indice_offset() return type and get() arg type to int64_t to match; restructure the wgrad-out PTX-load path to load 32-bit values into an int mask_inds[] temporary first (the PTX store is 32-bit, which would otherwise leave the upper half of an int64 indices_[] slot uninitialized)

Also updated cumm/gemm/frozen/{init.py,mask_iters.py} which carry a snapshot of the iterator code used by cumm/gemm/gather.py, keeping the AOT gather utility consistent with the NVRTC path.

Validated:

  • test/test_int64_large_gemm.py (new) -- shuffle-AC with N = 16M rows x K = 256 float32 (features = 16 GiB, 8x INT32_MAX). Probes straddle the 2^31 byte boundary and all match a PyTorch reference within fp32 tolerance (rel_err ~1e-6)
  • shuffle-AB sweep (backward-weight) up to 16M x 64 x 4 = 4 GiB, all OK
  • end-to-end through downstream spconv SubMConv3d (kernel=3, N = 4M, C = 256 -> features = 4 GiB), full-tensor rel_err = 5e-6

The MaskTileIteratorGather and ForwardDgradSparseIOIterator classes
store per-row gather offsets computed as `idx * stride * sizeof(dtype)`
in a fixed int32 array. The multiplication is done in int32 and
overflows once `idx * stride * nbytes >= 2^31`, producing illegal
memory accesses or silent wrong results.

Minimal reproducer: shuffle-AC GEMM with fp32 and K=256. max_idx = 2^21
(where max_idx * K * 4 = 2^31) crashes the kernel; max_idx = 2^21 - 1
works. The break point matches the overflow exactly.

Fix:
  * indices_ storage type: int32 -> int64
  * cast indice_ptr_[...] and thread_offset_[...] to int64_t before
    the `stride * nbytes` multiply, so the product is computed in 64
    bits
  * widen the host-side assertion guard
    (nvrtc_code.py / conv/main.py / conv/params.py) from INT32_MAX to
    INT64_MAX so the wrapper no longer refuses shapes the kernel can
    now handle
  * conv/sparse_iters.py: widen get_indice_offset() return type and
    get() arg type to int64_t to match; restructure the wgrad-out
    PTX-load path to load 32-bit values into an int mask_inds[]
    temporary first (the PTX store is 32-bit, which would otherwise
    leave the upper half of an int64 indices_[] slot uninitialized)

Also updated cumm/gemm/frozen/{__init__.py,mask_iters.py} which carry
a snapshot of the iterator code used by cumm/gemm/gather.py, keeping
the AOT gather utility consistent with the NVRTC path.

Validated end-to-end:
  * test/test_int64_large_gemm.py (new) -- shuffle-AC with N = 16M
    rows x K = 256 float32 (features = 16 GiB, 8x INT32_MAX). Probes
    straddle the 2^31 byte boundary and all match a PyTorch reference
    within fp32 tolerance (rel_err ~1e-6)
  * shuffle-AB sweep (backward-weight) up to 16M x 64 x 4 = 4 GiB,
    all OK
  * spconv SubMConv3d kernel=3 forward (cumm/conv `else` branch of
    update_indices): N = 4M, C = 256 -> features = 4 GiB, full-tensor
    rel_err = 5e-6 vs torch.nn.functional.conv3d reference
  * spconv SubMConv3d kernel=3 backward-weight (cumm/conv `is_wgrad_out`
    branch of update_indices): same shape, dWeight rel_err = 8.5e-5
    vs torch autograd reference -- confirms the restructured PTX-load
    path is correct
@ladvu ladvu force-pushed the int64-offset-fix branch from 609e8f0 to 3e149d6 Compare April 22, 2026 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant