Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[Performance][Numpy] np.einsum can be 500 times slower than torch.einsum #18043

@sxjscience

Description

@sxjscience

The performance of np.einsum in GPU is not very good and will usually be 500 times slower than th.einsum. Because einsum is essential for implementing the attention mechanism used in NLP + CV, we should accelerate the implementation.

Here is the code to profile different implementations of einsum (also in gist: https://gist.github.com/sxjscience/bfda1a8bd2942d93eef5ddf8a15b52b8). The profiling result shows that the following order

PyTorch einsum > MXNet no-einsum >> MXNet einsum

import mxnet as mx
import numpy as np
import torch as th
import argparse
mx.npx.set_np()

parser = argparse.ArgumentParser(description='Profile einsum')
parser.add_argument('--mode', choices=['einsum', 'no_einsum', 'th_einsum'],
                    default='einsum', required=True)
parser.add_argument('--problem', type=int,
                    choices=[0, 1, 2], help='Problem type.', default=0, required=True)
args = parser.parse_args()

np.random.seed(100)
batch_size = 64
num_heads = 8
seq_length_A = 100
seq_length_B = 50
units = 128

if args.problem == 0:
    lhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_A, units))
    rhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_B, units))
    mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
    mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
    mx.npx.waitall()
    th_lhs = th.from_numpy(lhs).float().cuda()
    th_rhs = th.from_numpy(rhs).float().cuda()
    typ = 'bnid,bnjd->bnij'
    if args.mode == 'einsum':
        out = mx.np.einsum(typ, mx_lhs, mx_rhs)
        out_np = out.asnumpy()
    elif args.mode == 'no_einsum':
        out = mx.npx.batch_dot(mx_lhs, mx_rhs, transpose_b=True)
        out_np = out.asnumpy()
    elif args.mode == 'th_einsum':
        out = th.einsum(typ, th_lhs, th_rhs)
        out_np = out.cpu().numpy()
    else:
        raise NotImplementedError
    print(out_np.shape)
elif args.problem == 1:
    lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units))
    rhs = np.random.normal(0, 1, (batch_size, seq_length_B, num_heads, units))
    mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
    mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
    mx.npx.waitall()
    th_lhs = th.from_numpy(lhs).float().cuda()
    th_rhs = th.from_numpy(rhs).float().cuda()
    typ = 'bind,bjnd->bnij'
    if args.mode == 'einsum':
        out = mx.np.einsum(typ, mx_lhs, mx_rhs)
        out_np = out.asnumpy()
    elif args.mode == 'no_einsum':
        out = mx.npx.batch_dot(mx.np.swapaxes(mx_lhs, 1, 2),
                               mx.np.swapaxes(mx_rhs, 1, 2),
                               transpose_b=True)
        out_np = out.asnumpy()
    elif args.mode == 'th_einsum':
        out = th.einsum(typ, th_lhs, th_rhs)
        out_np = out.cpu().numpy()
    else:
        raise NotImplementedError
    print(out_np.shape)
elif args.problem == 2:
    lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units))
    rhs = np.random.normal(0, 1, (seq_length_B, num_heads, units))
    mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
    mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
    mx.npx.waitall()
    th_lhs = th.from_numpy(lhs).float().cuda()
    th_rhs = th.from_numpy(rhs).float().cuda()
    typ = 'bind,jnd->bnij'
    if args.mode == 'einsum':
        out = mx.np.einsum(typ, mx_lhs, mx_rhs)
        out_np = out.asnumpy()
    elif args.mode == 'no_einsum':
        out = mx.np.matmul(mx.np.swapaxes(mx_lhs, 1, 2),
                           mx.np.transpose(mx_rhs, (1, 2, 0)))
        out_np = out.asnumpy()
    elif args.mode == 'th_einsum':
        out = th.einsum(typ, th_lhs, th_rhs)
        out_np = out.cpu().numpy()
    else:
        raise NotImplementedError
    print(out_np.shape)

We profiled three different usages of einsum:

  1. (B, K, T0, C) X (B, K, T1, C) --> (B, K, T0, T1)

    • MXNet einsum
      nvprof python profile_einsum.py --mode einsum --problem 0

      Time Kernel
      41.009ms _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0_
    • MXNet implementation without einsum
      nvprof python profile_einsum.py --mode no_einsum --problem 0

      Time Kernel
      198.75us volta_sgemm_128x64_tn
    • PyTorch Implementation
      nvprof python profile_einsum.py --mode th_einsum --problem 0

      Time Kernel
      192.35us volta_sgemm_128x64_tn
  2. (B, T0, K, C) X (B, T1, K, C) --> (B, K, T0, T1)

    • MXNet einsum
      nvprof python profile_einsum.py --mode einsum --problem 1

      Time Kernel
      40.665ms _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0_
    • MXNet implementation without einsum
      nvprof python profile_einsum.py --mode no_einsum --problem 1

      Time Kernel
      185.76us volta_sgemm_128x64_tn
      89.519us void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5)
    • PyTorch implementation
      nvprof python profile_einsum.py --mode th_einsum --problem 1

      Time Kernel
      193.02us volta_sgemm_128x64_tn
      61.967us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1_
  3. (B, K, T0, C) X (T1, K, C) --> (B, K, T0, T1)

    • MXNet einsum
      nvprof python profile_einsum.py --mode einsum --problem 2

      Time Kernel
      40.551ms _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0_
    • MXNet implementation without einsum
      nvprof python profile_einsum.py --mode no_einsum --problem 2

      Time Kernel
      322.33us _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_16broadcast_kernelINS0_10mshadow_op8identityEEEJPfS7_N7mshadow5ShapeILi5EEESA_NS_9OpReqTypeEmEEEviDpT0_
      183.23us volta_sgemm_128x64_nn
      120.13us void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5)
      5.3120us void mxnet::op::cuda::transpose_pseudo2D<float, unsigned long, bool=0>(float*, float, int, int, int, int)
    • PyTorch Implementation
      nvprof python profile_einsum.py --mode th_einsum --problem 2

      Time Kernel
      152.16us volta_sgemm_128x64_tn
      28.704us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1_

@yzhliu @hzfan @haojin2 @reminisce @szha

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions