5454import abc
5555import sys
5656from collections .abc import Callable
57+ from dataclasses import dataclass
5758from typing import TYPE_CHECKING , Any
5859
5960import numpy as np
6061
62+ import pyopencl as cl
63+ import pytools
6164from pytools import memoize_method
6265from pytools .tag import Tag , ToTagSetConvertible , normalize_tags
6366
7477
7578if TYPE_CHECKING :
7679 import loopy as lp
77- import pyopencl as cl
7880 import pytato
7981
8082if getattr (sys , "_BUILDING_SPHINX_DOCS" , False ):
@@ -235,6 +237,24 @@ def get_target(self):
235237
236238# {{{ PytatoPyOpenCLArrayContext
237239
240+
241+ @dataclass
242+ class ProfileEvent :
243+ """Holds a profile event that has not been collected by the profiler yet."""
244+
245+ cl_event : cl ._cl .Event
246+ translation_unit : Any
247+ # args_tuple: tuple
248+
249+
250+ @dataclass
251+ class MultiCallKernelProfile :
252+ """Class to hold the results of multiple kernel executions."""
253+
254+ num_calls : int
255+ time : int
256+
257+
238258class PytatoPyOpenCLArrayContext (_BasePytatoArrayContext ):
239259 """
240260 An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -259,7 +279,7 @@ def __init__(
259279 self , queue : cl .CommandQueue , allocator = None , * ,
260280 use_memory_pool : bool | None = None ,
261281 compile_trace_callback : Callable [[Any , str , Any ], None ] | None = None ,
262-
282+ profile : bool = True ,
263283 # do not use: only for testing
264284 _force_svm_arg_limit : int | None = None ,
265285 ) -> None :
@@ -271,9 +291,26 @@ def __init__(
271291 representation. This interface should be considered
272292 unstable.
273293 """
294+ import pyopencl as cl
295+
274296 if allocator is not None and use_memory_pool is not None :
275297 raise TypeError ("may not specify both allocator and use_memory_pool" )
276298
299+ self .profile = profile
300+
301+ if profile :
302+ if not queue .properties & cl .command_queue_properties .PROFILING_ENABLE :
303+ raise RuntimeError ("Profiling was not enabled in the command queue. "
304+ "Please create the queue with "
305+ "cl.command_queue_properties.PROFILING_ENABLE." )
306+
307+ # List of ProfileEvents that haven't been transferred to profiled
308+ # results yet
309+ self .profile_events : list [ProfileEvent ] = []
310+
311+ # Dict of kernel name -> list of kernel execution times
312+ self .profile_results : dict [str , list [int ]] = {}
313+
277314 self .using_svm = None
278315
279316 if allocator is None :
@@ -322,6 +359,79 @@ def __init__(
322359
323360 self ._force_svm_arg_limit = _force_svm_arg_limit
324361
362+ def _wait_and_transfer_profile_events (self ) -> None :
363+ # First, wait for completion of all events
364+ if self .profile_events :
365+ cl .wait_for_events ([p_event .cl_event for p_event in self .profile_events ])
366+
367+ # Then, collect all events and store them
368+ for t in self .profile_events :
369+ name = t .translation_unit .program .entrypoint
370+
371+ time = t .cl_event .profile .end - t .cl_event .profile .start
372+
373+ self .profile_results .setdefault (name , []).append (time )
374+
375+ self .profile_events = []
376+
377+ def get_profiling_data_for_kernel (self , kernel_name : str ) \
378+ -> MultiCallKernelProfile :
379+ """Return profiling data for kernel `kernel_name`."""
380+ self ._wait_and_transfer_profile_events ()
381+
382+ time = 0
383+ num_calls = 0
384+
385+ if kernel_name in self .profile_results :
386+ knl_results = self .profile_results [kernel_name ]
387+
388+ num_calls = len (knl_results )
389+
390+ for r in knl_results :
391+ time += r
392+
393+ return MultiCallKernelProfile (num_calls , time )
394+
395+ def reset_profiling_data_for_kernel (self , kernel_name : str ) -> None :
396+ """Reset profiling data for kernel `kernel_name`."""
397+ self .profile_results .pop (kernel_name , None )
398+
399+ def tabulate_profiling_data (self ) -> pytools .Table :
400+ """Return a :class:`pytools.Table` with the profiling results."""
401+ self ._wait_and_transfer_profile_events ()
402+
403+ tbl = pytools .Table ()
404+
405+ # Table header
406+ tbl .add_row (("Function" , "# Calls" , "Time_sum [s]" , "Time_avg [s]" ))
407+
408+ # Precision of results
409+ g = ".4g"
410+
411+ total_calls = 0
412+ total_time = 0.0
413+
414+ for knl in self .profile_results :
415+ r = self .get_profiling_data_for_kernel (knl )
416+
417+ total_calls += r .num_calls
418+
419+ t_sum = r .time
420+ t_avg = r .time / r .num_calls
421+ if t_sum is not None :
422+ total_time += t_sum
423+
424+ time_sum = f"{ t_sum :{g }} "
425+ time_avg = f"{ t_avg :{g }} "
426+
427+ tbl .add_row ((knl , r .num_calls , time_sum ,
428+ time_avg ,))
429+
430+ tbl .add_row (("" , "" , "" , "" ))
431+ tbl .add_row (("Total" , total_calls , f"{ total_time :{g }} " , "--" ))
432+
433+ return tbl
434+
325435 @property
326436 def _frozen_array_types (self ) -> tuple [type , ...]:
327437 import pyopencl .array as cla
@@ -546,9 +656,13 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
546656 self ._dag_transform_cache [normalized_expr ])
547657
548658 assert len (pt_prg .bound_arguments ) == 0
549- _evt , out_dict = pt_prg (self .queue ,
659+ evt , out_dict = pt_prg (self .queue ,
550660 allocator = self .allocator ,
551661 ** bound_arguments )
662+
663+ if self .profile :
664+ self .profile_events .append (ProfileEvent (evt , pt_prg ))
665+
552666 assert len (set (out_dict ) & set (key_to_frozen_subary )) == 0
553667
554668 key_to_frozen_subary = {
0 commit comments