Skip to content

Conversation

@slyubomirsky
Copy link
Contributor

@slyubomirsky slyubomirsky commented Nov 14, 2023

Pursuant to issue #15319, this pass implements an in-place transformation for dataflow blocks. For each operator call in the dataflow block, the pass if at least one argument is not live past the call and has no aliases that are live past the call. If this condition is met for at least one argument, the operator can be replaced with an in-place version called via call_tir_inplace. The in-place version is a TIR PrimFunc that is produced by taking the legalizer for that operator and replacing its output with one of the inputs.

The liveness analysis is very simple since we are handling only single dataflow blocks. Handling general Relax control flow would require a solution more akin to #15689. The alias analysis, by contrast, is more complex since alias analysis is generally a tricky issue. It focuses on being able to handle common cases (especially tuples).

@junrushao junrushao force-pushed the unity branch 2 times, most recently from c95d45f to 45eeb8c Compare December 18, 2023 21:00
}

// Replace buffers in a PrimFunc according to the mapping.
tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map<tir::Buffer, tir::Buffer>& buffer_map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be handled by calling tir::Specialize.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh! I didn't know about it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nice one for a lot of signature manipulations. There may need to be some post-processing, but it should work in this case. If the selected in-place input is parameter i, and the output is parameter j, then it would be called as auto new_func = Specialize(func, {{func->params[j], func->buffer_map[func->params[i]]}}).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, using tir::Specialize still left some uses of the old output buffer around, which still seems like it would require the manual rewrite I had. It was odd. Namely, the old buffer was being used in the writes field of the block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the output buffer have a DeclBuffer statement for it, or was it just used implicitly?

I think I ran into a similar issue, with #14565 to resolve it. That PR ended up not being merged, since RFC#70 would have required the DeclBuffer node. Unfortunately, #14778, which implements RFC#70, has been stuck for some time on ethos-u test cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was no DeclBuffer, there was just a loose variable hanging around.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. In that case, it would fall into the same bug that #14565 was intended to resolve. I think for now, it's probably better to use the separate mutator. If either RFC#70 gets fully implemented, or if the relax-to-TIR conversion provides a DeclBuffer annotation, then it could be updated.

// replace the call with a call to an in-place PrimFunc.
// (Made public for testing.)
Call CreateInplaceCall(const Call& call, const Array<Integer>& inplace_indices) {
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this use a separate legalization map? Even if they map to the same function, the per-operator support could then be checked by seeing if "FLegalizeInPlace" is defined for an operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's worth thinking about. I didn't use it here because the operators that can be done in-place can also be implemented by applying the transformation given here to their legalized version. @tqchen, @MasterJH5574 (and others), do you think it would be worthwhile to add another map for in-place legalizations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. In that case, it probably makes sense to use FLegalize as-is. At some point in the future, the explicit listing of operators could be replaced by a TIR analysis pass. I think that (hypothetical future enhancement) would work especially well if this pass is then moved downstream of FuseTIR, so that it would be inspecting the already-generated TIR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like for a pass like this to come later in compilation, but the issue is the fact it relies on working within single dataflow blocks, so it has to come before we get rid of dataflow blocks. That is also causing me a few issues with thinking about how to handle in-place split and concat.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a larger change, but would it make sense to either (A) implement it in terms of bindings without requiring a dataflow block, or (B) implement the pass as Sequential(ConvertToDataFlow(), current_pass(), ToNonDataflow())?

(B) would be easier to implement, but the use of ToNonDataflow as a post-processor would require that it be after any passes that require dataflows to be present.

Copy link
Contributor Author

@slyubomirsky slyubomirsky Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to have the dataflow requirement, we would also need to handle control flow. That's the main issue and there's no obvious way to deal with control flow without going with a solution more like abstract interpretation (see my comment about the general-purpose liveness analysis PR).

We could, I suppose, fake the requirement be analyzing the blocks before, after, and within control flow. It would be a large change but it could be done. The result would be less general than a solution that properly accounts for control flow but it would be a start.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the sequential transformation would be a less ad hoc version of the above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and agreed that the sequential transform would be the better way to handle it, and a more consistent pattern that could be applied to other passes.

vm = relax.VirtualMachine(ex, tvm.cpu())
res = vm["main"](x, y)
assert (expected == res.numpy()).all()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want a pytest.xfail case for internal functions, to indicate that they are not yet supported? I'm thinking of cases like the following, where the current local analysis would conclude that Module.func cannot do in-place operations, but non-local analysis could perform Module.func in-place for some cases, depending on the liveness analysis in the calling scope.

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
        cls = Module
        with R.dataflow():
            # Cannot be in-place, uses a function argument
            y = cls.subroutine(x)
            
            # Can be in-place, uses a local binding
            z = cls.subroutine(y)
            R.output(z)
        return z
    
    @R.function(private=True)
    def subroutine(x: R.Tensor((), "int32"), y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
        with R.dataflow():
            z = R.add(x,R.const(1, "int32"))
            R.output(z)
        return z

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always in favor of testing error conditions, so I'll add the test case too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Though, I'd probably see this one as a not-yet-supported test case (pytest.mark.xfail), rather than a test of an error case (with pytest.raises(...)). Rather than verifying that a specific error is raised, it would document what behavior isn't covered by the current implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I'm not sure we should add an error-condition test case for something we don't even claim to support (the description of the pass does not reference trying to do function calls in place).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could go either way on it. The high-level name DataflowUseInplaceCalls doesn't specify that the InplaceCalls only refer to native relax operators, and excludes potentially-in-place calls to other relax functions. The docstring does specify operators, so you're right that we don't claim support for it there.

I'd see a @pytest.mark.xfail(reason="Not currently supported") as a way to explicitly disclaim responsibility for that use case. It would indicate what that support would look like, and that it could be supported at some point, but that it isn't currently supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can appreciate your point of view here, but I would point out that while tests are a form of documentation, documentation (such as doc comments) is also a form of documentation 🙂

If there is a strong need to reinforce the point that we are only looking for a small number of operators, we can do this, but I don't think there is right now, especially since handling Relax functions in-place is not something we have even planned on doing. (Possible in principle, but it's not on the horizon.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and mostly just nitpicks on my side. While I think documenting potential future expansions is useful, I don't think it's worth delaying this PR. (Approved!)

@slyubomirsky
Copy link
Contributor Author

Thank you for drawing my attention to the dynamic case, @Lunderberg, I would not have guessed that it would "just work" like that 😄

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you on the changes, and I like the functionality!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants