diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 79e682a73b90..de847e59e82e 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -827,6 +827,15 @@ def _elementwise_flops_compute(input, other): return flops, 0 +def _attn_flops_compute(q, k, v, *args, **kwargs): + """ + Count flops for the scaled_dot_product_attention operation. + """ + macs = _prod(q.shape) * k.shape[-2] + macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1] + return 2 * macs, macs + + def wrapFunc(func, funcFlopCompute): oldFunc = func name = func.__str__ @@ -899,10 +908,14 @@ def _patch_functionals(): # embedding F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) + # attn + F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute) + def _patch_tensor_methods(): torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute) torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute) + torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute) torch.mm = wrapFunc(torch.mm, _matmul_flops_compute) torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute) torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)