-
Notifications
You must be signed in to change notification settings - Fork 410
Open
Description
Bug Description
Hello,
I would like to report some, in my view, inconsistent behavior in kernel loop unrolling with warp.static.
The problem is that when using for-loops within a warp.kernel, wp.static will capture the loop variable of the global python context instead of the correct variable.
import warp as wp
from warp.jax_experimental.ffi import jax_kernel
from warp.jax_experimental.ffi import jax_callable
wp.set_device("cuda:0")
wp.config.max_unroll = 128
wp.clear_kernel_cache()
wp_dtype = wp.float64
in_array = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=wp_dtype)
def loop_limit():
return 5
@wp.kernel
def dummy_kernel(array1: wp.array1d(dtype=wp_dtype),):
tid = wp.tid()
###alloc fixed array local
loc_array = wp.zeros(dtype=wp_dtype, shape=(5,))
for i in range(0, wp.static(loop_limit())):
wp.printf("test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: %d, %d, %d, %d, %d\n", tid, i, i+tid, wp.static(i), wp.static(i) + tid)
wp.launch(dummy_kernel, dim=1, inputs=[in_array])
#### prints
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 0, 0, 0, 0
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 1, 1, 1, 1
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 2, 2, 2, 2
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 3, 3, 3, 3
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 4, 4, 4, 4
####
########### I am an unrelated python loop out of kernel scope ############
for i in range(5):
print("Hello from python loop iteration ", i)
##########################################################################
@wp.kernel
def dummy_kernel2(array1: wp.array1d(dtype=wp_dtype),):
tid = wp.tid()
###alloc fixed array local
loc_array = wp.zeros(dtype=wp_dtype, shape=(5,))
for i in range(0, wp.static(loop_limit())):
wp.printf("test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: %d, %d, %d, %d, %d\n", tid, i, i+tid, wp.static(i), wp.static(i) + tid)
wp.launch(dummy_kernel2, dim=1, inputs=[in_array])
#### prints
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 0, 0, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 1, 1, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 2, 2, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 3, 3, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 4, 4, 4, 4
####
I think wp.static should capture the i inside the kernel for-loop to enable correct static loop unrolling as shown in https://nvidia.github.io/warp/codegen.html#example-static-loop-unrolling.
System Information
No response
coderabbitai
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working