Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,17 +418,17 @@ def instantiate_conv2d_template(attrs):
size_t workspace_size = conv2d_op.get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = conv2d_op.can_implement(arguments);
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_reset}
status = conv2d_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_update}

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());

status = conv2d_op(stream);
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_reduction}
"""

Expand All @@ -439,7 +439,7 @@ def instantiate_conv2d_template(attrs):
split_k_update = """
arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)};
status = conv2d_op.update(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
"""

split_k_reduction = """
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/contrib/cutlass/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,15 @@ def instantiate_gemm_template(attrs):
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
${kernel} gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
status = gemm_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());

status = gemm_op(stream);
CHECK(status == cutlass::Status::kSuccess);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
"""
op_type = attrs["op_type"]
has_bias = "bias" in op_type
Expand Down
17 changes: 7 additions & 10 deletions src/relax/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,28 @@ class CodegenCBase {
code_stream_ << "#ifdef __cplusplus\n";
code_stream_ << "extern \"C\" {\n";
code_stream_ << "#endif\n";
code_stream_ << "TVM_DLL int32_t ";
code_stream_ << "TVM_FFI_DLL_EXPORT int32_t ";
code_stream_ << func_name << "(";
code_stream_ << "TVMValue* args, ";
code_stream_ << "int* type_code, ";
code_stream_ << "int num_args, ";
code_stream_ << "TVMValue* out_value, ";
code_stream_ << "int* out_type_code) {\n";
code_stream_ << "tvm::ffi::PackedArgs args, ";
code_stream_ << "tvm::ffi::AnyView* out_value) {\n";
}

/*!
* \brief Adds a line to convert TVMValue args to DLTensors
* \brief Adds a line to convert tvm::ffi::PackedArgs args to DLTensors
*/
void PrintArgToData(int idx) {
PrintIndents();
code_stream_ << "DLTensor* arg" << idx << " = ";
code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n";
code_stream_ << "(DLTensor*)(args[" << idx << "].cast<DLTensor*>());\n";
}

/*!
* \brief Adds a line to convert TVMValue rets to DLTensors
* \brief Adds a line to convert tvm::ffi::PackedArgs rets to DLTensors
*/
void PrintRetToData(int idx) {
PrintIndents();
code_stream_ << "DLTensor* ret" << idx << " = ";
code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n";
code_stream_ << "(DLTensor*)(args[" << idx << "].cast<DLTensor*>());\n";
}

/*!
Expand Down
Loading