diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index b97966dde0c8..541f3bba13da 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -22,12 +22,195 @@ import tvm import tvm.testing from tvm import relay -from tvm.relay import transform +from tvm.relay import transform, build_module +from tvm.relay.testing import run_opt_pass from tvm.contrib import graph_executor, pipeline_executor, pipeline_executor_build from tvm._ffi import get_global_func from tvm.contrib import cc as _cc +def graph_split(expr, split_conf, params=None): + """Splitting the graph into a list of subgraphs""" + + def get_dep_var(sub_var_dep): + return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]] + + def parse_dependency(value, snode_dep, new_input_idx): + new_args = [] + need_update = False + for var in value.args: + is_free_var = False + for dep in snode_dep[:-1]: + if var in dep["nodes"]: + # Mark the previous subgraph node as a dependency. + dep["nodes"][var] += 1 + dep["ref_nodes"][var] = dep["nodes"][var] + # The var of this call is a free_var + is_free_var = True + # if the var of this call is a free_var, recreate it and give it a fixed input name. + if is_free_var: + need_update = True + new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type)) + new_input_idx += 1 + else: + new_args.append(var) + # if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'. + if need_update: + value = tvm.relay.expr.Call( + value.op, new_args, value.attrs, value.type_args, value.span + ) + return value, snode_dep, new_input_idx + + def merge_constant_expr(constant_expr, expr): + # merge constant express with a express + if not isinstance(constant_expr.body, tvm.relay.expr.Let): + return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr) + + return tvm.relay.expr.Let( + constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr) + ) + + def _recursion(anf, pipeline_mods, split_conf, constant_expr): + # Enumurate all operators of compute graph, then split the compute graph into a group of + # subgraph. + nonlocal operator_index_map + nonlocal new_input_idx + nonlocal snode_dep + cur_node_dep = snode_dep[len(snode_dep) - 1] + if isinstance(anf, tvm.relay.Function): + return tvm.relay.Function( + anf.params, + _recursion(anf.body, pipeline_mods, split_conf, constant_expr), + anf.ret_type, + anf.type_params, + anf.attrs, + ) + if isinstance(anf, tvm.relay.expr.Let): + value = anf.value + # record the constant expr to make sure all sugraphs can find correct constant. + if isinstance(value, tvm.relay.expr.Constant): + if not constant_expr: + constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var) + else: + constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr) + if isinstance(value, tvm.relay.expr.Call): + new_args = [] + # build current var list + cur_node_dep["nodes"][anf.var] = 0 + # Get the dependency information of the nodes. + value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx) + if isinstance(value.op, tvm.ir.Op): + if value.op.name in operator_index_map: + operator_index_map[value.op.name] += 1 + else: + operator_index_map[value.op.name] = 0 + split_operator_name = split_conf[0]["op_name"] if split_conf else "" + split_operator_index = split_conf[0]["op_index"] if split_conf else "" + # if a operator name and repeating count in the network match with the values + # of the 'split configuration', then this place is where we should do the + # graph splitting. + if ( + split_conf + and split_operator_name in operator_index_map + and operator_index_map[split_operator_name] >= split_operator_index + ): + # Do graph splitting. + split_conf.pop(0) + snode_dep.append({"nodes": {}, "ref_nodes": {}}) + ann = _recursion( + anf.body, + pipeline_mods, + split_conf, + constant_expr, + ) + snode_dep.pop() + dep_vars = get_dep_var(snode_dep) + # When the nodes of the current subgraph are the depedency node of another + # subgraph, we need to set them as the output of current subgraph. + body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var + # when the operator of current subgraph uses previous subgraph constant + # as the argument of a "relay.expr.call", such constant may become a free + # varaible if the constant does not exist in the current subgraph. + # merge the previous constant with current subgraph to avoid such issue. + if constant_expr: + ann = merge_constant_expr(constant_expr, ann) + ann = run_opt_pass(ann, transform.ToGraphNormalForm()) + mod = tvm.IRModule.from_expr(ann) + pipeline_mods.insert(0, mod) + # Return the last node of the current subgraph. + return tvm.relay.expr.Let(anf.var, value, body) + return tvm.relay.expr.Let( + anf.var, + value, + _recursion(anf.body, pipeline_mods, split_conf, constant_expr), + ) + else: + return anf + + snode_dep = [{"nodes": {}, "ref_nodes": {}}] + pipeline_mods = [] + operator_index_map = {} + # Used to tracking new input which caused by graph splitting. + new_input_idx = 0 + constant_expr = None + subgraph_split_conf = split_conf.copy() + # Binding the parameters. + if params: + expr = build_module.bind_params_by_name(expr, params) + anf = run_opt_pass(expr, transform.ToANormalForm()) + anf = run_opt_pass(anf, transform.InferType()) + ann = _recursion( + anf, + pipeline_mods, + subgraph_split_conf, + constant_expr, + ) + ann = run_opt_pass(ann.body, transform.ToGraphNormalForm()) + mod = tvm.IRModule.from_expr(ann) + pipeline_mods.insert(0, mod) + return pipeline_mods + + +def get_network(): + # Get a list of modules representing subgraphs. + mods = [] + dshape = (3, 3) + data = relay.var("data_0", relay.TensorType(dshape, "float32")) + data21 = relay.var("data_1", relay.TensorType(dshape, "float32")) + data_net1_output_1 = relay.var("data_0", relay.TensorType(dshape, "float32")) + data_net1_output_2 = relay.var("data_1", relay.TensorType(dshape, "float32")) + data_net2_output_1 = relay.var("data_0", relay.TensorType(dshape, "float32")) + mvalue1 = np.full((1), 1).astype("float32") + mvalue2 = np.full((1), 2).astype("float32") + mvalue3 = np.full((1), 3).astype("float32") + mv1 = relay.Constant(tvm.nd.array(mvalue1)) + mv2 = relay.Constant(tvm.nd.array(mvalue2)) + mv3 = relay.Constant(tvm.nd.array(mvalue3)) + # There are three outputs in the first model. + net1_output1 = relay.add(data, mv1) + net1_output2 = relay.subtract(data, mv2) + net1_output3 = relay.concatenate((net1_output1, net1_output2), axis=0) + (net1_output3, _) = relay.split(net1_output3, indices_or_sections=2, axis=0) + net1_output3 = relay.add(net1_output3, mv2) + # The second model uses the output named net1_output3 of the first model as the first input, + # the second input of the second model is data21. + net2 = relay.add(net1_output3, mv2) + net2 = relay.add(net2, data21) + net2_output = relay.add(net2, mv3) + # The third model uses the output named net2_output of the second model as the first input + # and uses the output named net1_output2 of the first model as the second input. + net3 = relay.multiply(net2_output, mv3) + net3 = relay.add(net3, net1_output2) + return tvm.IRModule.from_expr(relay.Function([data, data21], relay.Tuple([net3]))), dshape + + +def get_split_mod(): + mod, dshape = get_network() + split_conf = [{"op_name": "add", "op_index": 1}, {"op_name": "add", "op_index": 4}] + mods = graph_split(mod["main"], split_conf) + return mods, dshape + + def get_mannual_mod(): # Get a list of modules representing subgraphs. mods = [] @@ -83,9 +266,8 @@ def get_manual_conf(mods, target): "mod_idx": 0, "cpu_affinity": "0", "output": [ - {"output_idx": 0, "dependencies": [{"mod_idx": 1, "input_name": "data_0"}]}, - {"output_idx": 1, "dependencies": [{"mod_idx": 2, "input_name": "data_0"}]}, - {"output_idx": 2, "dependencies": [{"global_output_index": 0}]}, + {"output_idx": 0, "dependencies": [{"mod_idx": 1, "input_name": "data_n_0"}]}, + {"output_idx": 1, "dependencies": [{"mod_idx": 2, "input_name": "data_n_2"}]}, ], } mod_config[mods[0]] = { @@ -103,7 +285,7 @@ def get_manual_conf(mods, target): "mod_idx": 1, "cpu_affinity": "0", "output": [ - {"output_idx": 0, "dependencies": [{"mod_idx": 2, "input_name": "data_1"}]}, + {"output_idx": 0, "dependencies": [{"mod_idx": 2, "input_name": "data_n_1"}]}, ], } mod_config[mods[1]] = { @@ -120,7 +302,7 @@ def get_manual_conf(mods, target): pipe_config3 = { "mod_idx": 2, "cpu_affinity": "0", - "output": [{"output_idx": 0, "dependencies": [{"global_output_index": 1}]}], + "output": [{"output_idx": 0, "dependencies": [{"global_output_index": 0}]}], } mod_config[mods[2]] = { "pipeline": pipe_config3, @@ -222,7 +404,7 @@ def test_pipe_runtime_error_check(): # This function is used to trigger runtime error by applying wrong logic. if pipeline_executor_build.pipeline_executor_build_enabled(): # Get three pipeline modules here. - (mod1, mod2, mod3), dshape = get_mannual_mod() + (mod1, mod2, mod3), dshape = get_split_mod() # The input or output name is illegal and expects a runtime error. pipe_error = pipeline_executor_build.PipelineConfig() @@ -283,7 +465,7 @@ def test_pipeline(): for target in target_list: affinity = os.sched_getaffinity(0) # Get the three pipeline modules here. - (mod1, mod2, mod3), dshape = get_mannual_mod() + (mod1, mod2, mod3), dshape = get_split_mod() # Prepare batch data for pipeline computation. datas = [] @@ -305,33 +487,29 @@ def test_pipeline(): pipe_config["input"]["data_b"].connect(pipe_config[mod2]["input"]["data_1"]) # The mod1 output[0] will be connected to a input named "data_0" of mod2. - pipe_config[mod1]["output"][0].connect(pipe_config[mod2]["input"]["data_0"]) + pipe_config[mod1]["output"][0].connect(pipe_config[mod2]["input"]["data_n_0"]) # The mod1 output[1] will be connected to a input named "data_0" of mod3. - pipe_config[mod1]["output"][1].connect(pipe_config[mod3]["input"]["data_0"]) + pipe_config[mod1]["output"][1].connect(pipe_config[mod3]["input"]["data_n_2"]) # The mod2 output[2] will be connected to a input named "data_1" of mod3. - pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_1"]) - - # The mod1 output[2] will be connected to pipeline output[0]. - pipe_config[mod1]["output"][2].connect(pipe_config["output"]["0"]) + pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_n_1"]) - # The mod3 output[0] will be connected to pipeline output[1]. - pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) - # Print configueration (print(pipe_config)), the result looks like following. + # The mod3 output[0] will be connected to pipeline output[0]. + pipe_config[mod3]["output"][0].connect(pipe_config["output"]["0"]) + # Print configuration (print(pipe_config)), the result looks like following. # # Inputs # |data_a: mod1:data_0 # |data_b: mod2:data_1 # # output - # |output(1) : mod1.output(2) - # |output(2) : mod3.output(0) + # |output(1) : mod3.output(0) # # connections - # |mod1.output(0)-> mod2.data_0 - # |mod1.output(1)-> mod3.data_0 - # |mod2.output(0)-> mod3.data_1 + # |mod1.output(0)-> mod2.data_n_0 + # |mod1.output(1)-> mod3.data_n_2 + # |mod2.output(0)-> mod3.data_n_1 # Set other parameters. pipe_config[mod1].target = target[0] @@ -367,7 +545,7 @@ def test_pipeline(): # Use the import function to create and initialize PipelineModule. pipeline_module_test = pipeline_executor.PipelineModule.load_library(config_file_name) - assert pipeline_module_test.num_outputs == 2 + assert pipeline_module_test.num_outputs == 1 input_map = pipeline_module_test.get_input_pipeline_map("data_b") assert input_map[0] == "1" and input_map[1] == "data_1"