Skip to content

[BUG] Inconsistent capture of python loop variables during loop-unrolling #1139

@gd193

Description

@gd193

Bug Description

Hello,

I would like to report some, in my view, inconsistent behavior in kernel loop unrolling with warp.static.

The problem is that when using for-loops within a warp.kernel, wp.static will capture the loop variable of the global python context instead of the correct variable.

import warp as wp
from warp.jax_experimental.ffi import jax_kernel
from warp.jax_experimental.ffi import jax_callable

wp.set_device("cuda:0")
wp.config.max_unroll = 128
wp.clear_kernel_cache()


wp_dtype = wp.float64
in_array = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=wp_dtype)

def loop_limit():
    return 5

@wp.kernel
def dummy_kernel(array1: wp.array1d(dtype=wp_dtype),):
    tid = wp.tid()
    ###alloc fixed array local
    loc_array = wp.zeros(dtype=wp_dtype, shape=(5,))

    for i in range(0, wp.static(loop_limit())):
        wp.printf("test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: %d, %d, %d, %d, %d\n", tid, i, i+tid, wp.static(i), wp.static(i) + tid)

wp.launch(dummy_kernel, dim=1, inputs=[in_array])

#### prints 
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 0, 0, 0, 0
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 1, 1, 1, 1
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 2, 2, 2, 2
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 3, 3, 3, 3
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 4, 4, 4, 4
####


########### I am an unrelated python loop out of kernel scope ############
for i in range(5):
    print("Hello from python loop iteration ", i)
##########################################################################

@wp.kernel
def dummy_kernel2(array1: wp.array1d(dtype=wp_dtype),):
    tid = wp.tid()
    ###alloc fixed array local
    loc_array = wp.zeros(dtype=wp_dtype, shape=(5,))

    for i in range(0, wp.static(loop_limit())):
        wp.printf("test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: %d, %d, %d, %d, %d\n", tid, i, i+tid, wp.static(i), wp.static(i) + tid)

wp.launch(dummy_kernel2, dim=1, inputs=[in_array])

#### prints 
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 0, 0, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 1, 1, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 2, 2, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 3, 3, 4, 4
test: threadid, stride, i, i+tid, wp.static(i), wp.static(i)+tid: 0, 4, 4, 4, 4
####

I think wp.static should capture the i inside the kernel for-loop to enable correct static loop unrolling as shown in https://nvidia.github.io/warp/codegen.html#example-static-loop-unrolling.

System Information

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions