gemm/conv: promote indices_ gather buffer from int32 to int64#36
Open
ladvu wants to merge 1 commit intoFindDefinition:mainfrom
Open
gemm/conv: promote indices_ gather buffer from int32 to int64#36ladvu wants to merge 1 commit intoFindDefinition:mainfrom
ladvu wants to merge 1 commit intoFindDefinition:mainfrom
Conversation
The MaskTileIteratorGather and ForwardDgradSparseIOIterator classes
store per-row gather offsets computed as `idx * stride * sizeof(dtype)`
in a fixed int32 array. The multiplication is done in int32 and
overflows once `idx * stride * nbytes >= 2^31`, producing illegal
memory accesses or silent wrong results.
Minimal reproducer: shuffle-AC GEMM with fp32 and K=256. max_idx = 2^21
(where max_idx * K * 4 = 2^31) crashes the kernel; max_idx = 2^21 - 1
works. The break point matches the overflow exactly.
Fix:
* indices_ storage type: int32 -> int64
* cast indice_ptr_[...] and thread_offset_[...] to int64_t before
the `stride * nbytes` multiply, so the product is computed in 64
bits
* widen the host-side assertion guard
(nvrtc_code.py / conv/main.py / conv/params.py) from INT32_MAX to
INT64_MAX so the wrapper no longer refuses shapes the kernel can
now handle
* conv/sparse_iters.py: widen get_indice_offset() return type and
get() arg type to int64_t to match; restructure the wgrad-out
PTX-load path to load 32-bit values into an int mask_inds[]
temporary first (the PTX store is 32-bit, which would otherwise
leave the upper half of an int64 indices_[] slot uninitialized)
Also updated cumm/gemm/frozen/{__init__.py,mask_iters.py} which carry
a snapshot of the iterator code used by cumm/gemm/gather.py, keeping
the AOT gather utility consistent with the NVRTC path.
Validated end-to-end:
* test/test_int64_large_gemm.py (new) -- shuffle-AC with N = 16M
rows x K = 256 float32 (features = 16 GiB, 8x INT32_MAX). Probes
straddle the 2^31 byte boundary and all match a PyTorch reference
within fp32 tolerance (rel_err ~1e-6)
* shuffle-AB sweep (backward-weight) up to 16M x 64 x 4 = 4 GiB,
all OK
* spconv SubMConv3d kernel=3 forward (cumm/conv `else` branch of
update_indices): N = 4M, C = 256 -> features = 4 GiB, full-tensor
rel_err = 5e-6 vs torch.nn.functional.conv3d reference
* spconv SubMConv3d kernel=3 backward-weight (cumm/conv `is_wgrad_out`
branch of update_indices): same shape, dWeight rel_err = 8.5e-5
vs torch autograd reference -- confirms the restructured PTX-load
path is correct
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The MaskTileIteratorGather and ForwardDgradSparseIOIterator classes store per-row gather offsets computed as
idx * stride * sizeof(dtype)in a fixed int32 array. The multiplication is done in int32 and overflows onceidx * stride * nbytes >= 2^31, producing illegal memory accesses or silent wrong results.Minimal reproducer: shuffle-AC GEMM with fp32 and K=256. max_idx = 2^21 (where max_idx * K * 4 = 2^31) crashes the kernel; max_idx = 2^21 - 1 works. The break point matches the overflow exactly.
Fix:
stride * nbytesmultiply, so the product is computed in 64 bitsAlso updated cumm/gemm/frozen/{init.py,mask_iters.py} which carry a snapshot of the iterator code used by cumm/gemm/gather.py, keeping the AOT gather utility consistent with the NVRTC path.
Validated: