-
Notifications
You must be signed in to change notification settings - Fork 79
Description
Motivation
On Hopper, efficient gemm requires warp-specialization, which is not currently supported by nvFuser. This doc is to extend nvFuser in order to support such optimization. I believe this will not only benefit matmul, but also benefit other cases like optimal cat/stack scheduling, horizontal fusion, etc., see "Potential applications" section for more detail.
Design
Notation: I will mostly use the term "task parallelism" for the new thing being added to nvFuser. "Warp-specialization" is a special case of "vertical task parallelism" (described below) on thread index.
Partition of DAG as tasks
In order to use task parallelism, we need first partition the DAG into tasks. Tasks are non-overlapping and dense, that is, every Val in the fusion definition except fusion inputs belongs to a task (fusion inputs are special because they are given instead of computed), and one Val can only belong to one task. Initially, all Vals belong to task 0. Example partition:
Grouping tasks into a hierarchical structure
Tasks are further grouped into task groups. Task groups form a hierarchical structure.
Parallelization of task groups
A task group can be parallelized, for example
group1->parallelize(ParallelType::TIDy);
group3->parallelize(ParallelType::TIDz);Not all task groups can be parallelized. A parallelizable group is either a "horizontal group" or a "vertical group". A "horizontal group" is a group whose members have no data dependency with each other. For example, group 1 is a horizontal group. A "vertical group" is a group whose members are connected, for example group 3 is a vertical group.
Below is an example where group 4 is neither a horizontal group nor a vertical group:
However, you can make it a horizontal group by grouping group 2 and group 3 together:

Expression sorting
Expression sorting must be task and task group aware. For the above example, the sorted expressions can be
group 3which zoom into
group 1
group 2which zoom into
task 1
task 2
task 3
task 4
task 5
// the following order is also valid, and expr sort is free to choose one from all valid orders
task 3
task 4
task 1
task 2
task 5
// There are more valid orders...
// In later context, we assume the first order is pickedwhich zoom into
tv1 = sin(tv0);
tv5 = cos(tv1);
tv2 = tan(tv0);
tv6 = relu(tv2);
tv3 = exp(tv0);
tv7 = sigmoid(tv3);
tv4 = log(tv0);
tv8 = tanh(tv4);
tv9 = cat([tv5, tv6, tv7, tv8]);
tv10 = neg(tv9);Loop nest generation
When generating loop nest, for an unparallelized group, it just generate its members one after another. For parallelized groups, it will generate kir::IfThenElses to dispatch between its members.
Assuming tv0-tv8 has [BIDx, TIDy{size0}, TIDx], tv9 and tv10 has [BIDx, TIDy{size0*4}, TIDx] (the cat dim is 1, I am assuming the cat dim is untouched by the scheduler in this example), and inline most.
Then the generated loop nest structure will be
FOR blockIdx.x in BIDx:
IF threadIdx.z == 0:
IF threadIdx.y >= 0 && threadIdx.y < size0:
FOR threadIdx.y in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv1 = sin(tv0);
tv5 = cos(tv1);
ELSE IF threadIdx.y - size0 >= 0 && threadIdx.y - size0 < size0:
FOR threadIdx.y - size0 in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv2 = tan(tv0);
tv6 = relu(tv2);
ELSE IF threadIdx.y - 2*size0 >= 0 && threadIdx.y - 2*size0 < size0:
FOR threadIdx.y - 2*size0 in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv3 = exp(tv0);
tv7 = sigmoid(tv3);
ELSE IF threadIdx.y - 3*size0 >= 0 && threadIdx.y - 3*size0 < size0:
FOR threadIdx.y - 3*size0 in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv4 = log(tv0);
tv8 = tanh(tv4);
ELSE IF threadIdx.z == 1:
FOR threadIdx.y in TIDy{4*size0}:
FOR threadIdx.x in TIDx:
tv9 = cat([tv5, tv6, tv7, tv8]);
tv10 = neg(tv9);Synchronization
For the parallelization of horizontal task groups, synchronization must happen before and after the dispatch. Depending on the parallel type, block sync or grid sync might be needed. For the parallelization of vertical task groups (a.k.a. warp specialization), parallelization boundary (in this case tv5-tv8) must be double/circular buffered, and arrive-wait barrier is used for sync.
Potential applications
Efficient matmul on Hopper
Warp specialization is used, and we are doing load and mma+store in different warps.
Horizontal fusion
For example, we have multiple separate fusions, independently scheduled. If all fusions only uses BIDx but not BIDy and BIDz, then we can trivially horizontally fuse these fusions by partitioning each fusion as a task in the combined fusion and horizontal parallelize these tasks on BIDx.
Cat/stack schedule
For cat/stack, the size of the output tensor is naturally the sum of the size of inputs, we could parallelize the computation of inputs in a way like the parallelization of group 1 in the above example.


