Skip to content

Skip host synchronization when it is safe to do so#419

Closed
pxl-th wants to merge 8 commits intomasterfrom
pxl-th/wait
Closed

Skip host synchronization when it is safe to do so#419
pxl-th wants to merge 8 commits intomasterfrom
pxl-th/wait

Conversation

@pxl-th
Copy link
Member

@pxl-th pxl-th commented Apr 30, 2023

Closes #405.

Since we use HSA for Julia kernels & HIP for other library calls (gemm, etc.) we have to perform wait at hsa | hip boundary, for example:
(x .* x) * x
hsa | hip

However, when using only hsa or only hip kernels in a row we can rely on on-device serialization and skip host wait.
This allows us to dispatch kernels asynchronously.

The only restriction is that you have to use the same HSA queue or HIP stream.
Which is fine, since we've moved to TLS.

Changes to the SyncState are the following:

  • Maintain same_queue & same_stream to check if all signals or streams belong to the same queue / are equal.
  • If either same_queue == false or same_stream == false fallback to the old wait behavior.
  • Otherwise, if SyncState contains only HSA signals and we are dispatching another HSA kernel, skip wait! if it is called.
  • To signal that you are dispatching another HSA kernel or doing another HIP call and want to skip (potentially) host wait, use hsa_wait! or hip_wait!.

hsa_wait! always waits for any HSA signal if it is present in a SyncState. It is meant to be used right before HIP library call, e.g. before gemm.
hip_wait! always waits for any HIP stream if it is present in a SyncState. It is meant to be used right before HSA kernel dispatches, e.g. inside @roc macro.

  • When skipping host wait, for example for HSA, remove all HSA signals from SyncState except the last one. This is to ensure we synchronize if the next OP is HIP library call.

  • Avoid duplication in SyncState. Code like broadcast!(cos, x, x) previously would push same signal twice into x's SyncState.

  • Synchronize on HIPEvent instead of HIPStream for HIP-based libraries.
    HIPEvent is created at the moment of mark!.

Code

using BenchmarkTools
using AMDGPU

x = ROCArray(rand(Float32, 128, 128))

function matmul(x)
    t1 = x * x
    t2 = x * t1
    t3 = x * t2
    AMDGPU.synchronize()
    AMDGPU.unsafe_free!(t1)
    AMDGPU.unsafe_free!(t2)
    AMDGPU.unsafe_free!(t3)
    return
end

function el_mul(x)
    t1 = x .* x
    t2 = x .* t1
    t3 = x .* t2
    AMDGPU.synchronize()
    AMDGPU.unsafe_free!(t1)
    AMDGPU.unsafe_free!(t2)
    AMDGPU.unsafe_free!(t3)
    return
end

Benchmarks

Without final synchronization (measuring dispatch times)

Before:

julia> @benchmark matmul(x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   79.702 μs   1.990 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     114.373 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   139.584 μs ± 42.262 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

     ▅▆▂           ▃█▃                     ▁       ▅▆▂  ▆▆   ▄ ▂
  ▃▇▇████▆█▆▅▅▆█▆▅▅████▆▇▇▆▆▃▃▃▁▁▄▃▃▄▄▇▅▄▃▆█▇▃▄▅▅▄▆████▇███▇▆█ █
  79.7 μs       Histogram: log(frequency) by time       189 μs <

 Memory estimate: 2.44 KiB, allocs estimate: 51.

julia> @benchmark el_mul(x)
BenchmarkTools.Trial: 5051 samples with 1 evaluation.
 Range (min  max):  471.790 μs  46.804 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     727.775 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   987.706 μs ±  2.229 ms  ┊ GC (mean ± σ):  0.15% ± 0.36%

  █▇                                                           ▁
  ██▅▁▃▁▁▁▁▁▁▁▁▁▁▃▁▃▁▁▁▁▁▁▁▃▁▁▃▃▄▁▃▁▁▁▁▁▁▁▁▁▁▃▄▆▅▁▄▃▁▁▁▃▁▃▅▁▄▇ █
  472 μs        Histogram: log(frequency) by time      16.6 ms <

 Memory estimate: 16.92 KiB, allocs estimate: 252.

After:

julia> @benchmark matmul(x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  22.440 μs  267.228 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     24.511 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   24.658 μs ±   2.835 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

       ▃▅▆▆▆▅▃▁▃▄███▅▃                                          
  ▂▂▄▆█████████████████▇▆▆▅▅▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂ ▄
  22.4 μs         Histogram: frequency by time           31 μs <

 Memory estimate: 3.48 KiB, allocs estimate: 77.

julia> @benchmark el_mul(x)
BenchmarkTools.Trial: 8878 samples with 1 evaluation.
 Range (min  max):  108.223 μs  26.969 ms  ┊ GC (min  max): 0.00%  37.01%
 Time  (median):     315.509 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   558.149 μs ±  1.562 ms  ┊ GC (mean ± σ):  0.79% ±  1.38%

  ▇█▃▃▂                                                        ▁
  ██████▇▆▆▇▆▅▅▄▄▁▁▃▁▃▁▁▁▄▄▃▄▁▄▃▁▁▃▁▃▁▁▁▁▁▃▁▃▃▁▁▄▃▁▃▅▆▆▆▅▆▆▅▅▆ █
  108 μs        Histogram: log(frequency) by time      9.96 ms <

 Memory estimate: 18.34 KiB, allocs estimate: 296.

With final synchronization

Before:

julia> @benchmark matmul(x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   91.804 μs  784.200 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     116.636 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   120.744 μs ±  31.422 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

     ▁ ▂▂▆█▄▂▂▂▂▁                                               ▁
  ▅▆▇█████████████▆▆▅▅▄▄▄▃▅▄▁▄▁▅▁▃▄▁▃▄▄▁▃▁▁▃▁▃▁▄▁▃▁▃▄▁▃▃▃▁▄▃▃▁▄ █
  91.8 μs       Histogram: log(frequency) by time        278 μs <

 Memory estimate: 2.58 KiB, allocs estimate: 54.

julia> @benchmark el_mul(x)
BenchmarkTools.Trial: 5038 samples with 1 evaluation.
 Range (min  max):  517.035 μs  27.458 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     571.632 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   989.842 μs ±  2.346 ms  ┊ GC (mean ± σ):  0.18% ± 0.71%

  █                                                             
  ██▆▄▅▅▃▄▅▃▃▁▁▃▄▄▃▄▄▃▄▁▁▃▃▄▁▄▄▁▅▄▃▁▄▄▃▄▄▁▄▃▃▁▅▃▅▅▆▅▆▅▅▅▃▅▅▄▄▄ █
  517 μs        Histogram: log(frequency) by time      16.1 ms <

 Memory estimate: 17.08 KiB, allocs estimate: 256.

After:

julia> @benchmark matmul(x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  45.983 μs   16.096 ms  ┊ GC (min  max): 0.00%  76.16%
 Time  (median):     56.414 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   60.288 μs ± 162.063 μs  ┊ GC (mean ± σ):  2.03% ±  0.76%

      ▁  ▁ ▁   ▁▇█▇▄▂▃▂▂                                       ▂
  ▅▅▅▇███████▇▇████████████████▇▇█▇▆▆▆▆▅▅▄▄▅▃▄▄▃▄▃▄▄▄▃▄▁▃▄▃▁▄▃ █
  46 μs         Histogram: log(frequency) by time      86.4 μs <

 Memory estimate: 3.62 KiB, allocs estimate: 80.

julia> @benchmark el_mul(x)
BenchmarkTools.Trial: 5501 samples with 1 evaluation.
 Range (min  max):  494.148 μs  28.415 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     552.011 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   907.265 μs ±  2.274 ms  ┊ GC (mean ± σ):  0.37% ± 0.92%

  █                                                            ▁
  █▇▅▄▄▄▁▄▄▁▁▁▁▁▁▄▃▁▁▃▃▁▁▁▅▃▃▁▁▄▁▃▁▁▃▁▁▁▃▃▃▃▄▃▄▅▁▄▅▅▅▄▃▅▄▅▃▄▄▆ █
  494 μs        Histogram: log(frequency) by time      16.4 ms <

 Memory estimate: 18.50 KiB, allocs estimate: 300.

@pxl-th pxl-th marked this pull request as ready for review May 2, 2023 09:21
@pxl-th
Copy link
Member Author

pxl-th commented May 2, 2023

Nerf.jl benchmark (1000 training steps)

  • Before: 103.576595 seconds (18.76 M allocations: 741.180 MiB, 3.55% gc time)
  • After: 85.742314 seconds (21.02 M allocations: 782.243 MiB, 1.71% gc time, 0.01% compilation time)

Flux model inference (private repo)

  • Before: 60 seconds for an inference
  • After: 44 seconds for an inference

@pxl-th
Copy link
Member Author

pxl-th commented Jun 16, 2023

Superseded by #423

@pxl-th pxl-th closed this Jun 16, 2023
@pxl-th pxl-th deleted the pxl-th/wait branch July 6, 2023 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimize wait! for HSA kernel launches

1 participant