Skip to content

Conversation

@tqchen
Copy link
Member

@tqchen tqchen commented Mar 26, 2018

This PR enables the warp memory abstraction. "warp" is a special memory scope that represents shared memory among the threads in the warp.

From the programmer's perspective, warp memory is exactly like shared memory, but it is instead lowered into local registers and its reads becomes shuffle instructions in typical GPUs. See the

So far cuda works with a simple exmaple. to enable it for other backends

  • Create a target with correct warp_size
  • always set threadIdx.x's size to be warp size
  • Create an intrinsic dispatching rule for tvm_warp_shuffle to dispatch to the right intrinsic

@tqchen
Copy link
Member Author

tqchen commented Mar 26, 2018

cc @eqy @Laurawly

@tqchen tqchen merged commit bbe4974 into apache:master Mar 26, 2018
tqchen added a commit to tqchen/tvm that referenced this pull request Jul 6, 2018
* [SCHEDULE][PASS] Enable Warp memory and lower to shuffle

* OpenCL dispatches for now to intel shuffle
sergei-mironov pushed a commit to sergei-mironov/tvm that referenced this pull request Aug 8, 2018
* [SCHEDULE][PASS] Enable Warp memory and lower to shuffle

* OpenCL dispatches for now to intel shuffle
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant