From e71e3bcfb2492ff99a27fb3b8e2896a56e6d1731 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 3 Mar 2025 18:18:56 -0600 Subject: [PATCH 1/6] _get_f_placeholder_args: set ForceValueArgTag --- arraycontext/impl/pytato/compile.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e77c1091..7a9389a2 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): :attr:`BaseLazilyCompilingFunctionCaller.f`. """ if np.isscalar(arg): + from pytato.tags import ForceValueArgTag name = arg_id_to_name[kw,] - return pt.make_placeholder(name, (), np.dtype(type(arg))) + return pt.make_placeholder(name, (), np.dtype(type(arg)), + tags=frozenset({ForceValueArgTag()})) elif isinstance(arg, pt.Array): name = arg_id_to_name[kw,] # Transform the DAG to give metadata inference a chance to do its job From 9114f8c67dd9d243ad79b011cc1d13a6d2d6f6d7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 4 Mar 2025 09:17:08 -0600 Subject: [PATCH 2/6] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a4cb4025..cab9efd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/inducer/pytato.git@valuearg-placeholder#egg=pytato From a2e322e949e115c5f6ff72a0f811884fa0fefd69 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 5 Mar 2025 16:01:25 -0600 Subject: [PATCH 3/6] skip scalar arg handling in _args_to_device_buffers --- arraycontext/impl/pytato/compile.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 7a9389a2..90449f0a 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -535,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): for arg_id, arg in arg_id_to_arg.items(): if np.isscalar(arg): if isinstance(actx, PytatoPyOpenCLArrayContext): - import pyopencl.array as cla - arg = cla.to_device(actx.queue, np.array(arg), - allocator=actx.allocator) + # Scalar kernel args are passed as lp.ValueArgs + pass elif isinstance(actx, PytatoJAXArrayContext): import jax arg = jax.device_put(arg) From 0cd3d564a7e7621115b73396117873ace1edbd21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= Date: Tue, 18 Mar 2025 17:09:03 -0500 Subject: [PATCH 4/6] Revert changes to requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cab9efd5..a4cb4025 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git@valuearg-placeholder#egg=pytato +git+https://github.com/inducer/pytato.git#egg=pytato From cbebe4537a86068cf6deb9677930ae5dcb54f774 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 18 Mar 2025 18:00:57 -0500 Subject: [PATCH 5/6] add a simple test --- test/test_pytato_arraycontext.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index a4050380..9611421c 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -247,6 +247,23 @@ def test_transfer(actx_factory): # }}} +def test_pass_args_compiled_func(actx_factory): + import numpy as np + + def twice(x, y, a): + return 2 * x * y * a + + actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue) + + import pyopencl.array as cl_array + cl_ary = cl_array.to_device(actx.queue, np.float64(23)) + + f = actx.compile(twice) + + with pytest.raises(ValueError): + f(99.0, np.float64(2.0), cl_ary) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: From 8b67e413bacebca5c08da25d7d91b8f566e6ba54 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 19 Mar 2025 15:29:53 -0500 Subject: [PATCH 6/6] Fix test --- test/test_pytato_arraycontext.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index 9611421c..deee7405 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -250,18 +250,28 @@ def test_transfer(actx_factory): def test_pass_args_compiled_func(actx_factory): import numpy as np + import loopy as lp + import pyopencl as cl + import pyopencl.array + import pytato as pt + def twice(x, y, a): return 2 * x * y * a actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue) - import pyopencl.array as cl_array - cl_ary = cl_array.to_device(actx.queue, np.float64(23)) + dev_scalar = pt.make_data_wrapper(cl.array.to_device(actx.queue, np.float64(23))) f = actx.compile(twice) - with pytest.raises(ValueError): - f(99.0, np.float64(2.0), cl_ary) + assert actx.to_numpy(f(99.0, np.float64(2.0), dev_scalar)) == 2*23*99*2 + + compiled_func, = f.program_cache.values() + ep = compiled_func.pytato_program.program.t_unit.default_entrypoint + + assert isinstance(ep.arg_dict["_actx_in_0"], lp.ValueArg) + assert isinstance(ep.arg_dict["_actx_in_1"], lp.ValueArg) + assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg) if __name__ == "__main__":