Skip to content

Conversation

@andy-yang-1
Copy link
Contributor

This PR introduces a new pass InjectPTXLDG32 to change the if_then_else call node to ptx_pred_ldg32 call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern.

Test the pass with

with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}): 
    mod = tvm.build(f, target="cuda")

@tvm-bot
Copy link
Collaborator

tvm-bot commented Feb 13, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@Hzfengsy
Copy link
Member

cc @spectrometerHBH

@masahi
Copy link
Member

masahi commented Feb 13, 2023

I hope TIR BufferLoad would natively support predication, rather than relying on intrinsics. See also https://discuss.tvm.apache.org/t/huge-pr-affecting-buffer-access-semantics-landed/12261/10. cc @wrongtest-intellif @vinx13 @junrushao

Comment on lines 948 to 957
this->stream << "asm volatile (\n" ;
this->PrintIndent();
stream << "\"{.reg .pred p;\\n\"\n" ;
this->PrintIndent();
stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n" ;
this->PrintIndent();
stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
this->PrintIndent();
stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n" ;
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
this->PrintIndent();
stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n" ;
this->PrintIndent();
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" << guard << ")\n" ;
this->PrintIndent();
stream << ");\n" ;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you may use multi-line string in C++


// The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it
TVM_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you may use clang-format to somehow organize the file slightly better

TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_pred_ldg32", Bool);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name is a bit confusing, can you discuss with @rainy-memory and figure out together something more comprehensible?

our key objective is that users may need to set at most one flag (zero is the best if possible) so that they could deliver the best GEMM performance out of the box

@junrushao
Copy link
Member

Let's fix the lint and merge it in asap. If you don't like that pylint claims about variable naming, just do:

# pylint: disable=invalid-name
you code
# pylint: enable=invalid-name

* \brief tvm intrinsic for ptx predicate load with 32-bit data type.
*
*/
TVM_DLL const Op& inject_ptx_ldg32();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming: we do not need inject prefix as it can just be ptx_ldg32

@junrushao junrushao self-assigned this Feb 17, 2023
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.

Test the pass with:

```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
    mod = tvm.build(f, target="cuda")
````
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.

Test the pass with:

```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
    mod = tvm.build(f, target="cuda")
````
Comment on lines +949 to +959
this->stream << "asm volatile (\n";
this->stream << "\"{.reg .pred p;\\n\"\n";
this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
<< ")\n";
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
<< guard << ")\n";
stream << ");\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.

Test the pass with:

```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
    mod = tvm.build(f, target="cuda")
````
Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my comments have been address, will let @junrushao handle this

@junrushao
Copy link
Member

@andy-yang-1 please fix the unittests and we are good to go

@junrushao junrushao merged commit 87bb8b1 into apache:main Feb 18, 2023
yongwww pushed a commit to yongwww/tvm that referenced this pull request Feb 27, 2023
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else` call node to `ptx_pred_ldg32` call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern.

Test the pass with

```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}): 
    mod = tvm.build(f, target="cuda")
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants