-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[Performance][Numpy] np.einsum can be 500 times slower than torch.einsum #18043
Description
The performance of np.einsum in GPU is not very good and will usually be 500 times slower than th.einsum. Because einsum is essential for implementing the attention mechanism used in NLP + CV, we should accelerate the implementation.
Here is the code to profile different implementations of einsum (also in gist: https://gist.github.com/sxjscience/bfda1a8bd2942d93eef5ddf8a15b52b8). The profiling result shows that the following order
PyTorch einsum > MXNet no-einsum >> MXNet einsum
import mxnet as mx
import numpy as np
import torch as th
import argparse
mx.npx.set_np()
parser = argparse.ArgumentParser(description='Profile einsum')
parser.add_argument('--mode', choices=['einsum', 'no_einsum', 'th_einsum'],
default='einsum', required=True)
parser.add_argument('--problem', type=int,
choices=[0, 1, 2], help='Problem type.', default=0, required=True)
args = parser.parse_args()
np.random.seed(100)
batch_size = 64
num_heads = 8
seq_length_A = 100
seq_length_B = 50
units = 128
if args.problem == 0:
lhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_A, units))
rhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_B, units))
mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
mx.npx.waitall()
th_lhs = th.from_numpy(lhs).float().cuda()
th_rhs = th.from_numpy(rhs).float().cuda()
typ = 'bnid,bnjd->bnij'
if args.mode == 'einsum':
out = mx.np.einsum(typ, mx_lhs, mx_rhs)
out_np = out.asnumpy()
elif args.mode == 'no_einsum':
out = mx.npx.batch_dot(mx_lhs, mx_rhs, transpose_b=True)
out_np = out.asnumpy()
elif args.mode == 'th_einsum':
out = th.einsum(typ, th_lhs, th_rhs)
out_np = out.cpu().numpy()
else:
raise NotImplementedError
print(out_np.shape)
elif args.problem == 1:
lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units))
rhs = np.random.normal(0, 1, (batch_size, seq_length_B, num_heads, units))
mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
mx.npx.waitall()
th_lhs = th.from_numpy(lhs).float().cuda()
th_rhs = th.from_numpy(rhs).float().cuda()
typ = 'bind,bjnd->bnij'
if args.mode == 'einsum':
out = mx.np.einsum(typ, mx_lhs, mx_rhs)
out_np = out.asnumpy()
elif args.mode == 'no_einsum':
out = mx.npx.batch_dot(mx.np.swapaxes(mx_lhs, 1, 2),
mx.np.swapaxes(mx_rhs, 1, 2),
transpose_b=True)
out_np = out.asnumpy()
elif args.mode == 'th_einsum':
out = th.einsum(typ, th_lhs, th_rhs)
out_np = out.cpu().numpy()
else:
raise NotImplementedError
print(out_np.shape)
elif args.problem == 2:
lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units))
rhs = np.random.normal(0, 1, (seq_length_B, num_heads, units))
mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
mx.npx.waitall()
th_lhs = th.from_numpy(lhs).float().cuda()
th_rhs = th.from_numpy(rhs).float().cuda()
typ = 'bind,jnd->bnij'
if args.mode == 'einsum':
out = mx.np.einsum(typ, mx_lhs, mx_rhs)
out_np = out.asnumpy()
elif args.mode == 'no_einsum':
out = mx.np.matmul(mx.np.swapaxes(mx_lhs, 1, 2),
mx.np.transpose(mx_rhs, (1, 2, 0)))
out_np = out.asnumpy()
elif args.mode == 'th_einsum':
out = th.einsum(typ, th_lhs, th_rhs)
out_np = out.cpu().numpy()
else:
raise NotImplementedError
print(out_np.shape)We profiled three different usages of einsum:
-
(B, K, T0, C) X (B, K, T1, C) --> (B, K, T0, T1)
-
MXNet einsum
nvprof python profile_einsum.py --mode einsum --problem 0Time Kernel 41.009ms _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0_ -
MXNet implementation without einsum
nvprof python profile_einsum.py --mode no_einsum --problem 0Time Kernel 198.75us volta_sgemm_128x64_tn -
PyTorch Implementation
nvprof python profile_einsum.py --mode th_einsum --problem 0Time Kernel 192.35us volta_sgemm_128x64_tn
-
-
(B, T0, K, C) X (B, T1, K, C) --> (B, K, T0, T1)
-
MXNet einsum
nvprof python profile_einsum.py --mode einsum --problem 1Time Kernel 40.665ms _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0_ -
MXNet implementation without einsum
nvprof python profile_einsum.py --mode no_einsum --problem 1Time Kernel 185.76us volta_sgemm_128x64_tn 89.519us void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5) -
PyTorch implementation
nvprof python profile_einsum.py --mode th_einsum --problem 1Time Kernel 193.02us volta_sgemm_128x64_tn 61.967us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1_
-
-
(B, K, T0, C) X (T1, K, C) --> (B, K, T0, T1)
-
MXNet einsum
nvprof python profile_einsum.py --mode einsum --problem 2Time Kernel 40.551ms _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0_ -
MXNet implementation without einsum
nvprof python profile_einsum.py --mode no_einsum --problem 2Time Kernel 322.33us _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_16broadcast_kernelINS0_10mshadow_op8identityEEEJPfS7_N7mshadow5ShapeILi5EEESA_NS_9OpReqTypeEmEEEviDpT0_ 183.23us volta_sgemm_128x64_nn 120.13us void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5) 5.3120us void mxnet::op::cuda::transpose_pseudo2D<float, unsigned long, bool=0>(float*, float, int, int, int, int) -
PyTorch Implementation
nvprof python profile_einsum.py --mode th_einsum --problem 2Time Kernel 152.16us volta_sgemm_128x64_tn 28.704us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1_
-