-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TIR] Introduce Pass InjectPTXLDG32 #13973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
I hope TIR |
src/target/source/codegen_cuda.cc
Outdated
| this->stream << "asm volatile (\n" ; | ||
| this->PrintIndent(); | ||
| stream << "\"{.reg .pred p;\\n\"\n" ; | ||
| this->PrintIndent(); | ||
| stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n" ; | ||
| this->PrintIndent(); | ||
| stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; | ||
| this->PrintIndent(); | ||
| stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n" ; | ||
| // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; | ||
| this->PrintIndent(); | ||
| stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n" ; | ||
| this->PrintIndent(); | ||
| stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" << guard << ")\n" ; | ||
| this->PrintIndent(); | ||
| stream << ");\n" ; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you may use multi-line string in C++
|
|
||
| // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it | ||
| TVM_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you may use clang-format to somehow organize the file slightly better
src/driver/driver_api.cc
Outdated
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_pred_ldg32", Bool); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the name is a bit confusing, can you discuss with @rainy-memory and figure out together something more comprehensible?
our key objective is that users may need to set at most one flag (zero is the best if possible) so that they could deliver the best GEMM performance out of the box
|
Let's fix the lint and merge it in asap. If you don't like that pylint claims about variable naming, just do: # pylint: disable=invalid-name
you code
# pylint: enable=invalid-name |
include/tvm/tir/builtin.h
Outdated
| * \brief tvm intrinsic for ptx predicate load with 32-bit data type. | ||
| * | ||
| */ | ||
| TVM_DLL const Op& inject_ptx_ldg32(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming: we do not need inject prefix as it can just be ptx_ldg32
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.
Test the pass with:
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
````
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.
Test the pass with:
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
````
| this->stream << "asm volatile (\n"; | ||
| this->stream << "\"{.reg .pred p;\\n\"\n"; | ||
| this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n"; | ||
| this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; | ||
| this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n"; | ||
| // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; | ||
| stream << ": \"=f\"(" << reg << "[" << local_addr << "]" | ||
| << ")\n"; | ||
| stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" | ||
| << guard << ")\n"; | ||
| stream << ");\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps it would be clearer to write this way:
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.
Test the pass with:
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
````
tqchen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my comments have been address, will let @junrushao handle this
|
@andy-yang-1 please fix the unittests and we are good to go |
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else` call node to `ptx_pred_ldg32` call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern.
Test the pass with
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
```
This PR introduces a new pass InjectPTXLDG32 to change the
if_then_elsecall node toptx_pred_ldg32call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern.Test the pass with