Skip to content

Conversation

@masahi
Copy link
Member

@masahi masahi commented Oct 31, 2020

The fix for the issue https://discuss.tvm.apache.org/t/graph-plan-memory-doesnt-support-nested-tuples/8278

A token is created for each tensor in a nested tuple. Graph runtime doesn't deal with tuples, so nested tuples are all flattened (runtime.get_num_outputs() and runtime.get_output(i) don't take tuples or nesting into account).

please review @tqchen @zhiics (cc @mbaret @manupa-arm )
I didn't look into memory planning in detail, but I hope this is correct.

@masahi masahi force-pushed the graph-memory-nested-tuple branch from bf5f22e to 3fadfe2 Compare November 1, 2020 06:09
auto tok = GetToken(field);
ICHECK_EQ(tok.size(), 1U);
fields.push_back(tok[0]);
auto tokens = GetToken(field);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle 2 levels of nesting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there is VisitExpr(field) inside GetToken, where nesting is recursively visited and token generated for each tensor in the nest.

x1 = x + relay.const(1.0)
x2 = x1 + relay.const(1.0)
x3 = x2 + relay.const(1.0)
out = relay.Tuple([x1, relay.Tuple([x2, x3])])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example what if x2 is also a tuple.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just updated this test to verify two level nesting works.

@masahi masahi force-pushed the graph-memory-nested-tuple branch from 3fadfe2 to 208cfc5 Compare November 2, 2020 03:30
Copy link
Contributor

@mbaret mbaret left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change lgtm, but I have a question relating to the original discuss post. I've been looking at the compile_engine recently and noticed a number of TODOs there with the comment 'Allow recursive tuple', see for instance: https://github.com/apache/incubator-tvm/blob/3222cad9464cadafcab29dbfb5cf149bf083913f/src/relay/backend/compile_engine.cc#L118

It also seems to me that there's logic in graph_runtime_codegen.cc that flattens tuples assuming they're not recursive. Therefore, I'm a bit surprised it was only graph_plan_memory that was blocking you.

@masahi
Copy link
Member Author

masahi commented Nov 2, 2020

@mbaret yes, so this patch suffices for my simple use case of returning nested tuples from the main function. But for use cases like BYOC where the checked_type of CallNode->op can be a nested tuple, it seems I need further fixes (one in graph_plan_memory.cc and another in graph_runtime_codegen.cc).

For example, I've just tried compiling this graph containing a sub function with nested tuple output, and it doesn't compile. I'll work on the fix, that should enable removing the tuple flatting in the BYOC partitioning pass @manupa-arm

#[version = "0.0.5"]
def @func0(%x: Tensor[(10, 10), uint8]) -> (Tensor[(5, 10), uint8], (Tensor[(5, 10), uint8], Tensor[(5, 10), uint8])) {
  %0 = split(%x, indices_or_sections=2) /* ty=(Tensor[(5, 10), uint8], Tensor[(5, 10), uint8]) */;
  %1 = %0.0;
  %2 = %0.1;
  %3 = %0.0;
  %4 = abs(%3) /* ty=Tensor[(5, 10), uint8] */;
  %5 = (%2, %4);
  (%1, %5)
}

def @main(%a: Tensor[(10, 10), uint8]) -> (Tensor[(5, 10), uint8], (Tensor[(5, 10), uint8], Tensor[(5, 10), uint8])) {
  @func0(%a) /* ty=(Tensor[(5, 10), uint8], (Tensor[(5, 10), uint8], Tensor[(5, 10), uint8])) */
}

I've been looking at the compile_engine recently and noticed a number of TODOs there with the comment 'Allow recursive tuple'

Those TODOs look unrelated to graph memory planning. They refer to Function params and shape func return type not supporting nested tuples, and that's definitely still the case.

@masahi masahi marked this pull request as draft November 2, 2020 20:12
@masahi masahi force-pushed the graph-memory-nested-tuple branch from 208cfc5 to d384062 Compare November 2, 2020 21:29
@masahi masahi marked this pull request as ready for review November 2, 2020 21:29
@masahi
Copy link
Member Author

masahi commented Nov 2, 2020

@tqchen @zhiics @jroesch @mbaret Supporting nested tuples in CallNode->op return type required making non trivial changes to multiple files in the commit d384062

Maybe it would help BYOC use cases, but otherwise I don't know if we want to support this in practice. So I'm happy to revert that change and only support returning nested tuples from main function.

@zhiics
Copy link
Member

zhiics commented Nov 2, 2020

I am actually okay without the new change

@mbaret
Copy link
Contributor

mbaret commented Nov 2, 2020

I'm fine to leave the further changes out of this PR (especially as it's titled with graph plan memory). If it looks more complicated, that may need to be the subject of an RFC.

@masahi masahi force-pushed the graph-memory-nested-tuple branch from 72b9564 to b139231 Compare November 3, 2020 00:01
@masahi
Copy link
Member Author

masahi commented Nov 3, 2020

ok reverted

@masahi masahi force-pushed the graph-memory-nested-tuple branch from b139231 to 2171fac Compare November 4, 2020 00:21
@masahi masahi merged commit 6019db2 into apache:main Nov 4, 2020
@masahi
Copy link
Member Author

masahi commented Nov 4, 2020

Thanks @zhiics @jroesch @mbaret

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Dec 2, 2020
* add test

* test working

* uncomment other tests

* remove redundant visit

* test double nesting

* support nested tuple in CallNode's return type

* Revert "support nested tuple in CallNode's return type"

This reverts commit 66225ed.
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Dec 4, 2020
* add test

* test working

* uncomment other tests

* remove redundant visit

* test double nesting

* support nested tuple in CallNode's return type

* Revert "support nested tuple in CallNode's return type"

This reverts commit 66225ed.
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Dec 4, 2020
* add test

* test working

* uncomment other tests

* remove redundant visit

* test double nesting

* support nested tuple in CallNode's return type

* Revert "support nested tuple in CallNode's return type"

This reverts commit 66225ed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants