Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 94 additions & 17 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def build(pipe_configs):
raise RuntimeError('"module_connection" is missing')
if "input_connection" not in config:
raise RuntimeError('"input_connection" is missing')
if "param_connection" not in config:
raise RuntimeError('"param_connection" is missing')

mod_n_configs = config["module_connection"]
config_len = len(mod_n_configs)
Expand Down Expand Up @@ -91,6 +93,7 @@ def build(pipe_configs):
# map of global input and subgraph input, and the "module_connection" is used to
# record module dependency.
string_config = {}
string_config["param_connection"] = config["param_connection"]
string_config["input_connection"] = config["input_connection"]
string_config["module_connection"] = module_string_config

Expand All @@ -114,6 +117,8 @@ def __init__(self, module):
# Get the packed functions from the pipeline executor.
self._get_num_outputs = self.module["get_num_outputs"]
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
self._get_params_group_pipeline_map = self.module["get_params_group_pipeline_map"]
self._set_param = self.module["set_param"]

def get_input_pipeline_map(self, name):
"""Using the "name" to get the corresponding subgraph index and also get the "input name"
Expand All @@ -125,6 +130,39 @@ def get_input_pipeline_map(self, name):
"""
return self._get_input_pipeline_map(name)

def get_params_group_pipeline_map(self, name):
"""Use the name of the parameters group to get the corresponding runtime module index.

Parameters
----------
name: str
The parameter group name.

Returns
-------
module_index: int
The index of the runtime module.
"""
return self._get_params_group_pipeline_map(name)

def set_params(self, params_group_name, params_data):
"""Set the parameter group value given the parameter group name. Note that the parameter
group name is declared in the pipeline executor config.

Parameters
----------
params_group_name : str
The parameters group name.

params_data : Dict[str, NDArray]
A map from parameter name to data.
"""
if not params_data:
raise RuntimeError('"params_data is empty!"')

for key, val in params_data.items():
self._set_param(params_group_name, key, val)

@property
def num_outputs(self):
"""Get the number of outputs.
Expand Down Expand Up @@ -311,9 +349,19 @@ def connect(self, binding):
if self.io_owner == binding.io_owner:
raise RuntimeError("Can not bind itself.")

if self.io_type == "param" and not self.is_pipeline_executor_interface():
raise RuntimeError(
'The "param" binding can only be used by a pipeline executor interface!'
)

if not self.is_pipeline_executor_interface() and self.io_type == "input":
raise RuntimeError("Module can only bind from output interface!")

if self.io_type == "param" and binding.io_type != "param":
raise RuntimeError(
'A global "param" interface can only be bind with a module "param" interface!'
)

if (
not self.is_pipeline_executor_interface()
and not binding.is_pipeline_executor_interface()
Expand Down Expand Up @@ -412,6 +460,7 @@ def __init__(self, mod=None):
self.output_type = InferType()(mod)["main"].checked_type.ret_type
self.input_bindings = PipelineConfig.BindingList(self, "input")
self.output_bindings = PipelineConfig.BindingList(self, "output")
self.param_binding = PipelineConfig.Binding(self, "param", "param")

def __eq__(self, other):
if isinstance(other, PipelineConfig.ModuleWrapper):
Expand All @@ -427,6 +476,9 @@ def __getitem__(self, key):
if key == "output":
return self.output_bindings

if key == "param":
return self.param_binding

raise RuntimeError(f"{key} not found!")

raise RuntimeError('The data type of "key" is not supported!')
Expand Down Expand Up @@ -483,14 +535,21 @@ def __init__(self):
self.mod_wrapper = {}
self.input_bindings = self.BindingList(self, "input")
self.output_bindings = self.BindingList(self, "output")
# There is a map of global parameters group and module index.
self.param_group_bindings = self.BindingList(self, "param")

def __str__(self):
# Get configuration information as a string.

# Use topological sort to get correct module order.
self.dag_topology_sort()
# Getting the parameters dependencies.
param_dump = "Params\n"
for param_name in self.param_group_bindings.bindings:
inf = self.param_group_bindings.bindings[param_name]
param_dump += str(inf) + "\n"
# Get the input dependencies.
input_dump = "Inputs\n"
input_dump = "\nInputs\n"
for input_name in self.input_bindings.bindings:
inf = self.input_bindings.bindings[input_name]
input_dump += str(inf) + "\n"
Expand All @@ -516,7 +575,7 @@ def __str__(self):
for name in sorted(output.keys()):
output_dump += f" |output({name}) : {output[name]}\n"

return input_dump + output_dump + connections_dump
return param_dump + input_dump + output_dump + connections_dump

def __getitem__(self, key):
if isinstance(key, tvm.ir.module.IRModule):
Expand All @@ -529,8 +588,12 @@ def __getitem__(self, key):
return self.input_bindings
if key == "output":
return self.output_bindings
if key == "param_group":
return self.param_group_bindings

raise RuntimeError(f"{key} not found!")

raise RuntimeError(f"{key} not found.")
raise RuntimeError(f'The key type "{type(key)}" is not supported!')

def get_config(self):
"""Get the configuration information in dictionary form, this configuration
Expand All @@ -541,7 +604,6 @@ def get_config(self):
self.dag_topology_sort()
mconfig = {}
module_connection = {}
input_connection = {}
for mod in self.mod_wrapper:
# Generate pipeline configuration.
mconf = {}
Expand Down Expand Up @@ -579,22 +641,33 @@ def get_config(self):
"dev": module.dev,
}

# Create a map of pipeline input and subgraph input.
input_connection = []
for input_name in self.input_bindings.bindings:
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
if "interface_name" not in input_dict["connection"][0]:
raise RuntimeError("interface_name is missing in connection config!")
# Creating the map of global interface and subgraph interface.
input_map = {
"global_interface_name": input_dict["interface_name"],
"mod_idx": input_dict["connection"][0]["mod_idx"],
"module_interface_name": input_dict["connection"][0]["interface_name"],
}
input_connection.append(input_map)
# Creating a map including pipeline inputs and subgraph inputs.
input_connection = []
for input_name in self.input_bindings.bindings:
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
if "interface_name" not in input_dict["connection"][0]:
raise RuntimeError("interface_name is missing in connection config!")
# Creating the map including global interfaces and subgraph interfaces.
input_map = {
"global_interface_name": input_dict["interface_name"],
"mod_idx": input_dict["connection"][0]["mod_idx"],
"module_interface_name": input_dict["connection"][0]["interface_name"],
}
input_connection.append(input_map)

# Create a map including global parameters groups and modules.
param_connection = []
for param_name in self.param_group_bindings.bindings:
param_dict = self.param_group_bindings.bindings[param_name].get_binding_dict()
param_map = {
"global_param_name": param_dict["interface_name"],
"mod_idx": param_dict["connection"][0]["mod_idx"],
}
param_connection.append(param_map)

mconfig["module_connection"] = module_connection
mconfig["input_connection"] = input_connection
mconfig["param_connection"] = param_connection
return mconfig

def dag_topology_sort(self):
Expand All @@ -613,8 +686,12 @@ def dag_topology_sort(self):

mlist += temp_list

mod_wrapper_sort = {}
for mod, i in zip(mlist, range(len(mlist))):
self.mod_wrapper[mod].set_idx_name(i)
mod_wrapper_sort[mod] = self.mod_wrapper[mod]

self.mod_wrapper = mod_wrapper_sort

def get_mod_idx(self, mod):
# Return the module index.
Expand Down
42 changes: 39 additions & 3 deletions src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,27 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
} else if (name == "get_input_pipeline_map") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
*rv = this->GetInputPipeplineMapping(args[0].operator String());
*rv = this->GetInputPipeplineMap(args[0].operator String());
} else {
LOG(FATAL) << "Function only support the input name value in the form of string";
}
});
} else if (name == "get_params_group_pipeline_map") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
*rv = this->GetParamsGroupPipelineMap(args[0].operator String());
} else {
LOG(FATAL) << "Function only support the input name value in the form of string";
}
});
} else if (name == "set_param") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0]) && String::CanConvertFrom(args[1])) {
this->SetParam(args[0].operator String(), args[1].operator String(), args[2]);
} else {
LOG(FATAL) << "Function only support the parameter name and the key in the form of string";
}
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc();
Expand All @@ -55,11 +71,20 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
* \param The global input name.
* \return Returning the index and the input interface name of corresponding subgraph.
*/
Array<String> PipelineExecutor::GetInputPipeplineMapping(std::string input_name) {
Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
std::pair<int, std::string> map = input_connection_config[input_name];
return {std::to_string(map.first), map.second};
}

/*!
* \brief Return the module index for the parameters group name.
* \param name The parameters group name.
* \return int The module index.
*/
int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
return param_connection_config[name];
}

/*!
* \brief Use the mod_config information to create a graph runtime list.
* \param mod_config The config information that generates by the export library function call.
Expand Down Expand Up @@ -115,7 +140,18 @@ std::vector<Module> PipelineExecutor::CreateGraphModules(const ModuleConfig& mod
}
return ret;
}

/*!
* \brief Set a parameter into a graph module.
* \param param_group_name The parameters group name.
* \param param_key_name The parameter key name.
* \param data_in The parameter data.
*/
void PipelineExecutor::SetParam(std::string param_group_name, std::string param_key_name,
DLTensor* data_in) {
// Get the module index from the param name.
int module_index = this->GetParamsGroupPipelineMap(param_group_name);
// TODO(huajsj): set the parameters into runtime module.
}
/*!
* \brief Initialize the pipeline executor with a list of modules to be pipelined
* and config in JSON format.
Expand Down
21 changes: 19 additions & 2 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
* \param The global input name.
* \return Returning the index and the input interface name of corresponding subgraph.
*/
Array<String> GetInputPipeplineMapping(std::string input_name);
Array<String> GetInputPipeplineMap(std::string input_name);
/*!
* \brief This function return a module index for the global parameters group name.
* \param name The parameters group name.
* \return Returning a runtime module index.
*/
int GetParamsGroupPipelineMap(const std::string& name);
/*!
* \brief Use the parameters group name to get the specific backend runtime then use
* the param_key_name to set param data for the said backend runtime.
* \param param_group_name The parameters group name.
* \param param_key_name The parameter key name.
* \param data_in The parameter value.
*/
void SetParam(std::string param_group_name, std::string param_key_name, DLTensor* data_in);
/*!
* \brief Get the number of outputs.
*
* \return The number of outputs.
*/
int NumOutputs() const { return num_outputs_; }

/*!\brief Load the module files information.*/
ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) {
reader->BeginArray();
Expand Down Expand Up @@ -126,6 +139,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
ConfigPipelineExecution pipeline_config_;
/*!\brief The map of global input and subgraph input.*/
InputConnectionConfig input_connection_config;
/*!\brief The map includes global parameters groups and runtime modules.*/
ParamConnectionConfig param_connection_config;
/*!\brief The module information used to create the graph runtimes.*/
ModuleConfig mod_config_;
/*!\brief How many outputs are in this pipeline executor.*/
Expand All @@ -139,6 +154,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
reader->Read(&pipeline_config_);
} else if (key == "input_connection") {
reader->Read(&input_connection_config);
} else if (key == "param_connection") {
reader->Read(&param_connection_config);
} else {
LOG(FATAL) << "do not support key " << key;
}
Expand Down
42 changes: 42 additions & 0 deletions src/runtime/pipeline/pipeline_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,48 @@ struct InputConnectionConfig {
}
};

/*!
* \brief A map includes global module parameters groups and graph modudles.
*/
struct ParamConnectionConfig {
/*!\brief Mapping from the name of a global module parameters group to the index of a runtime
* module.
*/
std::unordered_map<std::string, int> param_connection;
bool Empty() { return param_connection.empty(); }
int operator[](const std::string key) {
if (param_connection.find(key) == param_connection.end()) {
LOG(FATAL) << "do not support key " << key;
}
return param_connection[key];
}
/*!
* \brief Load from JSONReader.
* \param reader Json reader.
*/
void Load(dmlc::JSONReader* reader) {
reader->BeginArray();
while (reader->NextArrayItem()) {
reader->BeginObject();
std::string key;
std::string global_param_name;
int mod_idx = -1;
while (reader->NextObjectItem(&key)) {
if (key == "global_param_name") {
reader->Read(&global_param_name);
} else if (key == "mod_idx") {
reader->Read(&mod_idx);
} else {
LOG(FATAL) << "do not support key " << key;
}
}
ICHECK(mod_idx >= 0) << "Invalid module index value " << mod_idx;
ICHECK(!global_param_name.empty()) << "Invalid global parameter group name value";
param_connection[global_param_name] = mod_idx;
}
}
};

/*!
* \brief The information used to initialize the graph executor module, the information
* come from the export library function call.
Expand Down
Loading