Skip to content

Layout propagation#1788

Merged
jjsjann123 merged 55 commits intomainfrom
layout_propagation_pr_0
Feb 20, 2024
Merged

Layout propagation#1788
jjsjann123 merged 55 commits intomainfrom
layout_propagation_pr_0

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Feb 19, 2024

Stacked PRs:
#1755 enabling layout propagation through runtime
#1792 propagation rule for broadcast
#1790 propagation rule for binary op
==== #1788 adding layout inference pass <- this one

What's in this PR:
inferenceAllocationOrder pass that works on an entire Fusion:
It computes AllocationOrder on inputs by looking at each TensorView's allocation_domain and rfactor_domain;
It uses a predefined rule (in AllocationOrderInferencer) to traverse and propagate AllocationOrder from inputs to the entire fusion;

Note that the pass itself doesn't mutate the fusion IR. It's just a utility function that suggests ways to specify allocation domain to be used by other optimization passes.

  • adding inferenceAllocationOrder pass function;
  • adding propagate rule for unary op;
  • adding cpp test to verify propagation rule;

Quick design doc: #1756

Future Work:

  • expanding propagation rule to cover more operation;

@jjsjann123
Copy link
Collaborator Author

!build

@jjsjann123
Copy link
Collaborator Author

!build

@jjsjann123
Copy link
Collaborator Author

Failing test isn't related. merge as-is.

@jjsjann123 jjsjann123 merged commit 7d2740c into main Feb 20, 2024
@jjsjann123 jjsjann123 deleted the layout_propagation_pr_0 branch February 20, 2024 23:42
jjsjann123 added a commit that referenced this pull request Feb 21, 2024
Stacked PRs:
==== #1755 enabling layout propagation through runtime <- **_this one_**
#1792 propagation rule for broadcast
#1790 propagation rule for binary op
#1788 adding layout inference pass

What's in this PR:
Enabling the MemoryFormat optimization pass in runtime. The pass is run
as part of pre_segment optimization pass.
Adding cpp test to verify optimization behavior

Quick design doc: #1756

---------

Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
jjsjann123 added a commit that referenced this pull request Feb 21, 2024
Stacked PRs:
#1755 enabling layout propagation through runtime
#1792 propagation rule for broadcast
==== #1790 propagation rule for binary op **_<- this one_**
#1788 adding layout inference pass

What's in this PR:

BinaryOp propagation tries to merge the allocation order of both inputs:
* when there's only one operand is a tensor, we just forward the
recorded allocation order
* when both operands are tensors, we resolve it by:
    i. prioritize the tensor with less broadcast iterdomain;
    ii. otherwise, we just propagate the allocation order of lhs.

Propagation rule for binary operation, 

- [x] adding propagate rule for binary op;
- [x] handling two scalar;
- [x] handling intermediate tensors (factory tensor);
- [x] adding cpp test to verify propagation rule;

---------

Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
jjsjann123 added a commit that referenced this pull request Feb 21, 2024
Stacked PRs:
#1755 enabling layout propagation through runtime
==== #1792 propagation rule for broadcast **_<- this one_**
#1790 propagation rule for binary op
#1788 adding layout inference pass

What's in this PR:

BroadcastOp propagation tries to push all new broadcast iterdomain as
outer dimensions for the output tensor.

- [x] adding propagate rule for broadcast op;
- [x] adding cpp test to verify propagation rule;

---------

Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants