-
Notifications
You must be signed in to change notification settings - Fork 79
Description
This command often fails on A100 and H100:
$ bin/nvfuser_bench --benchmark_filter='NvFuserScheduler_Matmul_Manual/nvfuser_splitk_NT/M:1024/N:2048/K:4096/warps:8/stages:3/splitk_factor:5/smem_epilogue:1/manual_time'
...
what(): Fusion returns wrong results! The result tensor has shape [..., 1024,2048]. Mismatch happens at region result[...,5:784,0:1544]
Exception raised from checkMatch at /opt/pytorch/nvfuser/benchmarks/cpp/matmul.cpp:104
Disabling index hoisting and expression simplification makes this go away, but I am 99% confident this is not a hoisting or simplification bug. Rather it seems to be a data race. The inputs are fixed (fixed seeds, I've verified the inputs and outputs are stable other than the mismatches). There are about 10 mismatches per million in the outputs, but when running in compute-sanitizer, or when slowed down via NVFUSER_DISABLE options, there are no mismatches. Running with
NVFUSER_DUMP=debug_info compute-sanitizer --tool racecheck indicates a data hazard in the circular buffering code that uses cp.async.
The hazard warning goes away for smem_epilogue:0 in which case we also don't get any checkMatch errors. For splitk_factor:1/smem_epilogue:1 we don't get checkMatch errors but the racecheck warnings are present. If I add asm volatile("cp.async.wait_group %0;\n"::"n"(0LL)); (or wait_all) after the main loop and just before the smem unswizzling loop, then the race warnings and checkMatch failures go away in all cases I've tested.
I believe this is due to smem reuse being unaware of cp.async, and potentially due to some additional copies that go unused during circular buffering. The smem aliasing logic looks at the lifetime of the smem tensor as defined by the last read expression, and inserts a __syncthreads() to ensure that all threads in the block have reached that last read before allowing aliased writes to that chunk of smem. However, there could still be writes to smem from the previous lifetime when using cp.async. I am not sure how to programmatically check for this, and I'd also be surprised that we are doing cp.async.cg.shared.global transfers that go unused.
We seem to be using AsyncAwait with keepStages()==1 so there is 1 unsynched copy and that copy can at times interfere with the smem epilogue buffer. I think what may need to happen is for the smem aliasing pass to insert (in addition to the __syncthreads() it currently inserts) an AsyncAwait with keepStages()==0 if it detects the smem was involved in some async copy. We could do this analysis pretty similarly to how we search for block syncs in smem aliasing I believe.