Skip to content

Layout propagation#1744

Closed
jjsjann123 wants to merge 42 commits intomainfrom
layout_propagation
Closed

Layout propagation#1744
jjsjann123 wants to merge 42 commits intomainfrom
layout_propagation

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Feb 9, 2024

Stacked PRs:
#1755 enabling layout propagation through runtime
#1744 adding layout inference pass <- this one

What's in this PR:
The pass we have works on a Fusion IR:
It summaries MemoryFormat on inputs by looking at each TensorView's allocation_domain and rfactor_domain;
It uses a predefined rule (MemoryFormatInferencer) to propagate MemoryFormat from inputs to the entire fusion;

Note that the pass itself doesn't mutate the fusion IR. It's just a utility function that suggests ways to specify allocation domain to be used by other optimization passes.

  • adding some simple rule to propagate layout through fusion IR;
  • adding cpp test to verify propagation rule;

Quick design doc: #1756

Future Work:

  • expanding propagation rule to cover more operation;

@jjsjann123 jjsjann123 added the allocation domain issues related to allocation domain support label Feb 13, 2024
@jjsjann123
Copy link
Collaborator Author

!build

test/utils.h Outdated

// allows overload resolution with size-1 initializer list
inline TensorView* makeSymbolicTensor(
std::initializer_list<int64_t> shape,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is just added to allow overload of makeSymbolicTensor({-1}, ...), which would otherwise be called into makeSymbolicTensor(size_t, ...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to myself. Do a quick clean up for other APIs as well!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ syntax is weird...

// TV1 has b5 -> i4 -> i3
// we see that TV0 encounters a non-broadcast iter domain first, so TV0 is the
// dominating tensor. We'll produce an output with stride order identical to
// that of TV0 in the record.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinstephano This was what I was describing on the propagation rule for binary operations.

@jjsjann123
Copy link
Collaborator Author

😮‍💨 I realized I went with my old implementation and the memory order permutation is inconsistent with our stride_order in our python API.

Let me refactor that... 😱

@jjsjann123
Copy link
Collaborator Author

!build

@jjsjann123
Copy link
Collaborator Author

!build

@naoyam
Copy link
Collaborator

naoyam commented Feb 15, 2024

@jjsjann123 I'm a bit lost with what's addressed in this PR. According to your design doc, what'll be done are:

  1. It looks up the permutation from rfactor_dom to allocation_dom on input TensorViews and record the permutation as MemoryFormat for those tensors;
  2. The pass traverse the fusion to propagate MemoryFormat. It uses a set of propagation rules, where it compute & record MemoryFormat of outputs from the recorded MemoryFormat of inputs;
  3. Lastly, the pass iterates through all output tensors and try to specify their allocation domain as per recorded MemoryFormat.

Am I correct that this PR does items 1 and 2?

Also, what are propagation rules?

It uses a set of propagation rules, where it compute & record MemoryFormat of outputs from the recorded MemoryFormat of inputs;

private:
void handle(const UnaryOp*) override;
void handle(const BinaryOp*) override;
void handle(const BroadcastOp*) override;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam Propagation rules are specified here per operation. I'll add a note on the commit description.

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Feb 15, 2024

Am I correct that this PR does items 1 and 2?

Yes.
item 1 is done inside inferenceMemoryFormat function before it calls MemoryFormatInference to propagate it;
item 2 is done inside MemoryFormatInferencer which propagates memory format from inputs to the entire fusion.

@jjsjann123
Copy link
Collaborator Author

Failing test seems to be coming from #1743.
cc'ing @Priya2698

@jjsjann123 jjsjann123 requested a review from wujingyue February 15, 2024 01:16
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's convenient for you to git rebase and git add -p, I'd suggest separate the BinaryOp change to a different PR. That would reduce the size a lot to make review easier.

Comment on lines +50 to +52
if (auto iter = format_map_.find(in); iter != format_map_.end()) {
format_map_[out] = iter->second;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern seems to appear in multiple places in this file. Consider making it a helper. Maybe something like

copyFormat(from, to);

// e.g.
// lhs TV0 rfactor_dom [i0, i1, b2]
// 0 2 1
// rhs TV0 rfactor_dom [i3, i4, b5]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// rhs TV0 rfactor_dom [i3, i4, b5]
// rhs TV1 rfactor_dom [i3, i4, b5]

// TV0 has i1 -> b2 -> i0
// TV1 has b5 -> i4 -> i3
// we see that TV0 encounters a non-broadcast iter domain first, so TV0 is the
// dominating tensor. We'll produce an output with stride order identical to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so TV0 is the dominating tensor

Why are we in favor of the memory format that first hits a non-broadcast? (I suspect it's something about vectorization, but the comment wasn't clear about that)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question. Why is this better than just use lhs? @jjsjann123 Could you add the explanation to the code here as comment?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow feel that, what we should do is: if this binary op contains a broadcast concretization, then respect the one with most number of concrete IDs, otherwise, just use lhs. cc @naoyam

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are propagating only in the forward direction, it seems to me that we can't really know what will be the most advantageous stride order. For example if we later do a sum on some outer dimension then it might wind up that we would have preferred that dimension to be allocated inner-most, but we would need to propagate that information backwards. If we're sticking with forward-only, why not just use the first input's stride order for the output and call it a day? If we want to chase more optimality we could consider doing an iterative optimization on the segmented fusion, allowing the schedulers to specify weighted preferences for the allocation orderings of their inputs and propagating changes to the outputs using simple rules like the one here, but that optimization is a bigger change to tackle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we in favor of the memory format that first hits a non-broadcast? (I suspect it's something about vectorization, but the comment wasn't clear about that)

This one should have been updated. I was doing this earlier when I use a different propagation rule for broadcast, so I needed this trick to propagate nhwc.

tv0 = [i0 i1 i2 i3] @ {0 2 3 1}
bias0 = [i4] @ {0} -> broadcast_bias0 [b5 i4 b6 b7] @ {0 1 2 3}

But now I feel @zasdfgbnm's suggestion makes a lot more sense instead.

I somehow feel that, what we should do is: if this binary op contains a broadcast concretization, then respect the one with most number of concrete IDs, otherwise, just use lhs.

This makes a lot more sense to me. i.e. favoring larger tensor (hopefully more concrete IDs would lead to a larger tensor). I'll update.

// we see that TV0 encounters a non-broadcast iter domain first, so TV0 is the
// dominating tensor. We'll produce an output with stride order identical to
// that of TV0 in the record.
// In the event of a tie, we'll just propagate the memory format of lhs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is

i->b->i
i->i->b

considered a tie? I.e., do you care about just the first non-broadcast or the first difference in which case the non-broadcast wins? Either case, why?

@naoyam
Copy link
Collaborator

naoyam commented Feb 15, 2024

Can you please define what exactly the memory format means? Does it just mean the allocation domain?

@naoyam
Copy link
Collaborator

naoyam commented Feb 15, 2024

Can you please define what exactly the memory format means? Does it just mean the allocation domain?

I found a definition for tensors with an allocation domain:

// TensorView with allocatoin
//   domain that's a permutation of its corresponding rfactor domain and record
//   it as the memory format of the tensor;

What about tensors with no allocation domain?

@naoyam
Copy link
Collaborator

naoyam commented Feb 15, 2024

Do we just want to infer a preferred allocation domain of each output tensor?

How would you propagate a inferred format through reshape?

// unordered_map from TensorView to permutation.
//
// See details in Note [ Memory Format Propagation ]
std::unordered_map<const TensorView*, MemoryFormat> inferenceMemoryFormat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
std::unordered_map<const TensorView*, MemoryFormat> inferenceMemoryFormat(
std::unordered_map<const TensorView*, MemoryFormat> inferMemoryFormat(

std::unordered_map<const TensorView*, MemoryFormat>& format_map_;
};

// UnaryOp propagation forward memory format from input to output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the output has an allocation domain? Shouldn't the permutation be calculated here too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I make the decision to limit the scope of the pass to only propagate from inputs to outputs. So any intermediate tensor with an allocation domain would just be ignored.

Now I felt @zasdfgbnm 's comment about is this just a pass or an actual optimization thing? is quite on point. A real optimization run should have considered existing allocation domain on intermediates.


namespace nvfuser {

using MemoryFormat = std::vector<int64_t>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just call this StrideOrder?

Copy link
Collaborator Author

@jjsjann123 jjsjann123 Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a messy topic.

I was avoiding the term StrideOrder, because that's used in our python API. I want our python API to match what integration's semantic of StrideOrder is. (which is, nhwc tensor would be written as [3, 0, 2, 1]).

Meanwhile, the format notation used in codegen would mark nhwc tensor as [0, 2, 3, 1]. The reason we want that is so that it looks more consistent with our setAllocationDomain API.

tv0->setAllocationDomain({tv0->axis(0), tv->axis(2), tv->axis(3), tv->axis(1)}, true);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting... Could you add this note to the code as a comment?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this file placed inside csrc/optimization? Is the layout inference an "optimization"? Should we just call it passes or something like that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could argue this is an optimization, but I support changing the name since some other passes are not necessarily optimizing. passes might be too generic as there is already device_lower/pass. The debug dump option is fusion_ir_preseg and these are really the last thing before segmentation, so what about preseg_passes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preseg_passes works for me. Or even simpler, just preseg. I have no preference over preseg_passes vs preseg.


// BinaryOp propagation tries to merge the memory format of both inputs
//
// 1. when there's only one operand has a recorded memory format, it forwards
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possible? I think exprs are visited in topological order. Should we just NVF_ERROR(both operands has recorded memory format).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For inputs without allocation domain, we're leaving them as empty, which sounds like a bad idea.

Meanwhile, this could still happen to tensors created with factory method. Since we are only recording memory format of input tensors and I don't want that to affect the output memory format.

This resonates with @jacobhinkle 's other comment on should we have a backward propagation as well.

// TV0 has i1 -> b2 -> i0
// TV1 has b5 -> i4 -> i3
// we see that TV0 encounters a non-broadcast iter domain first, so TV0 is the
// dominating tensor. We'll produce an output with stride order identical to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question. Why is this better than just use lhs? @jjsjann123 Could you add the explanation to the code here as comment?

// TV0 has i1 -> b2 -> i0
// TV1 has b5 -> i4 -> i3
// we see that TV0 encounters a non-broadcast iter domain first, so TV0 is the
// dominating tensor. We'll produce an output with stride order identical to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow feel that, what we should do is: if this binary op contains a broadcast concretization, then respect the one with most number of concrete IDs, otherwise, just use lhs. cc @naoyam

// e.g. TV0 rfactor domain [i0, i1, i2]
// alloc domain [i0, i2, i1]
// memory format 0, 2, 1
std::unordered_map<const TensorView*, MemoryFormat>& format_map_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: If a tensor has [I1, r2, b3, I4], should the MemoryFormat be 2d, 3d, or 4d?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't touched that yet.

But I think it should be 3d here. i.e. we'll want to exclude reduction iterdomain, since it doesn't help resolve propagation with a binary op. We can probably just leave the reduction iterdomain on the left of allocation domain... or better yet, maybe we should just remove it from allocation domain since it doesn't carry any real meaning.


namespace {

class MemoryFormatInferencer : public OptOutConstDispatch {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason for not making this a subclass of IterVisitor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely should have used that one instead. Thanks 🙇

for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
std::optional<MemoryFormat> permutation = ir_utils::computePermutation(
TensorDomain::noReductions(tv->getMaybeRFactorDomain()),
tv->getMaybeAllocationDomain());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be TensorDomain::noReductions(tv->getMaybeAllocationDomain())? IIRC allocation domain do have these reductions, although it makes no sense to do so.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we make sure that reductions in the allocation domain are correctly handled?

test/utils.h Outdated

// allows overload resolution with size-1 initializer list
inline TensorView* makeSymbolicTensor(
std::initializer_list<int64_t> shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ syntax is weird...

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to understand how propagation can be more useful than the default (or arbitrary rules like using the first input's stride order) if we are only propagating in the forward direction.

// TV0 has i1 -> b2 -> i0
// TV1 has b5 -> i4 -> i3
// we see that TV0 encounters a non-broadcast iter domain first, so TV0 is the
// dominating tensor. We'll produce an output with stride order identical to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are propagating only in the forward direction, it seems to me that we can't really know what will be the most advantageous stride order. For example if we later do a sum on some outer dimension then it might wind up that we would have preferred that dimension to be allocated inner-most, but we would need to propagate that information backwards. If we're sticking with forward-only, why not just use the first input's stride order for the output and call it a day? If we want to chase more optimality we could consider doing an iterative optimization on the segmented fusion, allowing the schedulers to specify weighted preferences for the allocation orderings of their inputs and propagating changes to the outputs using simple rules like the one here, but that optimization is a bigger change to tackle.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could argue this is an optimization, but I support changing the name since some other passes are not necessarily optimizing. passes might be too generic as there is already device_lower/pass. The debug dump option is fusion_ir_preseg and these are really the last thing before segmentation, so what about preseg_passes?

@jjsjann123
Copy link
Collaborator Author

closing this PR since we are handling this one in #1788 #1790 #1792

@jjsjann123 jjsjann123 closed this Feb 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

allocation domain issues related to allocation domain support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants