From 1576a432b2773fe957a8a4e7e89cf2301039c414 Mon Sep 17 00:00:00 2001 From: kimmishi Date: Fri, 24 Nov 2023 09:12:10 +0800 Subject: [PATCH 1/2] Update flops profiler to handle attn and __matmul__ --- deepspeed/profiling/flops_profiler/profiler.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 79e682a73b90..9f8d4515ef69 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -826,6 +826,13 @@ def _elementwise_flops_compute(input, other): flops = _prod(final_shape) return flops, 0 +def _attn_flops_compute(q, k, v, *args, **kwargs): + """ + Count flops for the scaled_dot_product_attention operation. + """ + macs = _prod(q.shape) * k.shape[-2] + macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1] + return 2 * macs, macs def wrapFunc(func, funcFlopCompute): oldFunc = func @@ -899,10 +906,14 @@ def _patch_functionals(): # embedding F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) + # attn + F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute) + def _patch_tensor_methods(): torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute) torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute) + torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute) torch.mm = wrapFunc(torch.mm, _matmul_flops_compute) torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute) torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) From eee9064cc80969bc41c03ecd7dd9d25c2cc64399 Mon Sep 17 00:00:00 2001 From: kimmishi Date: Fri, 22 Dec 2023 11:53:30 +0800 Subject: [PATCH 2/2] style: update profile.py --- deepspeed/profiling/flops_profiler/profiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 9f8d4515ef69..de847e59e82e 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -826,6 +826,7 @@ def _elementwise_flops_compute(input, other): flops = _prod(final_shape) return flops, 0 + def _attn_flops_compute(q, k, v, *args, **kwargs): """ Count flops for the scaled_dot_product_attention operation. @@ -834,6 +835,7 @@ def _attn_flops_compute(q, k, v, *args, **kwargs): macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1] return 2 * macs, macs + def wrapFunc(func, funcFlopCompute): oldFunc = func name = func.__str__