Skip to content

Conversation

@Hzfengsy
Copy link
Member

@Hzfengsy Hzfengsy commented Apr 26, 2024

New implementation for cumsum for axis=-1, which is the common case in LLM sampling and MoE inference)

Tested CUDA, Vulkan on RTX 3080, and Metal on Apple M1 Pro. The new implementation is faster or equal vendor-libraries on small size (<128k), and nearly 10x faster than topi implementation in all cases

batch size cuda-new cuda-thrust cuda-old vulkan-new vulkan-old metal-new metal-mlx metal-old
1 16 6.25 7.97 67.37 9.44 45.86 58.94 177.03 292.87
1 32 6.12 8.10 82.38 8.44 55.43 49.31 174.33 309.60
1 64 6.39 8.43 97.46 8.72 61.21 58.55 165.82 359.07
1 128 6.33 8.43 113.11 9.15 70.78 58.92 163.98 409.17
1 256 6.42 8.38 128.70 9.26 79.22 59.98 164.43 451.95
1 512 8.32 8.24 144.36 12.82 88.85 82.05 167.91 483.10
1 1024 8.52 8.54 160.91 12.88 101.88 79.05 178.71 537.88
1 2048 8.38 8.22 177.79 13.41 109.76 76.88 191.39 575.48
1 4096 8.58 8.44 195.53 13.51 128.23 79.58 195.81 631.32
1 8192 8.47 8.44 213.76 13.41 135.06 86.05 198.75 680.58
1 16384 8.62 8.48 233.33 13.63 150.49 84.88 208.80 718.50
1 32768 8.30 8.38 253.40 14.07 155.97 88.79 233.70 776.12
1 65536 8.53 8.49 274.29 14.71 167.76 90.55 262.51 834.81
1 131072 8.37 8.33 298.67 15.81 189.46 104.44 381.32 885.58
1 262144 12.91 8.42 341.27 30.34 205.42 143.01 465.27 948.23
1 524288 17.06 9.02 424.09 34.39 271.09 149.22 724.91 1049.32
2 16 6.41 8.90 67.47 8.77 42.70 57.66 165.14 251.76
2 32 6.06 8.63 82.67 8.51 48.47 57.67 171.58 302.67
2 64 6.37 9.02 97.77 8.83 58.16 57.87 176.88 346.19
2 128 6.11 8.60 113.27 9.04 70.47 59.19 171.07 388.83
2 256 8.13 9.14 129.01 9.29 77.00 58.96 170.68 435.14
2 512 8.37 8.90 145.17 12.83 87.47 81.19 181.40 483.08
2 1024 8.38 9.04 160.99 12.81 100.23 80.46 159.27 514.73
2 2048 8.12 8.78 178.08 12.93 113.50 80.18 188.20 580.38
2 4096 8.53 9.06 195.80 13.16 119.79 84.44 192.16 626.83
2 8192 8.54 9.14 214.19 13.08 134.16 84.65 192.03 672.50
2 16384 8.64 9.25 233.42 13.35 141.18 83.34 225.20 725.09
2 32768 7.90 8.81 254.05 13.89 160.75 87.98 248.16 777.93
2 65536 8.46 9.03 278.40 15.70 176.83 87.15 273.48 854.96
2 131072 8.99 8.64 321.05 16.81 200.23 101.25 333.40 890.80
2 262144 16.37 10.65 403.68 33.10 239.63 145.10 463.82 1211.26
2 524288 27.62 20.50 563.95 45.57 321.47 194.45 744.36 1825.24
4 16 6.36 9.83 68.00 8.91 42.77 53.82 170.16 265.54
4 32 6.40 8.93 83.37 8.63 52.10 56.97 172.82 303.51
4 64 6.40 8.88 98.87 8.70 61.83 57.81 174.93 342.14
4 128 6.41 8.83 114.38 9.34 71.56 58.35 178.79 396.11
4 256 6.29 8.95 130.44 9.35 74.58 59.26 174.94 449.76
4 512 8.42 8.95 145.96 13.01 89.83 82.67 184.15 490.14
4 1024 8.25 8.96 162.27 13.02 95.91 81.06 194.36 529.71
4 2048 8.39 9.04 179.47 13.10 103.76 85.09 204.80 585.84
4 4096 8.36 8.92 197.21 13.15 116.48 86.59 196.52 630.19
4 8192 8.51 9.09 215.16 13.63 133.20 86.84 204.69 694.02
4 16384 8.25 8.71 234.65 13.95 151.83 86.35 211.78 759.87
4 32768 8.35 9.01 258.72 15.58 170.70 100.12 235.80 814.38
4 65536 9.04 8.86 301.41 17.44 187.34 118.45 268.46 892.72
4 131072 12.39 10.60 383.88 25.88 242.64 125.52 335.74 1285.37
4 262144 26.84 20.48 544.06 47.22 316.59 189.70 476.14 2043.67
4 524288 60.45 34.56 932.24 77.57 557.80 257.83 782.61 3472.33
8 16 6.11 8.71 68.00 8.83 42.57 48.73 183.02 250.98
8 32 6.45 9.15 83.50 8.92 51.54 55.76 174.54 298.05
8 64 6.07 8.69 99.05 8.76 60.45 58.85 173.86 349.45
8 128 6.41 9.17 115.08 9.34 68.63 59.84 177.25 402.53
8 256 6.06 8.64 130.46 9.19 74.15 59.05 188.81 455.95
8 512 8.46 9.11 145.96 13.04 89.17 82.91 172.61 486.68
8 1024 8.11 8.61 162.29 12.88 96.17 81.61 190.08 558.30
8 2048 8.62 9.04 179.62 13.36 110.39 81.73 202.68 599.61
8 4096 8.35 8.96 197.18 13.22 126.85 76.91 202.04 656.39
8 8192 8.48 8.94 215.54 14.04 138.62 93.70 201.95 705.73
8 16384 8.08 8.59 238.38 15.50 155.09 93.64 225.18 745.43
8 32768 9.05 9.11 280.93 17.49 183.07 106.23 258.64 918.06
8 65536 12.36 10.62 363.60 21.97 225.67 123.30 305.32 1386.03
8 131072 22.76 20.52 522.70 33.03 309.12 165.73 354.11 2246.53
8 262144 59.72 34.62 911.46 74.96 545.36 282.01 520.06 3854.05
8 524288 109.59 62.37 1610.09 126.90 964.50 527.66 860.12 7112.14
16 16 6.31 8.90 68.13 8.90 42.36 57.37 173.51 258.43
16 32 6.51 9.02 83.69 8.84 53.61 57.34 186.03 316.75
16 64 6.11 8.66 99.70 8.70 59.64 58.21 177.09 365.92
16 128 6.39 9.04 115.09 9.27 68.49 60.20 179.41 417.18
16 256 6.28 9.01 130.46 9.31 76.14 49.81 180.26 464.03
16 512 8.32 8.98 146.01 12.97 88.49 80.68 189.09 516.36
16 1024 8.32 9.05 162.39 13.05 99.48 79.52 220.84 565.10
16 2048 8.38 8.88 179.54 13.29 111.38 87.14 217.45 601.32
16 4096 8.46 8.98 197.37 14.18 125.31 85.47 201.06 675.17
16 8192 8.48 9.05 219.23 15.67 141.25 97.23 209.27 753.48
16 16384 9.04 8.78 260.45 17.47 176.67 115.19 214.45 992.25
16 32768 12.37 10.61 342.67 22.25 204.74 127.93 256.51 1510.12
16 65536 22.77 20.48 502.90 32.95 292.88 154.95 289.82 2458.10
16 131072 55.75 34.02 891.14 64.40 533.90 229.87 372.80 4316.44
16 262144 108.69 62.55 1588.98 126.13 949.58 509.06 365.06 7999.45
16 524288 207.09 116.06 2950.06 223.72 1780.71 879.54 600.47 15989.44
32 16 6.25 8.77 68.29 8.76 43.32 57.91 175.19 258.41
32 32 6.42 9.06 84.36 8.81 48.97 56.20 163.83 321.86
32 64 6.18 8.60 99.73 8.62 59.35 59.32 158.40 364.88
32 128 6.39 9.10 115.09 9.33 64.93 58.94 181.50 411.87
32 256 6.17 8.66 130.57 9.21 79.81 57.93 176.96 457.31
32 512 8.37 9.01 146.09 19.36 89.57 84.12 186.84 512.48
32 1024 8.05 8.66 162.51 12.91 100.45 83.54 204.85 567.89
32 2048 8.48 8.95 179.96 14.20 116.00 84.92 220.63 623.03
32 4096 8.37 8.99 201.32 15.25 127.29 109.70 207.02 693.30
32 8192 9.02 9.04 241.51 17.59 159.22 122.43 220.55 1034.00
32 16384 12.30 10.61 322.61 22.08 195.73 125.69 241.14 1618.62
32 32768 22.67 20.49 482.70 33.49 284.26 153.76 273.32 2678.89
32 65536 55.58 34.58 871.77 63.65 517.39 231.97 328.83 4778.87
32 131072 104.57 62.56 1570.44 114.61 937.68 451.23 349.83 8878.56
32 262144 206.36 117.31 2932.04 223.22 1765.90 898.76 570.13 17701.32
32 524288 402.89 223.87 5598.63 418.52 3430.84 1666.30 1184.20 34668.30
64 16 6.28 8.83 69.03 8.98 43.40 57.84 163.56 271.59
64 32 6.35 9.06 84.46 8.84 49.50 58.20 144.15 325.09
64 64 6.23 8.74 99.90 8.88 59.47 58.65 138.63 369.52
64 128 6.35 9.01 115.32 9.45 70.82 56.34 144.73 420.50
64 256 6.36 8.97 130.83 9.40 77.52 56.23 177.07 463.65
64 512 8.42 9.02 146.82 13.15 96.55 84.53 205.03 539.64
64 1024 8.51 9.05 163.50 14.20 106.52 84.52 222.60 600.26
64 2048 8.34 8.92 184.88 15.71 125.38 79.54 227.01 714.73
64 4096 9.00 8.89 225.64 17.22 154.40 82.62 220.93 1072.33
64 8192 12.26 10.59 305.93 21.41 198.65 137.26 252.04 1706.42
64 16384 22.94 20.53 464.34 32.96 281.64 204.45 275.30 2864.55
64 32768 55.38 34.49 855.93 65.21 517.36 265.17 325.48 5193.89
64 65536 104.33 62.45 1553.73 113.07 931.52 467.74 343.35 9767.25
64 131072 202.20 116.94 2915.69 212.45 1756.54 877.37 598.26 19409.01
64 262144 401.75 223.63 5579.62 419.43 3424.43 1633.52 934.39 38109.44
64 524288 793.03 439.06 10869.96 813.65 6753.53 3066.89 1721.42 75524.41

@Hzfengsy
Copy link
Member Author

cc @tqchen @vinx13

@tqchen tqchen merged commit 278a6af into apache:main Apr 26, 2024
size = (8, 2000)
np_data = np.random.randint(0, 10, size).astype("int32")
np_cumsum = np.cumsum(np_data, axis=-1)
for target in ["cuda", "vulkan -supports_int64=1"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Use @tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1") instead of looping over each target. This performs each test case in a separate pytest environment,

  • Exercises each test in a separate pytest case. Can distinguish between failure on one specific backend as opposed to failure on every backend.
  • Applies the appropriate @tvm.testing.requires_* marks for each target. Currently, this test would fail if a developer runs it with set(USE_CUDA ON) and set(USE_VULKAN OFF).
@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1")
def test_dispatch_cumsum_gpu(target, dev):
   ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in #16947

shape = call.struct_info.shape
kwargs = {}
if (
(axis == -1 or axis == len(shape) - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tensors of unknown shape, the shape field is none. Instead of len(call.struct_info.shape), can we use call.struct_info.ndim? (Alternatively, since it looks like the implementation requires an explicit shape in order to apply a reshape, we could add shape is not None to this condition.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the caching. Unfortunately, the original implementation does not support unknown shape. I added a check in the pass

@Lunderberg
Copy link
Contributor

Lunderberg commented Apr 26, 2024

Looks like I took too long to review. I think the changes requested should probably be made in a follow-up PR.

Hzfengsy pushed a commit to Hzfengsy/tvm that referenced this pull request Apr 28, 2024
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let
binding. This PR fixes the issue.

BTW, this PR addresses the comments in apache#16934
Hzfengsy pushed a commit to Hzfengsy/tvm that referenced this pull request Apr 29, 2024
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let
binding. This PR fixes the issue.

BTW, this PR addresses the comments in apache#16934
Hzfengsy pushed a commit to Hzfengsy/tvm that referenced this pull request May 1, 2024
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let
binding. This PR fixes the issue.

BTW, this PR addresses the comments in apache#16934
Hzfengsy pushed a commit to Hzfengsy/tvm that referenced this pull request May 5, 2024
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let
binding. This PR fixes the issue.

BTW, this PR addresses the comments in apache#16934
tqchen pushed a commit that referenced this pull request May 6, 2024
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let
binding. This PR fixes the issue.

BTW, this PR addresses the comments in #16934
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants