diff --git a/tests/python/llama_inf_tests/graph_0.py b/tests/python/llama_inf_tests/graph_0.py new file mode 100644 index 00000000000..0bfc9db0d00 --- /dev/null +++ b/tests/python/llama_inf_tests/graph_0.py @@ -0,0 +1,131 @@ +import torch +from nvfuser import FusionDefinition, DataType +import time + +def nvfuser_fusion_id0(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0]) + T1 = fd.define_tensor(shape=[128256, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T2 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0]) + S3 = fd.define_scalar(2.00000, dtype=DataType.Double) + S4 = fd.define_scalar(False, dtype=DataType.Bool) + S5 = fd.define_scalar(False, dtype=DataType.Bool) + T6 = fd.ops.embedding_fwd(T0, T1, None, None, S3, S4, S5) + S7 = fd.define_scalar(6, dtype=DataType.Int) + S8 = fd.define_scalar(0, dtype=DataType.Int) + S9 = fd.define_scalar(1, dtype=DataType.Int) + T10 = fd.ops.iota(S7, S8, S9, dtype=DataType.Int) + T14 = fd.ops.broadcast_in_dim(T10, shape=[1, 6], broadcast_dims=[1]) + S15 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double) + T19 = fd.ops.full(shape=[6, 6], fill_value=S15, dtype=DataType.BFloat16) + T23 = fd.ops.broadcast_in_dim(T10, shape=[6, 1], broadcast_dims=[0]) + T27 = fd.ops.broadcast_in_dim(T14, shape=[6, 6], broadcast_dims=[0, 1]) + T31 = fd.ops.broadcast_in_dim(T23, shape=[6, 6], broadcast_dims=[0, 1]) + T32 = fd.ops.sub(T27, T31) + S33 = fd.define_scalar(1, dtype=DataType.Int) + T34 = fd.ops.ge(T32, S33) + S35 = fd.define_scalar(0.00000, dtype=DataType.Double) + T36 = fd.ops.where(T34, T19, S35) + T40 = fd.ops.reshape(T10, new_shape=[6, 1]) + T44 = fd.ops.broadcast_in_dim(T10, shape=[6, 6], broadcast_dims=[1]) + T48 = fd.ops.broadcast_in_dim(T40, shape=[6, 6], broadcast_dims=[0, 1]) + T49 = fd.ops.gt(T44, T48) + T50 = fd.ops.cast(T36, dtype=DataType.Float) + T51 = fd.ops.cast(T49, dtype=DataType.Float) + T52 = fd.ops.mul(T50, T51) + T53 = fd.ops.cast(T52, dtype=DataType.BFloat16) + T59 = fd.ops.broadcast_in_dim(T53, shape=[1, 1, 6, 6], broadcast_dims=[2, 3]) + T65 = fd.ops.broadcast_in_dim(T59, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3]) + T66 = fd.ops.set(T65) + T72 = fd.ops.broadcast_in_dim(T2, shape=[1, 1, 1, 6], broadcast_dims=[0, 3]) + T78 = fd.ops.broadcast_in_dim(T72, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3]) + T79 = fd.ops.cast(T66, dtype=DataType.Float) + T80 = fd.ops.cast(T78, dtype=DataType.Float) + T81 = fd.ops.add(T79, T80) + T82 = fd.ops.cast(T81, dtype=DataType.BFloat16) + S83 = fd.define_scalar(0.00000, dtype=DataType.Double) + T84 = fd.ops.eq(T82, S83) + S85 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double) + T86 = fd.ops.where(T84, S85, T66) + fd.add_output(T6) + fd.add_output(T66) + fd.add_output(T86) + +with FusionDefinition() as fd: + nvfuser_fusion_id0(fd) + +inputs = [ + torch.ones((1, 6), dtype=torch.int64, device='cuda:0'), + torch.testing.make_tensor((128256, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.ones((1, 6), dtype=torch.int64, device='cuda:0'), +] + +fd.execute(inputs) + +for _ in range(3): + fd.execute(inputs) + +torch.cuda.synchronize() +start = time.time() +# Mark the profiling region +torch.cuda.cudart().cudaProfilerStart() + +for _ in range(100): + fd.execute(inputs) + +torch.cuda.cudart().cudaProfilerStop() +torch.cuda.synchronize() +end = time.time() + +print((end-start)*1000, " ms") + +# Before: +# 12.0 ms +# After: +# 3.1 ms + +# rm report* +# nsys profile -c cudaProfilerApi python tests/python/llama_inf_tests/graph_0.py +# nsys stats report1.nsys-rep + +# Before: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 13.8 10011392 100 100113.9 80400.0 76319 768432 82944.7 PushPop :FusionExecutorCache::runFusionWithInputs +# 13.0 9409367 100 94093.7 77940.0 74188 765647 79435.0 PushPop :FusionKernelRuntime::runWithInputs +# 12.9 9353511 100 93535.1 77347.0 73635 764599 79335.4 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 12.4 8989375 300 29964.6 26494.0 12397 698157 44537.4 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.3 8896373 300 29654.6 26056.5 12135 697796 44508.6 PushPop :ExecutorDispatch::run2 +# 10.1 7309840 200 36549.2 31871.0 24321 697376 51775.8 PushPop :KernelExecutor::runFusion +# 6.7 4859672 200 24298.4 22950.5 13246 684391 48960.1 PushPop :KernelExecutor::runFusion::execute_kernel +# 5.9 4316308 1200 3596.9 2396.0 1980 175457 7203.4 PushPop :ExpressionEvaluator::evaluate +# 5.6 4086264 200 20431.3 19635.5 10005 674394 48476.2 PushPop :KernelExecutor::recomputeArgs +# 2.0 1455689 100 14556.9 12596.0 11930 176349 16430.4 PushPop :ExprEvalExecutor::run +# 1.9 1362206 200 6811.0 7320.5 3864 174236 12085.8 PushPop :fusion_executor::allocations::allocateOutputs +# 1.4 997365 600 1662.3 1368.0 1205 167712 6793.6 PushPop :fusion_executor::allocations::allocateTensor +# 1.0 690717 200 3453.6 3288.0 2816 10209 890.9 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.4 294173 300 980.6 831.0 107 9627 925.9 PushPop :executor_utils::bindInputs +# 0.3 228065 300 760.2 152.5 122 165328 9538.9 PushPop :ExecutorDispatch::isCompiled +# 0.3 192836 200 964.2 108.0 99 164903 11650.7 PushPop :KernelExecutor::runFusion::intermediates +# 0.1 75680 100 756.8 712.0 459 4308 388.0 PushPop :FusionExecutorCache::setCacheId +# 0.0 17393 100 173.9 133.0 112 875 100.7 PushPop :FusionExecutorCache::getKernelRuntimeFor + +# After: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 17.1 5182038 100 51820.4 40488.5 38316 309012 45433.7 PushPop :FusionExecutorCache::runFusionWithInputs +# 15.5 4712599 100 47126.0 38026.5 36027 293111 40961.2 PushPop :FusionKernelRuntime::runWithInputs +# 15.3 4653120 100 46531.2 37485.5 35536 290957 40853.9 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.5 4099585 300 13665.3 11896.0 8647 197301 18602.1 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.5 3810167 300 12700.6 11606.0 8426 196668 15305.2 PushPop :ExecutorDispatch::run2 +# 7.0 2114207 200 10571.0 10721.5 8123 45738 3065.6 PushPop :KernelExecutor::runFusion +# 5.3 1601371 100 16013.7 11992.5 11374 196441 25772.1 PushPop :ExprEvalExecutor::run +# 4.6 1406004 100 14060.0 10303.5 9756 194564 24892.8 PushPop :ExpressionEvaluator::evaluate +# 2.6 803300 200 4016.5 4956.5 2605 16355 1490.5 PushPop :fusion_executor::allocations::allocateOutputs +# 2.4 722700 200 3613.5 3430.5 2910 15725 1130.3 PushPop :KernelExecutor::runFusion::execute_kernel +# 2.2 666998 200 3335.0 3195.5 2708 14553 1040.9 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.8 257659 100 2576.6 707.5 483 182114 18138.7 PushPop :FusionExecutorCache::setCacheId +# 0.4 116565 200 582.8 602.0 420 2139 184.6 PushPop :KernelExecutor::computeArgs2 +# 0.3 96582 100 965.8 835.5 772 9563 880.2 PushPop :executor_utils::bindInputs +# 0.2 66419 300 221.4 157.0 127 2363 221.8 PushPop :ExecutorDispatch::isCompiled +# 0.1 29389 100 293.9 132.5 111 10088 1006.3 PushPop :FusionExecutorCache::getKernelRuntimeFor +# 0.1 29125 200 145.6 108.0 97 798 107.9 PushPop :KernelExecutor::runFusion::intermediates diff --git a/tests/python/llama_inf_tests/graph_1.py b/tests/python/llama_inf_tests/graph_1.py new file mode 100644 index 00000000000..4581586db20 --- /dev/null +++ b/tests/python/llama_inf_tests/graph_1.py @@ -0,0 +1,208 @@ +import torch +from nvfuser import FusionDefinition, DataType +import time + +def nvfuser_fusion_id1(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T1 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T2 = fd.define_tensor(shape=[32], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T3 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T4 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T5 = fd.define_tensor(shape=[1, 1, 6, 6], contiguity=[True, None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 3, 0]) + T6 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T7 = fd.ops.cast(T0, dtype=DataType.Float) + S8 = fd.define_scalar(2.00000, dtype=DataType.Double) + T9 = fd.ops.pow(T7, S8) + T10 = fd.ops.sum(T9, dims=[2], keepdim=False, dtype=DataType.Null) + T15 = fd.ops.broadcast_in_dim(T10, shape=[1, 6, 1], broadcast_dims=[0, 1]) + S16 = fd.define_scalar(2048.00, dtype=DataType.Double) + S17 = fd.ops.reciprocal(S16) + T18 = fd.ops.mul(T15, S17) + S19 = fd.define_scalar(1.00000e-05, dtype=DataType.Double) + T20 = fd.ops.add(T18, S19) + T21 = fd.ops.rsqrt(T20) + T26 = fd.ops.broadcast_in_dim(T21, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2]) + T27 = fd.ops.mul(T7, T26) + S28 = fd.define_scalar(6, dtype=DataType.Int) + S29 = fd.define_scalar(0, dtype=DataType.Int) + S30 = fd.define_scalar(1, dtype=DataType.Int) + T31 = fd.ops.iota(S28, S29, S30, dtype=DataType.Int) + T36 = fd.ops.broadcast_in_dim(T1, shape=[1, 6, 2048], broadcast_dims=[2]) + T40 = fd.ops.broadcast_in_dim(T31, shape=[1, 6], broadcast_dims=[1]) + T45 = fd.ops.broadcast_in_dim(T2, shape=[1, 32, 1], broadcast_dims=[1]) + T46 = fd.ops.cast(T36, dtype=DataType.Float) + T51 = fd.ops.broadcast_in_dim(T40, shape=[1, 1, 6], broadcast_dims=[0, 2]) + T52 = fd.ops.cast(T45, dtype=DataType.Float) + T53 = fd.ops.mul(T46, T27) + T54 = fd.ops.cast(T51, dtype=DataType.Float) + T59 = fd.ops.broadcast_in_dim(T52, shape=[1, 32, 1], broadcast_dims=[0, 1, 2]) + T60 = fd.ops.cast(T53, dtype=DataType.BFloat16) + T61 = fd.ops.matmul(T59, T54) + T62 = fd.ops.linear(T60, T3) + T63 = fd.ops.permute(T61, dims=[0, 2, 1]) + T69 = fd.ops.reshape(T62, new_shape=[1, 6, 8, 64]) + T70 = fd.ops.cat([T63, T63], dim=-1, manual_padding=0) + T71 = fd.ops.permute(T69, dims=[0, 2, 1, 3]) + T72 = fd.ops.sin(T70) + T88 = fd.ops.slice(T71, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0) + T89 = fd.ops.cos(T70) + T90 = fd.ops.linear(T60, T4) + S91 = fd.define_scalar(1.00000, dtype=DataType.Double) + T92 = fd.ops.mul(T72, S91) + T93 = fd.ops.cast(T88, dtype=DataType.Float) + S94 = fd.define_scalar(1.00000, dtype=DataType.Double) + T95 = fd.ops.mul(T89, S94) + T101 = fd.ops.reshape(T90, new_shape=[1, 6, 32, 64]) + T102 = fd.ops.cast(T92, dtype=DataType.BFloat16) + T103 = fd.ops.neg(T93) + T104 = fd.ops.cast(T95, dtype=DataType.BFloat16) + T105 = fd.ops.permute(T101, dims=[0, 2, 1, 3]) + T111 = fd.ops.broadcast_in_dim(T102, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3]) + T127 = fd.ops.slice(T71, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0) + T128 = fd.ops.cast(T103, dtype=DataType.BFloat16) + T134 = fd.ops.broadcast_in_dim(T104, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3]) + T150 = fd.ops.slice(T105, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0) + S151 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double) + T152 = fd.ops.eq(T5, S151) + T158 = fd.ops.broadcast_in_dim(T111, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T159 = fd.ops.cat([T128, T127], dim=-1, manual_padding=0) + T165 = fd.ops.broadcast_in_dim(T134, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T166 = fd.ops.cast(T150, dtype=DataType.Float) + T167 = fd.ops.bitwise_not(T152) + T168 = fd.ops.cast(T158, dtype=DataType.Float) + T169 = fd.ops.cast(T159, dtype=DataType.Float) + T170 = fd.ops.cast(T165, dtype=DataType.Float) + T171 = fd.ops.cast(T71, dtype=DataType.Float) + T172 = fd.ops.neg(T166) + T173 = fd.ops.cast(T167, dtype=DataType.Int) + T174 = fd.ops.mul(T169, T168) + T175 = fd.ops.mul(T171, T170) + T191 = fd.ops.slice(T105, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0) + T192 = fd.ops.cast(T172, dtype=DataType.BFloat16) + T193 = fd.ops.sum(T173, dims=[3], keepdim=False, dtype=DataType.Null) + T199 = fd.ops.broadcast_in_dim(T111, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T200 = fd.ops.cat([T192, T191], dim=-1, manual_padding=0) + T206 = fd.ops.broadcast_in_dim(T134, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3]) + T212 = fd.ops.broadcast_in_dim(T193, shape=[1, 1, 6, 1], broadcast_dims=[0, 1, 2]) + T213 = fd.ops.linear(T60, T6) + T214 = fd.ops.cast(T199, dtype=DataType.Float) + T215 = fd.ops.cast(T200, dtype=DataType.Float) + T216 = fd.ops.cast(T206, dtype=DataType.Float) + T217 = fd.ops.cast(T105, dtype=DataType.Float) + S218 = fd.define_scalar(0, dtype=DataType.Int) + T219 = fd.ops.ne(T212, S218) + T225 = fd.ops.reshape(T213, new_shape=[1, 6, 8, 64]) + T226 = fd.ops.add(T175, T174) + T227 = fd.ops.mul(T215, T214) + T228 = fd.ops.mul(T217, T216) + T229 = fd.ops.bitwise_not(T219) + T230 = fd.ops.permute(T225, dims=[0, 2, 1, 3]) + T231 = fd.ops.cast(T226, dtype=DataType.BFloat16) + T232 = fd.ops.bitwise_not(T229) + T239 = fd.ops.broadcast_in_dim(T230, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4]) + T246 = fd.ops.broadcast_in_dim(T231, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4]) + T252 = fd.ops.broadcast_in_dim(T232, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3]) + T259 = fd.ops.broadcast_in_dim(T239, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4]) + T266 = fd.ops.broadcast_in_dim(T246, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4]) + T267 = fd.ops.add(T228, T227) + T268 = fd.ops.cast(T252, dtype=DataType.Float) + T269 = fd.ops.cast(T5, dtype=DataType.Float) + T275 = fd.ops.reshape(T259, new_shape=[1, 32, 6, 64]) + T281 = fd.ops.reshape(T266, new_shape=[1, 32, 6, 64]) + T282 = fd.ops.cast(T267, dtype=DataType.BFloat16) + T283 = fd.ops.mul(T269, T268) + T284 = fd.ops.stride_order(T275, stride_order=[3, 2, 1, 0]) + T285 = fd.ops.stride_order(T281, stride_order=[3, 2, 1, 0]) + T286 = fd.ops.stride_order(T282, stride_order=[3, 2, 1, 0]) + T287 = fd.ops.cast(T283, dtype=DataType.BFloat16) + fd.add_output(T287) + fd.add_output(T230) + fd.add_output(T231) + fd.add_output(T286) + fd.add_output(T285) + fd.add_output(T284) + +with FusionDefinition() as fd: + nvfuser_fusion_id1(fd) + +inputs = [ + torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((32,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((1, 1, 6, 6), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'), +] + +fd.execute(inputs) + +for _ in range(3): + fd.execute(inputs) + +torch.cuda.synchronize() +start = time.time() +# Mark the profiling region +torch.cuda.cudart().cudaProfilerStart() + +for _ in range(100): + fd.execute(inputs) + +torch.cuda.cudart().cudaProfilerStop() +torch.cuda.synchronize() +end = time.time() + +print((end-start)*1000, " ms") + + +# Before: +# 19.8 ms +# After: +# 10.6 ms + +# rm report* +# nsys profile -c cudaProfilerApi python tests/python/llama_inf_tests/graph_1.py +# nsys stats report1.nsys-rep + +# Before: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 14.2 31791843 100 317918.4 268246.5 246507 762170 88138.8 PushPop :FusionExecutorCache::runFusionWithInputs +# 13.6 30602349 100 306023.5 261735.5 239889 737741 82786.1 PushPop :FusionKernelRuntime::runWithInputs +# 13.6 30480294 100 304802.9 260895.0 239116 735007 82461.4 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.0 29106369 1300 22389.5 18414.5 1815 266605 22601.1 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.5 28090755 1300 21608.3 17832.0 1556 265963 22146.4 PushPop :ExecutorDispatch::run2 +# 8.1 18224542 500 36449.1 32308.0 16901 265512 24691.3 PushPop :KernelExecutor::runFusion +# 7.2 16182053 4100 3946.8 3299.5 258 152312 4334.0 PushPop :ExpressionEvaluator::evaluate +# 5.3 11797199 500 23594.4 18368.5 10877 209862 17543.9 PushPop :KernelExecutor::runFusion::execute_kernel +# 4.1 9212380 500 18424.8 12260.5 7177 200666 15103.3 PushPop :KernelExecutor::recomputeArgs +# 3.9 8816273 800 11020.3 11422.5 1369 159960 9005.4 PushPop :ExprEvalExecutor::run +# 1.5 3394339 500 6788.7 3757.5 1977 196369 9934.9 PushPop :fusion_executor::allocations::allocateOutputs +# 1.2 2593352 900 2881.5 1502.5 1213 186775 6863.1 PushPop :fusion_executor::allocations::allocateTensor +# 1.0 2179586 500 4359.2 3768.0 2935 185683 8309.3 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.6 1350809 1300 1039.1 838.0 420 6663 603.7 PushPop :executor_utils::bindInputs +# 0.2 498945 1300 383.8 172.0 102 2870 400.6 PushPop :ExecutorDispatch::isCompiled +# 0.1 159524 500 319.0 120.0 97 6098 431.0 PushPop :KernelExecutor::runFusion::intermediates +# 0.1 145263 100 1452.6 1330.0 907 5442 594.3 PushPop :FusionExecutorCache::setCacheId +# 0.0 37566 100 375.7 182.0 130 1791 350.7 PushPop :FusionExecutorCache::getKernelRuntimeFor + +# After: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 16.1 19038957 100 190389.6 172709.0 145764 682753 77360.1 PushPop :FusionExecutorCache::runFusionWithInputs +# 15.2 18050839 100 180508.4 164976.0 140671 653055 72317.9 PushPop :FusionKernelRuntime::runWithInputs +# 15.2 17951998 100 179520.0 164040.5 139910 650172 72127.0 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.8 16310510 1300 12546.5 11597.5 1791 229198 15422.3 PushPop :FusionKernelRuntime::runKernelWithInput +# 13.2 15617300 1300 12013.3 11062.5 1542 228905 14374.1 PushPop :ExecutorDispatch::run2 +# 7.4 8714870 1900 4586.8 2470.0 255 214450 9636.4 PushPop :ExpressionEvaluator::evaluate +# 7.2 8564673 800 10705.8 10866.5 1348 200141 10054.6 PushPop :ExprEvalExecutor::run +# 5.5 6549775 500 13099.5 10053.5 6840 228580 19230.8 PushPop :KernelExecutor::runFusion +# 1.9 2290425 500 4580.9 3967.5 3054 200810 9071.5 PushPop :KernelExecutor::runFusion::execute_kernel +# 1.8 2092522 500 4185.0 3594.0 2841 200306 9052.8 PushPop :ExecutorRunFusion::cuLaunchKernel +# 1.3 1562864 500 3125.7 1808.5 1417 205151 9320.5 PushPop :fusion_executor::allocations::allocateOutputs +# 0.6 717665 900 797.4 669.0 433 11922 529.8 PushPop :executor_utils::bindInputs +# 0.4 423085 500 846.2 644.0 258 12710 764.5 PushPop :KernelExecutor::computeArgs2 +# 0.2 264123 1300 203.2 153.0 103 2587 140.1 PushPop :ExecutorDispatch::isCompiled +# 0.1 131127 100 1311.3 1206.5 829 6301 574.0 PushPop :FusionExecutorCache::setCacheId +# 0.1 101062 500 202.1 122.0 98 2736 238.3 PushPop :KernelExecutor::runFusion::intermediates +# 0.0 20536 100 205.4 158.0 115 1009 135.9 PushPop :FusionExecutorCache::getKernelRuntimeFor diff --git a/tests/python/llama_inf_tests/graph_2.py b/tests/python/llama_inf_tests/graph_2.py new file mode 100644 index 00000000000..0ea9137d433 --- /dev/null +++ b/tests/python/llama_inf_tests/graph_2.py @@ -0,0 +1,158 @@ +import torch +from nvfuser import FusionDefinition, DataType +import time + +def nvfuser_fusion_id2(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 32, 6, 64], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0]) + T1 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T2 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T3 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T4 = fd.define_tensor(shape=[8192, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T5 = fd.define_tensor(shape=[8192, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T6 = fd.define_tensor(shape=[2048, 8192], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T7 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T8 = fd.define_tensor(shape=[128256, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) + T9 = fd.ops.permute(T0, dims=[0, 2, 1, 3]) + T10 = fd.ops.stride_order(T9, stride_order=[3, 2, 1, 0]) + T15 = fd.ops.reshape(T10, new_shape=[1, 6, 2048]) + T16 = fd.ops.stride_order(T15, stride_order=[2, 1, 0]) + T17 = fd.ops.linear(T16, T1) + T18 = fd.ops.cast(T2, dtype=DataType.Float) + T19 = fd.ops.cast(T17, dtype=DataType.Float) + T20 = fd.ops.add(T18, T19) + S21 = fd.define_scalar(2.00000, dtype=DataType.Double) + T22 = fd.ops.pow(T20, S21) + T23 = fd.ops.sum(T22, dims=[2], keepdim=False, dtype=DataType.Null) + T28 = fd.ops.broadcast_in_dim(T23, shape=[1, 6, 1], broadcast_dims=[0, 1]) + S29 = fd.define_scalar(2048.00, dtype=DataType.Double) + S30 = fd.ops.reciprocal(S29) + T31 = fd.ops.mul(T28, S30) + S32 = fd.define_scalar(1.00000e-05, dtype=DataType.Double) + T33 = fd.ops.add(T31, S32) + T34 = fd.ops.rsqrt(T33) + T39 = fd.ops.broadcast_in_dim(T34, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2]) + T40 = fd.ops.mul(T20, T39) + T45 = fd.ops.broadcast_in_dim(T3, shape=[1, 6, 2048], broadcast_dims=[2]) + T46 = fd.ops.cast(T45, dtype=DataType.Float) + T47 = fd.ops.mul(T46, T40) + T48 = fd.ops.cast(T47, dtype=DataType.BFloat16) + T49 = fd.ops.linear(T48, T4) + T50 = fd.ops.cast(T49, dtype=DataType.Float) + T51 = fd.ops.neg(T50) + T52 = fd.ops.exp(T51) + S53 = fd.define_scalar(1.00000, dtype=DataType.Double) + T54 = fd.ops.add(S53, T52) + T55 = fd.ops.reciprocal(T54) + T56 = fd.ops.mul(T50, T55) + T57 = fd.ops.linear(T48, T5) + T58 = fd.ops.cast(T57, dtype=DataType.Float) + T59 = fd.ops.mul(T56, T58) + T60 = fd.ops.cast(T59, dtype=DataType.BFloat16) + T61 = fd.ops.linear(T60, T6) + T62 = fd.ops.cast(T61, dtype=DataType.Float) + T63 = fd.ops.add(T20, T62) + S64 = fd.define_scalar(2.00000, dtype=DataType.Double) + T65 = fd.ops.pow(T63, S64) + T66 = fd.ops.sum(T65, dims=[2], keepdim=False, dtype=DataType.Null) + T71 = fd.ops.broadcast_in_dim(T66, shape=[1, 6, 1], broadcast_dims=[0, 1]) + S72 = fd.define_scalar(2048.00, dtype=DataType.Double) + S73 = fd.ops.reciprocal(S72) + T74 = fd.ops.mul(T71, S73) + S75 = fd.define_scalar(1.00000e-05, dtype=DataType.Double) + T76 = fd.ops.add(T74, S75) + T77 = fd.ops.rsqrt(T76) + T82 = fd.ops.broadcast_in_dim(T77, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2]) + T83 = fd.ops.mul(T63, T82) + T88 = fd.ops.broadcast_in_dim(T7, shape=[1, 6, 2048], broadcast_dims=[2]) + T89 = fd.ops.cast(T88, dtype=DataType.Float) + T90 = fd.ops.mul(T89, T83) + T91 = fd.ops.cast(T90, dtype=DataType.BFloat16) + T92 = fd.ops.linear(T91, T8) + fd.add_output(T92) + +with FusionDefinition() as fd: + nvfuser_fusion_id2(fd) + +inputs = [ + torch.randn(12288, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 6, 64), (12288, 64, 2048, 1)), + torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((8192, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((8192, 2048), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048, 8192), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'), + torch.testing.make_tensor((128256, 2048), dtype=torch.bfloat16, device='cuda:0'), +] + +fd.execute(inputs) + + +for _ in range(3): + fd.execute(inputs) + +torch.cuda.synchronize() +start = time.time() +# Mark the profiling region +torch.cuda.cudart().cudaProfilerStart() + +for _ in range(100): + fd.execute(inputs) + +torch.cuda.cudart().cudaProfilerStop() +torch.cuda.synchronize() +end = time.time() +print((end-start)*1000, " ms") + +# Before: +# 18.9 ms +# After: +# 18.8 ms + + +# rm report* +# nsys profile -c cudaProfilerApi python tests/python/llama_inf_tests/graph_2.py +# nsys stats report1.nsys-rep + +# Before: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 14.3 21273988 100 212739.9 191180.5 179102 647287 68515.0 PushPop :FusionExecutorCache::runFusionWithInputs +# 13.9 20711603 100 207116.0 185424.0 174989 625555 67241.1 PushPop :FusionKernelRuntime::runWithInputs +# 13.9 20634952 100 206349.5 184704.0 174362 623140 67108.8 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.1 19477550 900 21641.7 19736.5 5253 229134 19235.9 PushPop :FusionKernelRuntime::runKernelWithInput +# 12.8 18979906 900 21088.8 19402.5 5008 228699 18181.9 PushPop :ExecutorDispatch::run2 +# 9.1 13569155 2100 6461.5 3299.5 1250 188373 8072.4 PushPop :ExpressionEvaluator::evaluate +# 6.8 10071317 600 16785.5 16953.0 4816 226456 12582.7 PushPop :ExprEvalExecutor::run +# 5.8 8593835 300 28646.1 23748.5 18304 209021 24037.1 PushPop :KernelExecutor::runFusion +# 4.1 6042339 300 20141.1 17470.5 12833 200139 18266.7 PushPop :KernelExecutor::runFusion::execute_kernel +# 3.2 4802005 300 16006.7 13105.5 9464 195217 18063.1 PushPop :KernelExecutor::recomputeArgs +# 0.7 1083270 300 3610.9 3488.0 2803 9530 804.5 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.7 1066491 300 3555.0 2310.5 1934 173206 9901.1 PushPop :fusion_executor::allocations::allocateOutputs +# 0.7 1059251 900 1176.9 864.0 534 174544 5806.8 PushPop :executor_utils::bindInputs +# 0.5 753282 400 1883.2 1430.0 1237 169947 8427.1 PushPop :fusion_executor::allocations::allocateTensor +# 0.1 168892 900 187.7 147.0 103 1888 115.3 PushPop :ExecutorDispatch::isCompiled +# 0.1 154076 100 1540.8 1415.5 1012 6127 537.9 PushPop :FusionExecutorCache::setCacheId +# 0.0 51330 300 171.1 114.5 97 1893 169.0 PushPop :KernelExecutor::runFusion::intermediates +# 0.0 19650 100 196.5 156.0 117 845 109.3 PushPop :FusionExecutorCache::getKernelRuntimeFor + +# After: +# Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range +# -------- --------------- --------- -------- -------- -------- -------- ----------- ------- ---------------------------------------------- +# 16.0 16373962 100 163739.6 141242.5 134284 628006 71340.5 PushPop :FusionExecutorCache::runFusionWithInputs +# 15.1 15382598 100 153826.0 136014.5 129986 600299 64199.4 PushPop :FusionKernelRuntime::runWithInputs +# 15.0 15308089 100 153080.9 135307.0 129396 597501 64038.0 PushPop :FusionKernelRuntime::runSegmentsWithInputs +# 13.8 14094916 900 15661.0 16073.0 5213 251448 17412.1 PushPop :FusionKernelRuntime::runKernelWithInput +# 13.3 13579134 900 15087.9 15684.5 4944 251078 16245.8 PushPop :ExecutorDispatch::run2 +# 9.7 9923498 600 16539.2 16927.0 4741 250699 14459.9 PushPop :ExprEvalExecutor::run +# 9.4 9632237 900 10702.5 13644.5 1314 248454 13132.5 PushPop :ExpressionEvaluator::evaluate +# 3.3 3330448 300 11101.5 8723.0 6965 201143 18774.3 PushPop :KernelExecutor::runFusion +# 1.1 1129730 300 3765.8 3564.0 2899 11811 941.0 PushPop :KernelExecutor::runFusion::execute_kernel +# 1.0 1013555 300 3378.5 3257.5 2673 10550 813.0 PushPop :ExecutorRunFusion::cuLaunchKernel +# 0.9 917298 300 3057.7 1641.0 1451 193374 11094.4 PushPop :fusion_executor::allocations::allocateOutputs +# 0.5 545570 600 909.3 832.0 531 11018 553.9 PushPop :executor_utils::bindInputs +# 0.4 374638 900 416.3 150.0 101 199723 6652.2 PushPop :ExecutorDispatch::isCompiled +# 0.2 204834 100 2048.3 148.5 110 185215 18502.2 PushPop :FusionExecutorCache::getKernelRuntimeFor +# 0.2 166768 300 555.9 478.5 341 5239 394.4 PushPop :KernelExecutor::computeArgs2 +# 0.2 165881 100 1658.8 1535.0 1125 7627 674.0 PushPop :FusionExecutorCache::setCacheId +# 0.1 57354 300 191.2 111.0 97 2077 236.0 PushPop :KernelExecutor::runFusion::intermediates