Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 207 additions & 21 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@ def build(pipe_configs):
Common interface for pipeline executor factory modules.
"""
libs = {}
mod_n_configs = pipe_configs.get_config()
config = pipe_configs.get_config()
if "module_connection" not in config:
raise RuntimeError('"module_connection" is missing')
if "input_connection" not in config:
raise RuntimeError('"input_connection" is missing')

mod_n_configs = config["module_connection"]
config_len = len(mod_n_configs)
string_config = [{} for _ in range(config_len)]
module_string_config = [{} for _ in range(config_len)]
# Build the backend modules then create the config of the connections in string form.
for ir_mod, mod_config in mod_n_configs.items():
mconf = mod_config["pipeline"].copy()
mod_idx = mconf["mod_idx"]
pipe_config = mod_config["pipeline"].copy()
mod_idx = pipe_config["mod_idx"]
dev = mod_config["dev"]
target = mod_config["target"]
build_func = relay.build
Expand All @@ -70,11 +77,18 @@ def build(pipe_configs):
mod_name=mod_config["mod_name"],
)

mconf["dev"] = "{},{}".format(dev.device_type, dev.device_id)
pipe_config["dev"] = "{},{}".format(dev.device_type, dev.device_id)
# Create a pipeline configuration.
string_config[mod_idx] = mconf
module_string_config[mod_idx] = pipe_config
libs[mod_idx] = {"lib": lib, "dev": dev}

# Merge the "input_connection", the "param_connection" and the "module_connection" into one
# configuration.
string_config = {}
string_config["input_connection"] = config["input_connection"]
string_config["param_connection"] = config["param_connection"]
string_config["module_connection"] = module_string_config

return PipelineExecutorFactoryModule(libs, string_config)


Expand All @@ -93,8 +107,80 @@ def __init__(self, module):
else:
self.module = module
# Get the packed functions from the pipeline executor.
self._run = self.module["run"]
self._stop = self.module["stop"]
self._set_input = self.module["set_input"]
self._set_param = self.module["set_param"]
self._get_input = self.module["get_input"]
self._get_output = self.module["get_output"]
self._get_num_inputs = self.module["get_num_inputs"]
self._get_num_outputs = self.module["get_num_outputs"]

def run(self, sync=False):
"""Run the pipeline executor."""
self._run(sync)

def stop(self):
"""Stop the pipeline executor."""
self._stop()

def set_input(self, key, value):
"""Set the value of "value" to the global input named "value". A global input is
defined during the pipeline configurration, it is connected with a graph module input.

Parameters
----------
key : str
The input key

value : array_like.
The input value
"""
v = self._get_input(key)
if v is None:
raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
v.copyfrom(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that you need to do this because all buffers were created when initializing the pipeline executor. In this way we will double the memory usage of global inputs. Can we avoid allocating redundant buffers to reduce memory consumption, similar to graph executor and VM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, the said logic seems like already follow the same logic as what graph_executor did and no redundant buffer creation, could you help to give more detail information about which part cause the problem?


def set_params(self, params_name, params_data):
"""Set the value of "params_data" to the global params named "params_name", the global
params name is defined during the pipeline configueration creation, it is connected with
the params of a graph module which is a dictionary constructed from key and value.

Parameters
----------
params_name : str
The params name

params_data : dict of str to NDArray
A list of params data and params key name.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between params key and params_name

"""
for key, val in params_data.items():
self._set_param(params_name, key, val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the above comment, this is a weird API.


def get_input(self, key):
"""Get the input via an input name.
Parameters
----------
key : str
The input key

Returns
-------
data : NDArray
The input data.
"""
self._get_input(key)

def get_output(self):
"""Get the output.

Returns:
-----------
data : Array[NDArray]
A list of output data.
"""
return self._get_output()

@property
def num_outputs(self):
"""Get the number of outputs.
Expand All @@ -105,6 +191,16 @@ def num_outputs(self):
"""
return self._get_num_outputs()

@property
def num_inputs(self):
"""Get the number of inputs.
Returns
-------
count : int
The number of inputs.
"""
return self._get_num_inputs()

@staticmethod
def load_library(config_file_name):
"""Import files to create a pipeline executor.
Expand Down Expand Up @@ -154,7 +250,7 @@ class Binding:
The class who owns this interface.

io_type : str
The I/O type of this interface. It can only be "input" or "output".
The I/O type of this interface. It can only be "input" or "output" or "param".

name : str/integer
Name, for input it is string such as "data0", for output it is an integer such as 0.
Expand All @@ -171,7 +267,6 @@ def __init__(self, owner, io_type, name, data_type=None):
self.bindings = []
# Parents interfaces that this interface depend on.
self.parents = []

self.data_type = data_type

def get_name(self):
Expand Down Expand Up @@ -199,12 +294,48 @@ def is_pipeline_executor_interface(self):
return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper)

def __repr__(self):
# Get all binding information.
ret = " |{}: ".format(self.name)
# Get the binding information in the form of string.
str_format = " |{}: ".format(self.name)
for binding in self.bindings:
mname, dname = binding.get_name()
ret += "{0}:{1} ".format(mname, dname)
return ret
str_format += "{0}:{1} ".format(mname, dname)

return str_format

def check_binding_dict(self, connection_dict):
"""Check the dict form of this binding.
Parameter
---------
connection_dict : Dict[str, Any]
The dict of input or parameters connection.
"""
if "interface_name" not in connection_dict:
raise RuntimeError(f'"inteface_name" is missing in global config!"')
if "connection" not in connection_dict:
raise RuntimeError(f'"connection" is missing!"')
# The global interface mapping should be one-to-one.
if not connection_dict["connection"]:
raise RuntimeError(f"The global interface map is empty!")
if len(connection_dict["connection"]) > 1:
raise RuntimeError(f"A global interface maps multiple module interfaces!")
if "mod_idx" not in connection_dict["connection"][0]:
raise RuntimeError(f'"mod_idx" is missing!')

def get_binding_dict(self):
"""Return the binding information in the form of dict.
Returns
-------
data : Dict[str, Any]
The binding information in the form of dict.
"""
dict_format = {"interface_name": self.name, "connection": []}
for binding in self.bindings:
_, dname = binding.get_name()
midx = binding.get_owner_idx()
dict_format["connection"].append({"mod_idx": midx, "interface_name": dname})

self.check_binding_dict(dict_format)
return dict_format

def check_dag_acyclic(self, start, inputs):
"""This is to check whether the DAG containing these input interfaces is acyclic.
Expand Down Expand Up @@ -245,6 +376,15 @@ def connect(self, binding):
if self.io_owner == binding.io_owner:
raise RuntimeError(f"Can not bind itself.")

if self.io_type == "param" and not self.is_pipeline_executor_interface():
raise RuntimeError(f'Only a pipeline executor can do "param" binding!')

if self.io_type == "param" and binding.io_type != "param":
raise RuntimeError(
f"A global param interface can only bind with a module \
param interface!"
)

if not self.is_pipeline_executor_interface() and self.io_type == "input":
raise RuntimeError(f"Module can only bind from output interface!")

Expand All @@ -265,7 +405,11 @@ def connect(self, binding):
if self.is_pipeline_executor_interface() and self.io_type == "output":
raise RuntimeError(f"Global output can not be used as binding start point.")

if self.is_pipeline_executor_interface() and binding.io_type != "input":
if (
self.is_pipeline_executor_interface()
and self.io_type == "input"
and binding.io_type != "input"
):
raise RuntimeError(f"Global input can only bind with module input.")

self.bindings.append(binding)
Expand Down Expand Up @@ -342,6 +486,7 @@ def __init__(self, mod=None):
self.output_type = InferType()(mod)["main"].checked_type.ret_type
self.input_bindings = PipelineConfig.BindingList(self, "input")
self.output_bindings = PipelineConfig.BindingList(self, "output")
self.param_binding = PipelineConfig.Binding(self, "param", "param")

def __eq__(self, other):
if isinstance(other, PipelineConfig.ModuleWrapper):
Expand All @@ -353,11 +498,13 @@ def __getitem__(self, key):
if isinstance(key, str):
if key == "input":
return self.input_bindings

if key == "output":
return self.output_bindings
if key == "param":
return self.param_binding
raise RuntimeError(f"{key} not found!")

raise RuntimeError(f"{key} not found!")
raise RuntimeError(f'The data type of "key" is not supported!')

def get_data_type(self, key, interface_type):
"""Get the module interface data type according to the key value and interface type.
Expand Down Expand Up @@ -411,14 +558,22 @@ def __init__(self):
self.mod_wrapper = {}
self.input_bindings = self.BindingList(self, "input")
self.output_bindings = self.BindingList(self, "output")
# The mapping of global parameters and module parameters.
self.param_bindings = self.BindingList(self, "param")

def __str__(self):
# Get configuration information as a string.

# Use topological sort to get correct module order.
self.dag_topology_sort()

# Get param dependencies.
param_dump = "Params\n"
for param_name in self.param_bindings.bindings:
inf = self.param_bindings.bindings[param_name]
param_dump += str(inf) + "\n"
# Get the input dependencies.
input_dump = "Inputs\n"
input_dump = "\nInputs\n"
for input_name in self.input_bindings.bindings:
inf = self.input_bindings.bindings[input_name]
input_dump += str(inf) + "\n"
Expand All @@ -444,7 +599,7 @@ def __str__(self):
for name in sorted(output.keys()):
output_dump += f" |output({name}) : {output[name]}\n"

return input_dump + output_dump + connections_dump
return param_dump + input_dump + output_dump + connections_dump

def __getitem__(self, key):
if isinstance(key, tvm.ir.module.IRModule):
Expand All @@ -457,6 +612,8 @@ def __getitem__(self, key):
return self.input_bindings
if key == "output":
return self.output_bindings
if key == "param":
return self.param_bindings

raise RuntimeError(f"{key} not found.")

Expand All @@ -468,6 +625,8 @@ def get_config(self):
# Use topological sort to get the correct order of modules.
self.dag_topology_sort()
mconfig = {}
module_connection = {}
input_connection = {}
for mod in self.mod_wrapper:
# Generate pipeline configuration.
mconf = {}
Expand Down Expand Up @@ -495,7 +654,7 @@ def get_config(self):
mconf["mod_idx"] = module.idx
mconf["output"] = output_conf

mconfig[mod] = {
module_connection[mod] = {
"pipeline": mconf,
"target_host": module.target_host,
"mod_name": "default",
Expand All @@ -505,6 +664,33 @@ def get_config(self):
"dev": module.dev,
}

# Create a mapping of global input and module input.
input_connection = []
for input_name in self.input_bindings.bindings:
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
if "interface_name" not in input_dict["connection"][0]:
raise RuntimeError(f"interface_name is missing in connection config!")
# Establish the mapping of global interface and the mapping of module interfaces.
input_map = {
"global_interface_name": input_dict["interface_name"],
"mod_idx": input_dict["connection"][0]["mod_idx"],
"module_interface_name": input_dict["connection"][0]["interface_name"],
}
input_connection.append(input_map)

# Create a mapping of global param and module param.
param_connection = []
for param_name in self.param_bindings.bindings:
param_dict = self.param_bindings.bindings[param_name].get_binding_dict()
param_map = {
"global_param_name": param_dict["interface_name"],
"mod_idx": param_dict["connection"][0]["mod_idx"],
}
param_connection.append(param_map)

mconfig["module_connection"] = module_connection
mconfig["input_connection"] = input_connection
mconfig["param_connection"] = param_connection
return mconfig

def dag_topology_sort(self):
Expand Down Expand Up @@ -627,13 +813,13 @@ def export_library(self, directory_path):
)

# Get the graph, lib, and parameters from GraphExecutorFactoryModule.
graph, lib, params = self.pipeline_mods[lib_index]["lib"]
lib = self.pipeline_mods[lib_index]["lib"]
# Export the lib, graph, and parameters to disk.
lib.export_library(mconfig["lib_name"])
with open(mconfig["json_name"], "w") as file_handle:
file_handle.write(graph)
file_handle.write(lib.graph_json)
with open(mconfig["params_name"], "wb") as file_handle:
file_handle.write(relay.save_param_dict(params))
file_handle.write(relay.save_param_dict(lib.params))

load_config.append(mconfig)

Expand Down
Loading