Skip to content

Commit 5f2d45d

Browse files
PytatoPyOpenCLArrayContext: add support for kernel profiling
1 parent 3cb5ef4 commit 5f2d45d

File tree

2 files changed

+127
-5
lines changed

2 files changed

+127
-5
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@
5454
import abc
5555
import sys
5656
from collections.abc import Callable
57+
from dataclasses import dataclass
5758
from typing import TYPE_CHECKING, Any
5859

5960
import numpy as np
6061

62+
import pyopencl as cl
63+
import pytools
6164
from pytools import memoize_method
6265
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
6366

@@ -74,7 +77,6 @@
7477

7578
if TYPE_CHECKING:
7679
import loopy as lp
77-
import pyopencl as cl
7880
import pytato
7981

8082
if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
@@ -235,6 +237,24 @@ def get_target(self):
235237

236238
# {{{ PytatoPyOpenCLArrayContext
237239

240+
241+
@dataclass
242+
class ProfileEvent:
243+
"""Holds a profile event that has not been collected by the profiler yet."""
244+
245+
cl_event: cl._cl.Event
246+
translation_unit: Any
247+
# args_tuple: tuple
248+
249+
250+
@dataclass
251+
class MultiCallKernelProfile:
252+
"""Class to hold the results of multiple kernel executions."""
253+
254+
num_calls: int
255+
time: int
256+
257+
238258
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
239259
"""
240260
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -259,7 +279,7 @@ def __init__(
259279
self, queue: cl.CommandQueue, allocator=None, *,
260280
use_memory_pool: bool | None = None,
261281
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
262-
282+
profile: bool = True,
263283
# do not use: only for testing
264284
_force_svm_arg_limit: int | None = None,
265285
) -> None:
@@ -271,9 +291,26 @@ def __init__(
271291
representation. This interface should be considered
272292
unstable.
273293
"""
294+
import pyopencl as cl
295+
274296
if allocator is not None and use_memory_pool is not None:
275297
raise TypeError("may not specify both allocator and use_memory_pool")
276298

299+
self.profile = profile
300+
301+
if profile:
302+
if not queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
303+
raise RuntimeError("Profiling was not enabled in the command queue. "
304+
"Please create the queue with "
305+
"cl.command_queue_properties.PROFILING_ENABLE.")
306+
307+
# List of ProfileEvents that haven't been transferred to profiled
308+
# results yet
309+
self.profile_events: list[ProfileEvent] = []
310+
311+
# Dict of kernel name -> list of kernel execution times
312+
self.profile_results: dict[str, list[int]] = {}
313+
277314
self.using_svm = None
278315

279316
if allocator is None:
@@ -322,6 +359,79 @@ def __init__(
322359

323360
self._force_svm_arg_limit = _force_svm_arg_limit
324361

362+
def _wait_and_transfer_profile_events(self) -> None:
363+
# First, wait for completion of all events
364+
if self.profile_events:
365+
cl.wait_for_events([p_event.cl_event for p_event in self.profile_events])
366+
367+
# Then, collect all events and store them
368+
for t in self.profile_events:
369+
name = t.translation_unit.program.entrypoint
370+
371+
time = t.cl_event.profile.end - t.cl_event.profile.start
372+
373+
self.profile_results.setdefault(name, []).append(time)
374+
375+
self.profile_events = []
376+
377+
def get_profiling_data_for_kernel(self, kernel_name: str) \
378+
-> MultiCallKernelProfile:
379+
"""Return profiling data for kernel `kernel_name`."""
380+
self._wait_and_transfer_profile_events()
381+
382+
time = 0
383+
num_calls = 0
384+
385+
if kernel_name in self.profile_results:
386+
knl_results = self.profile_results[kernel_name]
387+
388+
num_calls = len(knl_results)
389+
390+
for r in knl_results:
391+
time += r
392+
393+
return MultiCallKernelProfile(num_calls, time)
394+
395+
def reset_profiling_data_for_kernel(self, kernel_name: str) -> None:
396+
"""Reset profiling data for kernel `kernel_name`."""
397+
self.profile_results.pop(kernel_name, None)
398+
399+
def tabulate_profiling_data(self) -> pytools.Table:
400+
"""Return a :class:`pytools.Table` with the profiling results."""
401+
self._wait_and_transfer_profile_events()
402+
403+
tbl = pytools.Table()
404+
405+
# Table header
406+
tbl.add_row(("Function", "# Calls", "Time_sum [s]", "Time_avg [s]"))
407+
408+
# Precision of results
409+
g = ".4g"
410+
411+
total_calls = 0
412+
total_time = 0.0
413+
414+
for knl in self.profile_results:
415+
r = self.get_profiling_data_for_kernel(knl)
416+
417+
total_calls += r.num_calls
418+
419+
t_sum = r.time
420+
t_avg = r.time / r.num_calls
421+
if t_sum is not None:
422+
total_time += t_sum
423+
424+
time_sum = f"{t_sum:{g}}"
425+
time_avg = f"{t_avg:{g}}"
426+
427+
tbl.add_row((knl, r.num_calls, time_sum,
428+
time_avg,))
429+
430+
tbl.add_row(("", "", "", ""))
431+
tbl.add_row(("Total", total_calls, f"{total_time:{g}}", "--"))
432+
433+
return tbl
434+
325435
@property
326436
def _frozen_array_types(self) -> tuple[type, ...]:
327437
import pyopencl.array as cla
@@ -546,9 +656,13 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
546656
self._dag_transform_cache[normalized_expr])
547657

548658
assert len(pt_prg.bound_arguments) == 0
549-
_evt, out_dict = pt_prg(self.queue,
659+
evt, out_dict = pt_prg(self.queue,
550660
allocator=self.allocator,
551661
**bound_arguments)
662+
663+
if self.profile:
664+
self.profile_events.append(ProfileEvent(evt, pt_prg))
665+
552666
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
553667

554668
key_to_frozen_subary = {

arraycontext/impl/pytato/compile.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,14 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
636636
input_kwargs_for_loopy = _args_to_device_buffers(
637637
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
638638

639-
_evt, out_dict = self.pytato_program(queue=self.actx.queue,
639+
evt, out_dict = self.pytato_program(queue=self.actx.queue,
640640
allocator=self.actx.allocator,
641641
**input_kwargs_for_loopy)
642642

643+
if self.actx.profile:
644+
from arraycontext.impl.pytato import ProfileEvent
645+
self.actx.profile_events.append(ProfileEvent(evt, self.pytato_program))
646+
643647
def to_output_template(keys, _):
644648
name_in_program = self.output_id_to_name_in_program[keys]
645649
return self.actx.thaw(to_tagged_cl_array(
@@ -675,10 +679,14 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
675679
input_kwargs_for_loopy = _args_to_device_buffers(
676680
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
677681

678-
_evt, out_dict = self.pytato_program(queue=self.actx.queue,
682+
evt, out_dict = self.pytato_program(queue=self.actx.queue,
679683
allocator=self.actx.allocator,
680684
**input_kwargs_for_loopy)
681685

686+
if self.actx.profile:
687+
from arraycontext.impl.pytato import ProfileEvent
688+
self.actx.profile_events.append(ProfileEvent(evt, self.pytato_program))
689+
682690
return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
683691
axes=get_cl_axes_from_pt_axes(
684692
self.output_axes),

0 commit comments

Comments
 (0)