Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 201 additions & 23 deletions tests/python/relay/test_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,195 @@
import tvm
import tvm.testing
from tvm import relay
from tvm.relay import transform
from tvm.relay import transform, build_module
from tvm.relay.testing import run_opt_pass
from tvm.contrib import graph_executor, pipeline_executor, pipeline_executor_build
from tvm._ffi import get_global_func
from tvm.contrib import cc as _cc


def graph_split(expr, split_conf, params=None):
"""Splitting the graph into a list of subgraphs"""

def get_dep_var(sub_var_dep):
return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]]

def parse_dependency(value, snode_dep, new_input_idx):
new_args = []
need_update = False
for var in value.args:
is_free_var = False
for dep in snode_dep[:-1]:
if var in dep["nodes"]:
# Mark the previous subgraph node as a dependency.
dep["nodes"][var] += 1
dep["ref_nodes"][var] = dep["nodes"][var]
# The var of this call is a free_var
is_free_var = True
# if the var of this call is a free_var, recreate it and give it a fixed input name.
if is_free_var:
need_update = True
new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type))
new_input_idx += 1
else:
new_args.append(var)
# if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'.
if need_update:
value = tvm.relay.expr.Call(
value.op, new_args, value.attrs, value.type_args, value.span
)
return value, snode_dep, new_input_idx

def merge_constant_expr(constant_expr, expr):
# merge constant express with a express
if not isinstance(constant_expr.body, tvm.relay.expr.Let):
return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr)

return tvm.relay.expr.Let(
constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr)
)

def _recursion(anf, pipeline_mods, split_conf, constant_expr):
# Enumurate all operators of compute graph, then split the compute graph into a group of
# subgraph.
nonlocal operator_index_map
nonlocal new_input_idx
nonlocal snode_dep
cur_node_dep = snode_dep[len(snode_dep) - 1]
if isinstance(anf, tvm.relay.Function):
return tvm.relay.Function(
anf.params,
_recursion(anf.body, pipeline_mods, split_conf, constant_expr),
anf.ret_type,
anf.type_params,
anf.attrs,
)
if isinstance(anf, tvm.relay.expr.Let):
value = anf.value
# record the constant expr to make sure all sugraphs can find correct constant.
if isinstance(value, tvm.relay.expr.Constant):
if not constant_expr:
constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
else:
constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr)
if isinstance(value, tvm.relay.expr.Call):
new_args = []
# build current var list
cur_node_dep["nodes"][anf.var] = 0
# Get the dependency information of the nodes.
value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx)
if isinstance(value.op, tvm.ir.Op):
if value.op.name in operator_index_map:
operator_index_map[value.op.name] += 1
else:
operator_index_map[value.op.name] = 0
split_operator_name = split_conf[0]["op_name"] if split_conf else ""
split_operator_index = split_conf[0]["op_index"] if split_conf else ""
# if a operator name and repeating count in the network match with the values
# of the 'split configuration', then this place is where we should do the
# graph splitting.
if (
split_conf
and split_operator_name in operator_index_map
and operator_index_map[split_operator_name] >= split_operator_index
):
# Do graph splitting.
split_conf.pop(0)
snode_dep.append({"nodes": {}, "ref_nodes": {}})
ann = _recursion(
anf.body,
pipeline_mods,
split_conf,
constant_expr,
)
snode_dep.pop()
dep_vars = get_dep_var(snode_dep)
# When the nodes of the current subgraph are the depedency node of another
# subgraph, we need to set them as the output of current subgraph.
body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var
# when the operator of current subgraph uses previous subgraph constant
# as the argument of a "relay.expr.call", such constant may become a free
# varaible if the constant does not exist in the current subgraph.
# merge the previous constant with current subgraph to avoid such issue.
if constant_expr:
ann = merge_constant_expr(constant_expr, ann)
ann = run_opt_pass(ann, transform.ToGraphNormalForm())
mod = tvm.IRModule.from_expr(ann)
pipeline_mods.insert(0, mod)
# Return the last node of the current subgraph.
return tvm.relay.expr.Let(anf.var, value, body)
return tvm.relay.expr.Let(
anf.var,
value,
_recursion(anf.body, pipeline_mods, split_conf, constant_expr),
)
else:
return anf

snode_dep = [{"nodes": {}, "ref_nodes": {}}]
pipeline_mods = []
operator_index_map = {}
# Used to tracking new input which caused by graph splitting.
new_input_idx = 0
constant_expr = None
subgraph_split_conf = split_conf.copy()
# Binding the parameters.
if params:
expr = build_module.bind_params_by_name(expr, params)
anf = run_opt_pass(expr, transform.ToANormalForm())
anf = run_opt_pass(anf, transform.InferType())
ann = _recursion(
anf,
pipeline_mods,
subgraph_split_conf,
constant_expr,
)
ann = run_opt_pass(ann.body, transform.ToGraphNormalForm())
mod = tvm.IRModule.from_expr(ann)
pipeline_mods.insert(0, mod)
return pipeline_mods


def get_network():
# Get a list of modules representing subgraphs.
mods = []
dshape = (3, 3)
data = relay.var("data_0", relay.TensorType(dshape, "float32"))
data21 = relay.var("data_1", relay.TensorType(dshape, "float32"))
data_net1_output_1 = relay.var("data_0", relay.TensorType(dshape, "float32"))
data_net1_output_2 = relay.var("data_1", relay.TensorType(dshape, "float32"))
data_net2_output_1 = relay.var("data_0", relay.TensorType(dshape, "float32"))
mvalue1 = np.full((1), 1).astype("float32")
mvalue2 = np.full((1), 2).astype("float32")
mvalue3 = np.full((1), 3).astype("float32")
mv1 = relay.Constant(tvm.nd.array(mvalue1))
mv2 = relay.Constant(tvm.nd.array(mvalue2))
mv3 = relay.Constant(tvm.nd.array(mvalue3))
# There are three outputs in the first model.
net1_output1 = relay.add(data, mv1)
net1_output2 = relay.subtract(data, mv2)
net1_output3 = relay.concatenate((net1_output1, net1_output2), axis=0)
(net1_output3, _) = relay.split(net1_output3, indices_or_sections=2, axis=0)
net1_output3 = relay.add(net1_output3, mv2)
# The second model uses the output named net1_output3 of the first model as the first input,
# the second input of the second model is data21.
net2 = relay.add(net1_output3, mv2)
net2 = relay.add(net2, data21)
net2_output = relay.add(net2, mv3)
# The third model uses the output named net2_output of the second model as the first input
# and uses the output named net1_output2 of the first model as the second input.
net3 = relay.multiply(net2_output, mv3)
net3 = relay.add(net3, net1_output2)
return tvm.IRModule.from_expr(relay.Function([data, data21], relay.Tuple([net3]))), dshape


def get_split_mod():
mod, dshape = get_network()
split_conf = [{"op_name": "add", "op_index": 1}, {"op_name": "add", "op_index": 4}]
mods = graph_split(mod["main"], split_conf)
return mods, dshape


def get_mannual_mod():
# Get a list of modules representing subgraphs.
mods = []
Expand Down Expand Up @@ -83,9 +266,8 @@ def get_manual_conf(mods, target):
"mod_idx": 0,
"cpu_affinity": "0",
"output": [
{"output_idx": 0, "dependencies": [{"mod_idx": 1, "input_name": "data_0"}]},
{"output_idx": 1, "dependencies": [{"mod_idx": 2, "input_name": "data_0"}]},
{"output_idx": 2, "dependencies": [{"global_output_index": 0}]},
{"output_idx": 0, "dependencies": [{"mod_idx": 1, "input_name": "data_n_0"}]},
{"output_idx": 1, "dependencies": [{"mod_idx": 2, "input_name": "data_n_2"}]},
],
}
mod_config[mods[0]] = {
Expand All @@ -103,7 +285,7 @@ def get_manual_conf(mods, target):
"mod_idx": 1,
"cpu_affinity": "0",
"output": [
{"output_idx": 0, "dependencies": [{"mod_idx": 2, "input_name": "data_1"}]},
{"output_idx": 0, "dependencies": [{"mod_idx": 2, "input_name": "data_n_1"}]},
],
}
mod_config[mods[1]] = {
Expand All @@ -120,7 +302,7 @@ def get_manual_conf(mods, target):
pipe_config3 = {
"mod_idx": 2,
"cpu_affinity": "0",
"output": [{"output_idx": 0, "dependencies": [{"global_output_index": 1}]}],
"output": [{"output_idx": 0, "dependencies": [{"global_output_index": 0}]}],
}
mod_config[mods[2]] = {
"pipeline": pipe_config3,
Expand Down Expand Up @@ -222,7 +404,7 @@ def test_pipe_runtime_error_check():
# This function is used to trigger runtime error by applying wrong logic.
if pipeline_executor_build.pipeline_executor_build_enabled():
# Get three pipeline modules here.
(mod1, mod2, mod3), dshape = get_mannual_mod()
(mod1, mod2, mod3), dshape = get_split_mod()

# The input or output name is illegal and expects a runtime error.
pipe_error = pipeline_executor_build.PipelineConfig()
Expand Down Expand Up @@ -283,7 +465,7 @@ def test_pipeline():
for target in target_list:
affinity = os.sched_getaffinity(0)
# Get the three pipeline modules here.
(mod1, mod2, mod3), dshape = get_mannual_mod()
(mod1, mod2, mod3), dshape = get_split_mod()

# Prepare batch data for pipeline computation.
datas = []
Expand All @@ -305,33 +487,29 @@ def test_pipeline():
pipe_config["input"]["data_b"].connect(pipe_config[mod2]["input"]["data_1"])

# The mod1 output[0] will be connected to a input named "data_0" of mod2.
pipe_config[mod1]["output"][0].connect(pipe_config[mod2]["input"]["data_0"])
pipe_config[mod1]["output"][0].connect(pipe_config[mod2]["input"]["data_n_0"])

# The mod1 output[1] will be connected to a input named "data_0" of mod3.
pipe_config[mod1]["output"][1].connect(pipe_config[mod3]["input"]["data_0"])
pipe_config[mod1]["output"][1].connect(pipe_config[mod3]["input"]["data_n_2"])

# The mod2 output[2] will be connected to a input named "data_1" of mod3.
pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_1"])

# The mod1 output[2] will be connected to pipeline output[0].
pipe_config[mod1]["output"][2].connect(pipe_config["output"]["0"])
pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_n_1"])

# The mod3 output[0] will be connected to pipeline output[1].
pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"])
# Print configueration (print(pipe_config)), the result looks like following.
# The mod3 output[0] will be connected to pipeline output[0].
pipe_config[mod3]["output"][0].connect(pipe_config["output"]["0"])
# Print configuration (print(pipe_config)), the result looks like following.
#
# Inputs
# |data_a: mod1:data_0
# |data_b: mod2:data_1
#
# output
# |output(1) : mod1.output(2)
# |output(2) : mod3.output(0)
# |output(1) : mod3.output(0)
#
# connections
# |mod1.output(0)-> mod2.data_0
# |mod1.output(1)-> mod3.data_0
# |mod2.output(0)-> mod3.data_1
# |mod1.output(0)-> mod2.data_n_0
# |mod1.output(1)-> mod3.data_n_2
# |mod2.output(0)-> mod3.data_n_1

# Set other parameters.
pipe_config[mod1].target = target[0]
Expand Down Expand Up @@ -367,7 +545,7 @@ def test_pipeline():

# Use the import function to create and initialize PipelineModule.
pipeline_module_test = pipeline_executor.PipelineModule.load_library(config_file_name)
assert pipeline_module_test.num_outputs == 2
assert pipeline_module_test.num_outputs == 1

input_map = pipeline_module_test.get_input_pipeline_map("data_b")
assert input_map[0] == "1" and input_map[1] == "data_1"
Expand Down