From 8fa532f4f1da6aa9eb59a49a32d1163df5bab5f7 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Mon, 7 Jan 2019 19:37:48 +0530 Subject: [PATCH] [RELAY][PASS] AlterOpLayout bugfix to handle Tuple args properly. --- src/relay/pass/alter_op_layout.cc | 74 +++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 23 deletions(-) diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 5c4475259086..7341c21fe836 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -165,7 +165,22 @@ Call CallAlter(const Call& ref_call, } } if (!modified) { - new_e = CallNode::make(ref_call->op, new_args, + std::vector new_ref_args; + int arg_index = 0; + // build according to ref_call. + for (auto arg : ref_call->args) { + if (arg->is_type()) { + Tuple tuple_arg = Downcast(arg); + Array tuple_fields; + for (auto x : tuple_arg->fields) { + tuple_fields.push_back(new_args[arg_index++]); + } + new_ref_args.push_back(TupleNode::make(tuple_fields)); + } else { + new_ref_args.push_back(new_args[arg_index++]); + } + } + new_e = CallNode::make(ref_call->op, new_ref_args, ref_call->attrs, ref_call->type_args); } @@ -261,39 +276,52 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, CHECK_EQ(new_in.size(), new_in2.size()) << "The number of input nodes should keep the same during alter_op_layout"; + // Expand new_call args + std::vector expanded_new_args; + for (auto arg : new_call->args) { + if (arg->is_type()) { // expand tuple + Tuple tuple_arg = Downcast(arg); + for (auto x : tuple_arg->fields) { + expanded_new_args.push_back(x); + } + } else { + expanded_new_args.push_back(arg); + } + } + // if (new_in != new_in2): insert transform (new_in -> new_in2) Array transformed_args; for (size_t i = 0; i < inputs.size(); ++i) { - transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i])); + transformed_args.push_back(memorizer.Transform(expanded_new_args[i], new_in[i], new_in2[i])); } // state[node] = (old_out, new_out) CHECK(ref_call->checked_type_.defined()) << "Call infer_type pass before alter_op_layout pass"; - if (ref_call->checked_type()->is_type()) { - Expr tuple_output = CallNode::make(new_call->op, transformed_args, - new_call->attrs, new_call->type_args); - Array fields; - for (size_t i = 0; i < new_out.size(); ++i) { - auto rnode = make_node(); - rnode->value = TupleGetItemNode::make(tuple_output, i); - rnode->old_layout = old_out[i]; - rnode->new_layout = new_out[i]; - rnode->memorizer = memorizer; - fields.push_back(Expr(rnode)); + Array transformed_args_new; + int arg_index = 0; + for (auto arg : ref_call->args) { + if (arg->is_type()) { + Tuple tuple_arg = Downcast(arg); + Array tuple_fields; + for (auto x : tuple_arg->fields) { + tuple_fields.push_back(transformed_args[arg_index++]); + } + transformed_args_new.push_back(TupleNode::make(tuple_fields)); + } else { + transformed_args_new.push_back(transformed_args[arg_index++]); } - return TupleNode::make(fields); - } else { - auto rnode = make_node(); - CHECK_EQ(new_out.size(), 1); - rnode->value = CallNode::make(new_call->op, transformed_args, - new_call->attrs, new_call->type_args); - rnode->old_layout = old_out[0]; - rnode->new_layout = new_out[0]; - rnode->memorizer = memorizer; - return Expr(rnode); } + + auto rnode = make_node(); + CHECK_EQ(new_out.size(), 1); + rnode->value = CallNode::make(new_call->op, transformed_args_new, + new_call->attrs, new_call->type_args); + rnode->old_layout = old_out[0]; + rnode->new_layout = new_out[0]; + rnode->memorizer = memorizer; + return Expr(rnode); } TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")