Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cumm/conv/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,20 +561,20 @@ def implicit_gemm2(self):
if p.op_type == ConvOpType.kBackwardWeight:
code.raw(f"""
TV_ASSERT_RT_ERR(N == output.dim(0), "error");
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_b.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_b.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
""")
elif p.op_type == ConvOpType.kForward:
code.raw(f"""
TV_ASSERT_RT_ERR(N == output.dim(0), "error");
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
""")
else:
code.raw(f"""
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
TV_ASSERT_RT_ERR(N == input.dim(0), "error");
""")
Expand Down Expand Up @@ -816,20 +816,20 @@ def implicit_gemm2_deprecated(self):
if p.op_type == ConvOpType.kBackwardWeight:
code.raw(f"""
TV_ASSERT_RT_ERR(N == output.dim(0), "error");
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_b.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_b.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
""")
elif p.op_type == ConvOpType.kForward:
code.raw(f"""
TV_ASSERT_RT_ERR(N == output.dim(0), "error");
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(C) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
""")
else:
code.raw(f"""
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int32_t>::max(),
TV_ASSERT_RT_ERR(int64_t(N) * int64_t(K) * {ker.dtype_a.bitsize()} / 8 < std::numeric_limits<int64_t>::max(),
"your data exceed int32 range. this will be fixed in cumm + nvrtc (spconv 2.2/2.3).");
TV_ASSERT_RT_ERR(N == input.dim(0), "error");
""")
Expand Down
2 changes: 1 addition & 1 deletion cumm/conv/nvrtc_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def nvrtc_conv_template(code: pccm.FunctionCode):
sp_kernel_params.act_beta = kernel_params.act_beta;
sp_kernel_params.act_type = kernel_params.act_type;

constexpr int int_max = std::numeric_limits<int32_t>::max();
constexpr int64_t int_max = std::numeric_limits<int64_t>::max();

if (algo_desp.mask_sparse){{
if (algo_desp.op_type == tv::gemm::ConvOpType::kBackwardWeight){{
Expand Down
2 changes: 1 addition & 1 deletion cumm/conv/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def check_npq_not_overflow(self):
lines: List[str] = []
for i in range(self.ndim + 1):
lines.append(f"int64_t(shape[{i}])")
code.raw("std::abs(" + " * ".join(lines) + ") <= std::numeric_limits<int>::max()")
code.raw("std::abs(" + " * ".join(lines) + ") <= std::numeric_limits<int64_t>::max()")
code.raw(");")
code.ret("bool")
return code
Expand Down
59 changes: 30 additions & 29 deletions cumm/conv/sparse_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,11 @@ def __init__(self,
f"tv::array<uint32_t, {self.access_per_vector}>")

# self.add_member("filter_kernel_idxes_", f"tv::array<int, {self.ndim}>")
# indices_ stores gather offsets in bytes. For large inputs the product
# index * channel * nbytes can exceed INT32_MAX, so use int64_t.
self.add_member(
"indices_",
str(dtypes.int32),
str(dtypes.int64),
array=f"[{self.tmap.iterations[0] * self.sub_tile_shape[0]}]")

def get_params(self) -> pccm.ParameterizedClass:
Expand Down Expand Up @@ -282,16 +284,19 @@ def update_indices(self):
code = pccm.cuda.PTXCode()
C_or_K = "C" if self.op_type == ConvOpType.kForward else "K"
if self.is_wgrad_out:
# if False:
# wgrad out only need shuffle.
# wgrad out only need shuffle. PTX loads 32-bit values into a
# temporary int mask_inds[] array (indices_ is int64 now).
code.raw(
f"int mask_inds[{self.tmap.iterations[0] * self.sub_tile_shape[0]}];"
)
for s in range(self.tmap.iterations[0]):
for ss in range(self.sub_tile_shape[0]):
code.raw(f"uint32_t pred{s}_{ss};")
code.raw(
f"pred{s}_{ss} = mask_[0] & (1u << ({s * self.sub_tile_shape[0] * self.tmap.iterations[1]} + {ss}));"
)
with code.asm_block() as asm:
mask_ptr = asm.reg_ptr("indices_", RegDType.B32)
mask_ptr = asm.reg_ptr("mask_inds", RegDType.B32)
pred_ptr = asm.ext_reg(f"pred{s}_{ss}", RegDType.B32)
mask_arg_ptr = asm.global_ptr(
"params_.mask_argsort_ptr_")
Expand All @@ -300,16 +305,13 @@ def update_indices(self):
mask_arg_ptr +
(s * self.tmap.delta[0] + ss) * 4,
mask_ptr[s * self.sub_tile_shape[0] + ss])
# code.raw(f"""
# indices_[{s * self.sub_tile_shape[0] + ss}] = pred{s}_{ss} ? params_.mask_argsort_ptr_[{(s * self.tmap.delta[0] + ss)}] : 0;

# """)
code.raw(f"""
TV_PRAGMA_UNROLL
for (int s = 0; s < {self.tmap.iterations[0]}; ++s){{
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
indices_[s * {self.sub_tile_shape[0]} + ss] = indices_[s * {self.sub_tile_shape[0]} + ss] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(mask_inds[s * {self.sub_tile_shape[0]} + ss]) *
problem_.K * {self.dtype.nbytes_str()} ;
}}
}}
Expand All @@ -334,10 +336,6 @@ def update_indices(self):
mask_arg_ptr +
(s * self.tmap.delta[0] + ss) * 4,
mask_ptr[s * self.sub_tile_shape[0] + ss])
# code.raw(f"""
# mask_inds[{s * self.sub_tile_shape[0] + ss}] = pred ? params_.mask_argsort_ptr_[{(s * self.tmap.delta[0] + ss)}] : 0;

# """)

if self.is_wgrad_input:
C_or_K = "C"
Expand All @@ -347,8 +345,8 @@ def update_indices(self):
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
if (mask_[0] & (1u << (s * {self.sub_tile_shape[0] * self.tmap.iterations[1]} + ss))){{
indices_[s * {self.sub_tile_shape[0]} + ss] =
indice_ptr_[mask_inds[s * {self.sub_tile_shape[0]} + ss]] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(indice_ptr_[mask_inds[s * {self.sub_tile_shape[0]} + ss]]) *
problem_.{C_or_K} * {self.dtype.nbytes_str()} ;
}}
}}
Expand Down Expand Up @@ -592,12 +590,12 @@ def get_indice_offset(self):
code.raw(f"""
return indices_[stride * {self.sub_tile_shape[0]} + ss];
""")
return code.ret(f"int")
return code.ret(f"int64_t")

@pccm.cuda.member_function(device=True, forceinline=True, const=True)
def get(self):
code = FunctionCode()
code.arg("indice_offset", f"int")
code.arg("indice_offset", f"int64_t")
code.raw(f"""
return reinterpret_cast<{self.const_access_pointer}>( pointer_ + indice_offset);
""")
Expand Down Expand Up @@ -778,9 +776,11 @@ def __init__(self,
self.add_member("mask_", f"Mask")

# self.add_member("filter_kernel_idxes_", f"tv::array<int, {self.ndim}>")
# indices_ stores gather offsets in bytes. For large inputs the product
# index * channel * nbytes can exceed INT32_MAX, so use int64_t.
self.add_member(
"indices_",
str(dtypes.int32),
str(dtypes.int64),
array=f"[{self.tmap.iterations[0] * self.sub_tile_shape[0]}]")

def get_params(self) -> pccm.ParameterizedClass:
Expand Down Expand Up @@ -872,19 +872,19 @@ def update_indices(self):
code = pccm.cuda.PTXCode()
C_or_K = "C" if self.op_type == ConvOpType.kForward else "K"
if self.is_wgrad_out:
# if False:
# wgrad out only need shuffle.
# wgrad out only need shuffle. PTX loads 32-bit values into a
# temporary int mask_inds[] array (indices_ is int64 now).
code.raw(
f"int mask_inds[{self.tmap.iterations[0] * self.sub_tile_shape[0]}];"
)
for s in range(self.tmap.iterations[0]):
for ss in range(self.sub_tile_shape[0]):
code.raw(f"uint32_t pred{s}_{ss};")
# code.raw(
# f"pred{s}_{ss} = mask_[0] & (1u << ({s} * {self.sub_tile_shape[0]} + {ss}));"
# )
code.raw(
f"pred{s}_{ss} = mask_.query_coord({s}, 0, {ss}, 0);")

with code.asm_block() as asm:
mask_ptr = asm.reg_ptr("indices_", RegDType.B32)
mask_ptr = asm.reg_ptr("mask_inds", RegDType.B32)
pred_ptr = asm.ext_reg(f"pred{s}_{ss}", RegDType.B32)
mask_arg_ptr = asm.global_ptr(
"params_.mask_argsort_ptr_")
Expand All @@ -898,7 +898,8 @@ def update_indices(self):
for (int s = 0; s < {self.tmap.iterations[0]}; ++s){{
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
indices_[s * {self.sub_tile_shape[0]} + ss] = indices_[s * {self.sub_tile_shape[0]} + ss] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(mask_inds[s * {self.sub_tile_shape[0]} + ss]) *
problem_.K * {self.dtype.nbytes_str()} ;
}}
}}
Expand Down Expand Up @@ -933,8 +934,8 @@ def update_indices(self):
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
// if (mask_[0] & (1u << (s * {self.sub_tile_shape[0]} + ss)))
if (mask_.query_coord(s, 0, ss, 0)){{
indices_[s * {self.sub_tile_shape[0]} + ss] =
indice_ptr_[mask_inds[s * {self.sub_tile_shape[0]} + ss]] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(indice_ptr_[mask_inds[s * {self.sub_tile_shape[0]} + ss]]) *
problem_.{C_or_K} * {self.dtype.nbytes_str()} ;
}}
}}
Expand Down Expand Up @@ -1122,12 +1123,12 @@ def get_indice_offset(self):
code.raw(f"""
return indices_[stride * {self.sub_tile_shape[0]} + ss];
""")
return code.ret(f"int")
return code.ret(f"int64_t")

@pccm.cuda.member_function(device=True, forceinline=True, const=True)
def get(self):
code = FunctionCode()
code.arg("indice_offset", f"int")
code.arg("indice_offset", f"int64_t")
code.raw(f"""
return reinterpret_cast<{self.const_access_pointer}>( pointer_ + indice_offset);
""")
Expand Down
14 changes: 8 additions & 6 deletions cumm/gemm/frozen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def __init__(self,
str(dtypes.uint32),
array=f"[{self.num_pred_32}]")
if self.shuffle_in_stride:
# indices_ stores gather offsets in bytes. For large inputs the
# product can exceed INT32_MAX, so use int64_t.
self.add_member(
"indices_",
str(dtypes.int32),
str(dtypes.int64),
array=f"[{self.tmap.iterations[0] * self.sub_tile_shape[0]}]")

# cudasim members
Expand Down Expand Up @@ -338,9 +340,9 @@ def update_indices(self):
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
if (thread_offset_[0] + s * {self.tmap.delta[0]} + ss < extent_[0])
indices_[s * {self.sub_tile_shape[0]} + ss] =
params_.indice_ptr_[thread_offset_[0] +
s * {self.tmap.delta[0]} + ss] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(params_.indice_ptr_[thread_offset_[0] +
s * {self.tmap.delta[0]} + ss]) *
params_.stride_ * {self.dtype.bitsize()} / 8;
else{{
indices_[s * {self.sub_tile_shape[0]} + ss] = 0;
Expand All @@ -363,8 +365,8 @@ def update_indices_identity(self):
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
if (thread_offset_[0] + s * {self.tmap.delta[0]} + ss < extent_[0])
indices_[s * {self.sub_tile_shape[0]} + ss] =
(thread_offset_[0] + s * {self.tmap.delta[0]} + ss) *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(thread_offset_[0] + s * {self.tmap.delta[0]} + ss) *
params_.stride_ * {self.dtype.bitsize()} / 8;
else{{
indices_[s * {self.sub_tile_shape[0]} + ss] = 0;
Expand Down
14 changes: 8 additions & 6 deletions cumm/gemm/frozen/mask_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def __init__(self,
str(dtypes.uint32),
array=f"[{self.num_pred_32}]")
if self.shuffle_in_stride:
# indices_ stores gather offsets in bytes. For large inputs the
# product can exceed INT32_MAX, so use int64_t.
self.add_member(
"indices_",
str(dtypes.int32),
str(dtypes.int64),
array=f"[{self.tmap.iterations[0] * self.sub_tile_shape[0]}]")

# cudasim members
Expand Down Expand Up @@ -338,9 +340,9 @@ def update_indices(self):
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
if (thread_offset_[0] + s * {self.tmap.delta[0]} + ss < extent_[0])
indices_[s * {self.sub_tile_shape[0]} + ss] =
params_.indice_ptr_[thread_offset_[0] +
s * {self.tmap.delta[0]} + ss] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(params_.indice_ptr_[thread_offset_[0] +
s * {self.tmap.delta[0]} + ss]) *
params_.stride_ * {self.dtype.bitsize()} / 8;
else{{
indices_[s * {self.sub_tile_shape[0]} + ss] = 0;
Expand All @@ -363,8 +365,8 @@ def update_indices_identity(self):
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
if (thread_offset_[0] + s * {self.tmap.delta[0]} + ss < extent_[0])
indices_[s * {self.sub_tile_shape[0]} + ss] =
(thread_offset_[0] + s * {self.tmap.delta[0]} + ss) *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(thread_offset_[0] + s * {self.tmap.delta[0]} + ss) *
params_.stride_ * {self.dtype.bitsize()} / 8;
else{{
indices_[s * {self.sub_tile_shape[0]} + ss] = 0;
Expand Down
11 changes: 7 additions & 4 deletions cumm/gemm/mask_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,12 @@ def __init__(self,
str(dtypes.uint32),
array=f"[{self.num_pred_32}]")
if self.shuffle_in_stride:
# indices_ stores gather-offsets in bytes. For large inputs the
# product indice_ptr_[i] * stride_ * sizeof(dtype) can exceed
# INT32_MAX, so use int64_t.
self.add_member(
"indices_",
str(dtypes.int32),
str(dtypes.int64),
array=f"[{self.tmap.iterations[0] * self.sub_tile_shape[0]}]")

# cudasim members
Expand Down Expand Up @@ -978,9 +981,9 @@ def update_indices(self):
TV_PRAGMA_UNROLL
for (int ss = 0; ss < {self.sub_tile_shape[0]}; ++ss){{
if (thread_offset_[0] + s * {self.tmap.delta[0]} + ss < extent_[0])
indices_[s * {self.sub_tile_shape[0]} + ss] =
params_.indice_ptr_[thread_offset_[0] +
s * {self.tmap.delta[0]} + ss] *
indices_[s * {self.sub_tile_shape[0]} + ss] =
int64_t(params_.indice_ptr_[thread_offset_[0] +
s * {self.tmap.delta[0]} + ss]) *
params_.stride_ * {self.dtype.nbytes_str()};
else{{
indices_[s * {self.sub_tile_shape[0]} + ss] = 0;
Expand Down
2 changes: 1 addition & 1 deletion cumm/gemm/nvrtc_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def nvrtc_gemm_template(code: pccm.FunctionCode):

}}
int m, n, k, k2;
constexpr int int_max = std::numeric_limits<int32_t>::max();
constexpr int64_t int_max = std::numeric_limits<int64_t>::max();
if (algo_desp.shuffle_type == tv::gemm::ShuffleStrideType::kShuffleAC){{
TV_ASSERT_RT_ERR(!trans_a, "a of shuffle AB must be row major");
if (!a_inds.empty()){{
Expand Down
Loading