-
Notifications
You must be signed in to change notification settings - Fork 79
Description
I'd like to facilitate a discussion on how to represent hierarchical programs in our IR. This topic has come up repeatedly in various contexts so there might be an opportunity to introduce a nice abstraction for it.
Current status
Our IR currently has IrContainer which is the base class for Fusion. IrContainer owns all the Statements in the "fusion", and the Fusion class mostly manages inputs and outputs; this includes updating val->uses() when expressions are registered, since this can change the reachability of each statement. Any other child classes of IrContainer are themselves children of Fusion (namely, kir::Kernel).
Exprs are the simplest subfusions
In some sense, the core function of a Fusion (managing inputs/outputs) overlaps that of Expr. Namely, both represent some computation that takes in a collection of Vals and outputs some other Vals. This is the only sense in which we currently support "sub-fusions"; i.e. since a Fusion contains multiple Exprs, we can think of it as a hierarchy of subfusions.
The purpose of this proposal is to represent subfusions at intermediate granularities between "entire fusion" and "single expression".
Use cases
Partitioning Fusions into tasks
Partitioning the graph and executing segments separately from one another. This is important for
- Segmentation. Currently segmentation requires a specialized container called
SegmentedFusionwhich tracks collections of edges and groups of expressions representing segments. - Tasks for multidevice programs
- Task groups for warp specialization Task parallelism and warp-specialization #210
During segmentation, to analyze each segment we have a guard class that swaps the fusion inputs and outputs so that analysis can use our standard traversal utilities. Instead, it might be useful to represent segments during segmentation as subfusions that overlay the base fusion. This idea could be generalized to handle task parallelism
ExpressionEvaluator patterns (to replace composite IR nodes)
Recently, we introduced IR nodes for matmul patterns (#2175, #2240, #2294). These have greatly simplified our ATen fallback evaluation mode, replacing complex pattern matching code. However, fusion of these patterns still requires decomposition (#2236) into smaller primitives. If we had a working subfusion system, we might do both at once at program definition. For example, our linear function might create BroadcastOp nodes as well as a MmaOp node, add bias with BinaryOp then cast to reduced precision using UnaryOp. All these ops would then be grouped into a subfusion that is identified as a "Linear" pattern so that evaluation of its output could easily be identified and computed using ATen.
Lambdas for generalizing reduction/scan
When discussing #622 recently, @naoyam pointed out that representing lambdas directly in the tensor IR tends to clutter and complicate the fusion (see #2307). With the right abstraction, we could potentially maintain a lambda function as a subfusion that is disconnected from the main Fusion inputs and outputs, so that we would no longer need special IterTypes and we could use a single op like FoldOp to represent a reduction or scan. Note that we could still schedule the lambda as desired but we would not need to take as much care with inlining and allocation concerns if it is disconnected and represented in this way.
Possible approach
One approach would be to keep the current system intact, but add a new Task type like this
class Task : PolymorphicBase {
public:
const std::vector<Val*>& inputs() const;
const std::vector<Val*>& outputs() const;
// compute uses of Val* within this Task
const std::vector<Expr*>& uses(Val*);
// ...
};Fusion could own these Tasks just like it currently owns Statements. Task could potentially be made a subclass of Statement to facilitate cloning.
We could generalize Val::isFusionInput() and Val::isFusionOutput() with the following:
class Val : public Statement {
public:
// ...
// compute tasks having val as an input
std::vector<Task*> taskUses() const;
// compute tasks having val as an output. Note that multiple Tasks might output the same
// Val as they may be nested.
std::vector<Task*> taskDefinitions() const;
};Does this address the use cases above?
- No translation would be required for the matmul use case since we would create the full decomposed fusion at definition, and also create a matmul
Taskat that time. TheExpressionEvaluatorcould use this to check for supported fallback patterns by first checking for supportedval->taskDefinitions()then checkingval->definition(). - Likewise, we could have a
class Segment : public Task {};so that creating and modifying segments doesn't alter the fusion but just modifiesTasks. Then checking whether aValis a segmentation edge just means finding a task inval->taskDefinitions()orval->taskUses()thatisA<Segment>(). - For fold/scan, e could create a
Taskto use as a lambda, then just attach it to ourFoldOporScanOpas an attribute. - I am not sure yet how this would impact multidevice scheduling or how it might help us implement warp specialization.
### Tasks
- [ ] Come up with a better name than "Task"
- [ ] Remove inheritance of `IrContainer` from `Fusion`. Instead, `Fusion` will hold a `shared_ptr<IrContainer>`, accessible via `IrContainer* container() const`.
- [ ] Create "`Task`" with example of a task-specific traversal
- [ ] Identify issues and address...
- [ ] Adapt segmentation to use "tasks"
- [ ] Adapt MatmulOp etc. to use tasks