Skip to content

Support fancy iterators in cuda.parallel#2788

Merged
gevtushenko merged 125 commits intoNVIDIA:mainfrom
rwgk:fancy_iterators
Dec 6, 2024
Merged

Support fancy iterators in cuda.parallel#2788
gevtushenko merged 125 commits intoNVIDIA:mainfrom
rwgk:fancy_iterators

Conversation

@rwgk
Copy link
Copy Markdown
Contributor

@rwgk rwgk commented Nov 13, 2024

Description

closes #2479

closes #2480

closes #2536

Partially done: #2481

rwgk added 30 commits October 15, 2024 14:13
…t then fails with: Fatal Python error: Floating point exception
…resolves the Floating point exception (but the `cccl_device_reduce()` call still does not succeed)
LOOOK single_tile_kernel CALL /home/coder/cccl/c/parallel/src/reduce.cu:116

LOOOK EXCEPTION CUDA error: invalid argument  /home/coder/cccl/c/parallel/src/reduce.cu:703
…rametrize: `use_numpy_array`: `[True, False]`, `input_generator`: `["constant", "counting", "arbitrary", "nested"]`
…iterators.py (because numba.cuda cannot JIT classes).
… `unary_op`, which is then compiled with `numba.cuda.compile()`
… the `"map_mul2"` test and the added `"map_add10_map_mul2"` test works, too.
rwgk added 4 commits December 5, 2024 07:33
…o make it obvious that they are never used as Python methods, but exclusively as source for `numba.cuda.compile()`
@rwgk
Copy link
Copy Markdown
Contributor Author

rwgk commented Dec 5, 2024

/ok to test

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Dec 5, 2024

🟩 CI finished in 39m 02s: Pass: 100%/3 | Total: 36m 23s | Avg: 12m 07s | Max: 27m 08s
  • 🟩 cccl_c_parallel: Pass: 100%/2 | Total: 9m 15s | Avg: 4m 37s | Max: 7m 02s

    🟩 cpu
      🟩 amd64              Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 ctk
      🟩 12.6               Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 cudacxx
      🟩 nvcc12.6           Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 cudacxx_family
      🟩 nvcc               Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 cxx
      🟩 GCC13              Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 cxx_family
      🟩 GCC                Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 gpu
      🟩 v100               Pass: 100%/2   | Total:  9m 15s | Avg:  4m 37s | Max:  7m 02s
    🟩 jobs
      🟩 Build              Pass: 100%/1   | Total:  2m 13s | Avg:  2m 13s | Max:  2m 13s
      🟩 Test               Pass: 100%/1   | Total:  7m 02s | Avg:  7m 02s | Max:  7m 02s
    
  • 🟩 python: Pass: 100%/1 | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s

    🟩 cpu
      🟩 amd64              Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 ctk
      🟩 12.6               Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 cudacxx
      🟩 nvcc12.6           Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 cudacxx_family
      🟩 nvcc               Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 cxx
      🟩 GCC13              Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 cxx_family
      🟩 GCC                Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 gpu
      🟩 v100               Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    🟩 jobs
      🟩 Test               Pass: 100%/1   | Total: 27m 08s | Avg: 27m 08s | Max: 27m 08s
    

👃 Inspect Changes

Modifications in project?

Project
CCCL Infrastructure
libcu++
CUB
Thrust
CUDA Experimental
+/- python
+/- CCCL C Parallel Library
Catch2Helper

Modifications in project or dependencies?

Project
CCCL Infrastructure
libcu++
CUB
Thrust
CUDA Experimental
+/- python
+/- CCCL C Parallel Library
Catch2Helper

🏃‍ Runner counts (total jobs: 3)

# Runner
2 linux-amd64-gpu-v100-latest-1
1 linux-amd64-cpu16

Copy link
Copy Markdown
Contributor

@shwina shwina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good! Mainly I have some nits, but one relatively important issue is requiring the output type in TransformIterator. We can choose to punt that to #3064 if needed.



class ConstantIterator:
def __init__(self, val, ntype):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ntype->dtype would be better. The use of Numba should be an implementation detail from the user's perspective. Alternately, we could just accept a typed scalar like ConstantIterator(np.int32(0)).

Ditto for CountingIterator.

Comment on lines +245 to +250
def count_advance(this, diff):
this[0] += diff


def count_dereference(this):
return this[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming to think about it, it might be better to make these @staticmethod. After all:

$ python -c "import this"  | grep Namespace
Namespaces are one honking great idea -- let's do more of those!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit c3c51a5

I added comments:

# Exclusively for numba.cuda.compile (this is not an actual method).

My thinking:

Not adding a decorator (as we had originally), people will think it's a bound method, but wonder why self is called this.

Explicitly adding @staticmethod will make people believe it really is a static method, but that's not actually true.

Being explicit in the comment is only slightly more verbose than adding a decorator but much more informative.

Copy link
Copy Markdown
Contributor

@shwina shwina Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly adding @staticmethod will make people believe it really is a static method, but that's not actually true.

I don't think I understand. If anything, adding @staticmethod will make it even more obvious to the reader that the function is independent of the class. Typically functions that have no dependency on the class or its members, but are otherwise related to it are defined as @staticmethod.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, these are truly staticmethods in every sense

return self.it.alignment # TODO fix for stateful op


def TransformIterator(op, it, op_return_ntype):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should require the op_return_ntype here. Numba should in theory have everything it needs to infer the return type when compiling op.

cuda.compile returns both the LTOIR as well as the inferred return type, which we seem to be discarding in extract_ctypes_ltoirs.

Are we able to use the numba inferred return type and not require it from the user?

If not, it might be because numba doesn't have enough typing information. If that is the case, it will be fixed as part of #3064 by defining numba types corresponding to all of our Iteratortypes.

return 3 * val


SUPPORTED_VALUE_TYPE_NAMES = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use numpy types, which are trivially convertible to numba types via numba.from_dtype(...)?

@pytest.mark.parametrize(
"type_obj_from_str", [_iterators.numba_type_from_any, numpy.dtype, cp.dtype]
)
@pytest.mark.parametrize("value_type_name", SUPPORTED_VALUE_TYPE_NAMES)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we have found parametrized fixtures to be the better choice when sharing parameters across tests, especially as the codebase evolves:

https://docs.rapids.ai/api/cudf/stable/developer_guide/testing/#parametrization-custom-fixtures-and-pytest-mark-parametrize

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 6aeeff3

Nice. I didn't realize fixtures can be used in this way.

import numba.cuda
import numba.types
import cuda.parallel.experimental as cudax
from cuda.parallel.experimental import _iterators
Copy link
Copy Markdown
Contributor

@shwina shwina Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need something from a non-public submodule in the tests, then it's possible that:

  • it should go in a public API
  • we don't really need it

For instance, we are using _iterators.pointer() to construct inputs for one of our tests. This suggests that pointer() should be a public API (OR we are testing something that we don't expect users to ever do).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be a leftover. _iterators.pointer() is an implementation detail for transform iterator (glue layer to make it support containers). I would suggest to avoid testing reduce with pointer directly and test only reduction of transformed cp.array.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of short offline discussion: Maybe in a follow-on PR:

TransformIterator(identity_op, cupy_array, op_return_value_type)

This way we'd still have a test targeted at RawPointer, but through a public API.

)


def TransformIterator(op, it, op_return_value_type):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: can we infer return type of the op(it.value_type) somehow? I'd prefer not having value type parameter on transform iterator if possible.

Suggested change
def TransformIterator(op, it, op_return_value_type):
def TransformIterator(op, it):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so -- see my comment above.

from . import _iterators


def CacheModifiedInputIterator(device_array, value_type, modifier):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: can we infer value type from device_array? I'd prefer not having value_type parameter on this iterator is possible. Value type should match underlying memory's value type exactly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be just numba.from_dtype(device_array.dtype).

rwgk added 2 commits December 6, 2024 10:41
… functions back to class scope, with comments to explicitly state that these are not actual methods.
Copy link
Copy Markdown
Contributor

@shwina shwina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an offline sync with @rwgk and @gevtushenko, we decided to merge this sooner than later, and follow up to address any remaining review items.

@rwgk
Copy link
Copy Markdown
Contributor Author

rwgk commented Dec 6, 2024

/ok to test

@rwgk rwgk marked this pull request as ready for review December 6, 2024 22:08
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Dec 6, 2024

🟩 CI finished in 1h 55m: Pass: 100%/3 | Total: 42m 55s | Avg: 14m 18s | Max: 30m 51s
  • 🟩 cccl_c_parallel: Pass: 100%/2 | Total: 12m 04s | Avg: 6m 02s | Max: 9m 56s

    🟩 cpu
      🟩 amd64              Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 ctk
      🟩 12.6               Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 cudacxx
      🟩 nvcc12.6           Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 cudacxx_family
      🟩 nvcc               Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 cxx
      🟩 GCC13              Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 cxx_family
      🟩 GCC                Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 gpu
      🟩 v100               Pass: 100%/2   | Total: 12m 04s | Avg:  6m 02s | Max:  9m 56s
    🟩 jobs
      🟩 Build              Pass: 100%/1   | Total:  2m 08s | Avg:  2m 08s | Max:  2m 08s
      🟩 Test               Pass: 100%/1   | Total:  9m 56s | Avg:  9m 56s | Max:  9m 56s
    
  • 🟩 python: Pass: 100%/1 | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s

    🟩 cpu
      🟩 amd64              Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 ctk
      🟩 12.6               Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 cudacxx
      🟩 nvcc12.6           Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 cudacxx_family
      🟩 nvcc               Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 cxx
      🟩 GCC13              Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 cxx_family
      🟩 GCC                Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 gpu
      🟩 v100               Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    🟩 jobs
      🟩 Test               Pass: 100%/1   | Total: 30m 51s | Avg: 30m 51s | Max: 30m 51s
    

👃 Inspect Changes

Modifications in project?

Project
CCCL Infrastructure
libcu++
CUB
Thrust
CUDA Experimental
+/- python
+/- CCCL C Parallel Library
Catch2Helper

Modifications in project or dependencies?

Project
CCCL Infrastructure
libcu++
CUB
Thrust
CUDA Experimental
+/- python
+/- CCCL C Parallel Library
Catch2Helper

🏃‍ Runner counts (total jobs: 3)

# Runner
2 linux-amd64-gpu-v100-latest-1
1 linux-amd64-cpu16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

4 participants