Skip to content

Commit 6bdb85b

Browse files
refactor to simplify API
1 parent 94d9d16 commit 6bdb85b

File tree

3 files changed

+28
-56
lines changed

3 files changed

+28
-56
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,6 @@ class ProfileEvent:
245245
t_unit_name: str
246246

247247

248-
@dataclass
249-
class MultiCallKernelProfile:
250-
"""Class to hold the results of multiple kernel executions."""
251-
252-
num_calls: int
253-
time: int
254-
255-
256248
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
257249
"""
258250
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -289,14 +281,13 @@ def __init__(
289281
representation. This interface should be considered
290282
unstable.
291283
"""
292-
import pyopencl as cl
293-
294284
if allocator is not None and use_memory_pool is not None:
295285
raise TypeError("may not specify both allocator and use_memory_pool")
296286

297287
self.profile_kernels = profile_kernels
298288

299289
if profile_kernels:
290+
import pyopencl as cl
300291
if not queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
301292
raise RuntimeError("Profiling was not enabled in the command queue. "
302293
"Please create the queue with "
@@ -379,44 +370,12 @@ def _wait_and_transfer_profile_events(self) -> None:
379370
def _add_profiling_events(self, start: cl._cl.Event, stop: cl._cl.Event,
380371
t_unit_name: str) -> None:
381372
"""Add profiling events to the list of profiling events."""
382-
if self.profile_kernels:
383-
self._profile_events.append(ProfileEvent(start, stop, t_unit_name))
384-
385-
def get_profiling_data_for_kernel(self, kernel_name: str) \
386-
-> MultiCallKernelProfile:
387-
"""Return profiling data for kernel *kernel_name*."""
388-
self._wait_and_transfer_profile_events()
389-
390-
time = 0
391-
num_calls = 0
392-
393-
if kernel_name in self.profile_results:
394-
knl_results = self.profile_results[kernel_name]
395-
396-
num_calls = len(knl_results)
397-
398-
for r in knl_results:
399-
time += r
400-
401-
return MultiCallKernelProfile(num_calls, time)
402-
403-
def reset_profiling_data_for_kernel(self, kernel_name: str) -> None:
404-
"""Reset profiling data for kernel *kernel_name*."""
405-
self.profile_results.pop(kernel_name, None)
406-
407-
def get_and_reset_profiling_data(self) -> dict[str, MultiCallKernelProfile]:
408-
"""Return and reset profiling data."""
409-
self._wait_and_transfer_profile_events()
410-
411-
result = {
412-
kernel_name: MultiCallKernelProfile(len(times), sum(times))
413-
for kernel_name, times in self._profile_results.items()
414-
}
373+
self._profile_events.append(ProfileEvent(start, stop, t_unit_name))
415374

375+
def _reset_profiling_data(self) -> None:
376+
"""Reset profiling data."""
416377
self._profile_results = {}
417378

418-
return result
419-
420379
# }}}
421380

422381
@property

arraycontext/impl/pytato/utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
__doc__ = """
55
.. autofunction:: transfer_from_numpy
66
.. autofunction:: transfer_to_numpy
7+
8+
9+
Profiling-related functions:
10+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
11+
12+
.. autofunction:: tabulate_profiling_data
713
"""
814

915

@@ -52,7 +58,7 @@
5258

5359
from arraycontext import ArrayContext
5460
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
55-
from arraycontext.impl.pytato import MultiCallKernelProfile
61+
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
5662

5763

5864
if TYPE_CHECKING:
@@ -224,9 +230,12 @@ def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
224230
# }}}
225231

226232

227-
def tabulate_profiling_data(profile_results: dict[str, MultiCallKernelProfile])\
228-
-> pytools.Table:
233+
# {{{ Profiling
234+
235+
def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table:
229236
"""Return a :class:`pytools.Table` with the profiling results."""
237+
actx._wait_and_transfer_profile_events()
238+
230239
tbl = pytools.Table()
231240

232241
# Table header
@@ -238,23 +247,27 @@ def tabulate_profiling_data(profile_results: dict[str, MultiCallKernelProfile])\
238247
total_calls = 0
239248
total_time = 0.0
240249

241-
for kernel_name, mckp in profile_results.items():
242-
total_calls += mckp.num_calls
250+
for kernel_name, times in actx._profile_results.items():
251+
num_calls = len(times)
252+
total_calls += num_calls
243253

244-
t_sum = mckp.time
245-
t_avg = mckp.time / mckp.num_calls
254+
t_sum = sum(times)
255+
t_avg = t_sum / num_calls
246256
if t_sum is not None:
247257
total_time += t_sum
248258

249259
time_sum = f"{t_sum:{g}}"
250260
time_avg = f"{t_avg:{g}}"
251261

252-
tbl.add_row((kernel_name, mckp.num_calls, time_sum,
253-
time_avg,))
262+
tbl.add_row((kernel_name, num_calls, time_sum, time_avg))
254263

255264
tbl.add_row(("", "", "", ""))
256265
tbl.add_row(("Total", total_calls, f"{total_time:{g}}", "--"))
257266

267+
actx._reset_profiling_data()
268+
258269
return tbl
259270

271+
# }}}
272+
260273
# vim: foldmethod=marker

test/test_pytato_arraycontext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def twice(x):
303303

304304
from arraycontext.impl.pytato.utils import tabulate_profiling_data
305305

306-
print(tabulate_profiling_data(actx.get_and_reset_profiling_data()))
306+
print(tabulate_profiling_data(actx))
307307
assert len(actx._profile_results) == 0
308308

309309
# }}}
@@ -321,7 +321,7 @@ def twice(x):
321321
assert len(actx._profile_results) == 1
322322
assert len(actx._profile_results["frozen_result"]) == 10
323323

324-
print(tabulate_profiling_data(actx.get_and_reset_profiling_data()))
324+
print(tabulate_profiling_data(actx))
325325

326326
assert len(actx._profile_results) == 0
327327

0 commit comments

Comments
 (0)