diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 7cabb8b3d2ed..c75aa3dad43b 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -54,6 +54,8 @@ def build(pipe_configs): raise RuntimeError('"module_connection" is missing') if "input_connection" not in config: raise RuntimeError('"input_connection" is missing') + if "param_connection" not in config: + raise RuntimeError('"param_connection" is missing') mod_n_configs = config["module_connection"] config_len = len(mod_n_configs) @@ -91,6 +93,7 @@ def build(pipe_configs): # map of global input and subgraph input, and the "module_connection" is used to # record module dependency. string_config = {} + string_config["param_connection"] = config["param_connection"] string_config["input_connection"] = config["input_connection"] string_config["module_connection"] = module_string_config @@ -114,6 +117,8 @@ def __init__(self, module): # Get the packed functions from the pipeline executor. self._get_num_outputs = self.module["get_num_outputs"] self._get_input_pipeline_map = self.module["get_input_pipeline_map"] + self._get_params_group_pipeline_map = self.module["get_params_group_pipeline_map"] + self._set_param = self.module["set_param"] def get_input_pipeline_map(self, name): """Using the "name" to get the corresponding subgraph index and also get the "input name" @@ -125,6 +130,39 @@ def get_input_pipeline_map(self, name): """ return self._get_input_pipeline_map(name) + def get_params_group_pipeline_map(self, name): + """Use the name of the parameters group to get the corresponding runtime module index. + + Parameters + ---------- + name: str + The parameter group name. + + Returns + ------- + module_index: int + The index of the runtime module. + """ + return self._get_params_group_pipeline_map(name) + + def set_params(self, params_group_name, params_data): + """Set the parameter group value given the parameter group name. Note that the parameter + group name is declared in the pipeline executor config. + + Parameters + ---------- + params_group_name : str + The parameters group name. + + params_data : Dict[str, NDArray] + A map from parameter name to data. + """ + if not params_data: + raise RuntimeError('"params_data is empty!"') + + for key, val in params_data.items(): + self._set_param(params_group_name, key, val) + @property def num_outputs(self): """Get the number of outputs. @@ -311,9 +349,19 @@ def connect(self, binding): if self.io_owner == binding.io_owner: raise RuntimeError("Can not bind itself.") + if self.io_type == "param" and not self.is_pipeline_executor_interface(): + raise RuntimeError( + 'The "param" binding can only be used by a pipeline executor interface!' + ) + if not self.is_pipeline_executor_interface() and self.io_type == "input": raise RuntimeError("Module can only bind from output interface!") + if self.io_type == "param" and binding.io_type != "param": + raise RuntimeError( + 'A global "param" interface can only be bind with a module "param" interface!' + ) + if ( not self.is_pipeline_executor_interface() and not binding.is_pipeline_executor_interface() @@ -412,6 +460,7 @@ def __init__(self, mod=None): self.output_type = InferType()(mod)["main"].checked_type.ret_type self.input_bindings = PipelineConfig.BindingList(self, "input") self.output_bindings = PipelineConfig.BindingList(self, "output") + self.param_binding = PipelineConfig.Binding(self, "param", "param") def __eq__(self, other): if isinstance(other, PipelineConfig.ModuleWrapper): @@ -427,6 +476,9 @@ def __getitem__(self, key): if key == "output": return self.output_bindings + if key == "param": + return self.param_binding + raise RuntimeError(f"{key} not found!") raise RuntimeError('The data type of "key" is not supported!') @@ -483,14 +535,21 @@ def __init__(self): self.mod_wrapper = {} self.input_bindings = self.BindingList(self, "input") self.output_bindings = self.BindingList(self, "output") + # There is a map of global parameters group and module index. + self.param_group_bindings = self.BindingList(self, "param") def __str__(self): # Get configuration information as a string. # Use topological sort to get correct module order. self.dag_topology_sort() + # Getting the parameters dependencies. + param_dump = "Params\n" + for param_name in self.param_group_bindings.bindings: + inf = self.param_group_bindings.bindings[param_name] + param_dump += str(inf) + "\n" # Get the input dependencies. - input_dump = "Inputs\n" + input_dump = "\nInputs\n" for input_name in self.input_bindings.bindings: inf = self.input_bindings.bindings[input_name] input_dump += str(inf) + "\n" @@ -516,7 +575,7 @@ def __str__(self): for name in sorted(output.keys()): output_dump += f" |output({name}) : {output[name]}\n" - return input_dump + output_dump + connections_dump + return param_dump + input_dump + output_dump + connections_dump def __getitem__(self, key): if isinstance(key, tvm.ir.module.IRModule): @@ -529,8 +588,12 @@ def __getitem__(self, key): return self.input_bindings if key == "output": return self.output_bindings + if key == "param_group": + return self.param_group_bindings + + raise RuntimeError(f"{key} not found!") - raise RuntimeError(f"{key} not found.") + raise RuntimeError(f'The key type "{type(key)}" is not supported!') def get_config(self): """Get the configuration information in dictionary form, this configuration @@ -541,7 +604,6 @@ def get_config(self): self.dag_topology_sort() mconfig = {} module_connection = {} - input_connection = {} for mod in self.mod_wrapper: # Generate pipeline configuration. mconf = {} @@ -579,22 +641,33 @@ def get_config(self): "dev": module.dev, } - # Create a map of pipeline input and subgraph input. - input_connection = [] - for input_name in self.input_bindings.bindings: - input_dict = self.input_bindings.bindings[input_name].get_binding_dict() - if "interface_name" not in input_dict["connection"][0]: - raise RuntimeError("interface_name is missing in connection config!") - # Creating the map of global interface and subgraph interface. - input_map = { - "global_interface_name": input_dict["interface_name"], - "mod_idx": input_dict["connection"][0]["mod_idx"], - "module_interface_name": input_dict["connection"][0]["interface_name"], - } - input_connection.append(input_map) + # Creating a map including pipeline inputs and subgraph inputs. + input_connection = [] + for input_name in self.input_bindings.bindings: + input_dict = self.input_bindings.bindings[input_name].get_binding_dict() + if "interface_name" not in input_dict["connection"][0]: + raise RuntimeError("interface_name is missing in connection config!") + # Creating the map including global interfaces and subgraph interfaces. + input_map = { + "global_interface_name": input_dict["interface_name"], + "mod_idx": input_dict["connection"][0]["mod_idx"], + "module_interface_name": input_dict["connection"][0]["interface_name"], + } + input_connection.append(input_map) + + # Create a map including global parameters groups and modules. + param_connection = [] + for param_name in self.param_group_bindings.bindings: + param_dict = self.param_group_bindings.bindings[param_name].get_binding_dict() + param_map = { + "global_param_name": param_dict["interface_name"], + "mod_idx": param_dict["connection"][0]["mod_idx"], + } + param_connection.append(param_map) mconfig["module_connection"] = module_connection mconfig["input_connection"] = input_connection + mconfig["param_connection"] = param_connection return mconfig def dag_topology_sort(self): @@ -613,8 +686,12 @@ def dag_topology_sort(self): mlist += temp_list + mod_wrapper_sort = {} for mod, i in zip(mlist, range(len(mlist))): self.mod_wrapper[mod].set_idx_name(i) + mod_wrapper_sort[mod] = self.mod_wrapper[mod] + + self.mod_wrapper = mod_wrapper_sort def get_mod_idx(self, mod): # Return the module index. diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 32414c607df6..0ca291a2fbbe 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -37,11 +37,27 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, } else if (name == "get_input_pipeline_map") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { - *rv = this->GetInputPipeplineMapping(args[0].operator String()); + *rv = this->GetInputPipeplineMap(args[0].operator String()); } else { LOG(FATAL) << "Function only support the input name value in the form of string"; } }); + } else if (name == "get_params_group_pipeline_map") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + *rv = this->GetParamsGroupPipelineMap(args[0].operator String()); + } else { + LOG(FATAL) << "Function only support the input name value in the form of string"; + } + }); + } else if (name == "set_param") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0]) && String::CanConvertFrom(args[1])) { + this->SetParam(args[0].operator String(), args[1].operator String(), args[2]); + } else { + LOG(FATAL) << "Function only support the parameter name and the key in the form of string"; + } + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(); @@ -55,11 +71,20 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, * \param The global input name. * \return Returning the index and the input interface name of corresponding subgraph. */ -Array PipelineExecutor::GetInputPipeplineMapping(std::string input_name) { +Array PipelineExecutor::GetInputPipeplineMap(std::string input_name) { std::pair map = input_connection_config[input_name]; return {std::to_string(map.first), map.second}; } +/*! + * \brief Return the module index for the parameters group name. + * \param name The parameters group name. + * \return int The module index. + */ +int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) { + return param_connection_config[name]; +} + /*! * \brief Use the mod_config information to create a graph runtime list. * \param mod_config The config information that generates by the export library function call. @@ -115,7 +140,18 @@ std::vector PipelineExecutor::CreateGraphModules(const ModuleConfig& mod } return ret; } - +/*! + * \brief Set a parameter into a graph module. + * \param param_group_name The parameters group name. + * \param param_key_name The parameter key name. + * \param data_in The parameter data. + */ +void PipelineExecutor::SetParam(std::string param_group_name, std::string param_key_name, + DLTensor* data_in) { + // Get the module index from the param name. + int module_index = this->GetParamsGroupPipelineMap(param_group_name); + // TODO(huajsj): set the parameters into runtime module. +} /*! * \brief Initialize the pipeline executor with a list of modules to be pipelined * and config in JSON format. diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index 1ae52e07c260..6d4c7ba1fa4f 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -75,14 +75,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode { * \param The global input name. * \return Returning the index and the input interface name of corresponding subgraph. */ - Array GetInputPipeplineMapping(std::string input_name); + Array GetInputPipeplineMap(std::string input_name); + /*! + * \brief This function return a module index for the global parameters group name. + * \param name The parameters group name. + * \return Returning a runtime module index. + */ + int GetParamsGroupPipelineMap(const std::string& name); + /*! + * \brief Use the parameters group name to get the specific backend runtime then use + * the param_key_name to set param data for the said backend runtime. + * \param param_group_name The parameters group name. + * \param param_key_name The parameter key name. + * \param data_in The parameter value. + */ + void SetParam(std::string param_group_name, std::string param_key_name, DLTensor* data_in); /*! * \brief Get the number of outputs. * * \return The number of outputs. */ int NumOutputs() const { return num_outputs_; } - /*!\brief Load the module files information.*/ ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) { reader->BeginArray(); @@ -126,6 +139,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode { ConfigPipelineExecution pipeline_config_; /*!\brief The map of global input and subgraph input.*/ InputConnectionConfig input_connection_config; + /*!\brief The map includes global parameters groups and runtime modules.*/ + ParamConnectionConfig param_connection_config; /*!\brief The module information used to create the graph runtimes.*/ ModuleConfig mod_config_; /*!\brief How many outputs are in this pipeline executor.*/ @@ -139,6 +154,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode { reader->Read(&pipeline_config_); } else if (key == "input_connection") { reader->Read(&input_connection_config); + } else if (key == "param_connection") { + reader->Read(¶m_connection_config); } else { LOG(FATAL) << "do not support key " << key; } diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 52422b764564..aa831070ccdb 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -251,6 +251,48 @@ struct InputConnectionConfig { } }; +/*! + * \brief A map includes global module parameters groups and graph modudles. + */ +struct ParamConnectionConfig { + /*!\brief Mapping from the name of a global module parameters group to the index of a runtime + * module. + */ + std::unordered_map param_connection; + bool Empty() { return param_connection.empty(); } + int operator[](const std::string key) { + if (param_connection.find(key) == param_connection.end()) { + LOG(FATAL) << "do not support key " << key; + } + return param_connection[key]; + } + /*! + * \brief Load from JSONReader. + * \param reader Json reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + reader->BeginObject(); + std::string key; + std::string global_param_name; + int mod_idx = -1; + while (reader->NextObjectItem(&key)) { + if (key == "global_param_name") { + reader->Read(&global_param_name); + } else if (key == "mod_idx") { + reader->Read(&mod_idx); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid module index value " << mod_idx; + ICHECK(!global_param_name.empty()) << "Invalid global parameter group name value"; + param_connection[global_param_name] = mod_idx; + } + } +}; + /*! * \brief The information used to initialize the graph executor module, the information * come from the export library function call. diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 4e51f873b3fa..83cf237dbfcc 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -126,44 +126,76 @@ def get_manual_conf(mods, target): return mod_config -def test_pipe_config_check(): - # This function is used to trigger runtime error by applying wrong logic connection. +def recreate_parameters(mod): + # Get the binding parameters from a module, then create the same parameters with different data. + # This function is used to test the "parameter" connection. + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, "llvm") - # Get three pipeline modules here. - (mod1, mod2, mod3), dshape = get_mannual_mod() + mod_customized_params = {} + for key, value in lib.params.items(): + new_value = value.numpy() + np.full(value.shape, 10).astype(value.dtype) + mod_customized_params[key] = tvm.nd.array(new_value) + return mod_customized_params - # The input or output name is illegal and expects a runtime error. - pipe_error = pipeline_executor.PipelineConfig() - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][9] - with pytest.raises(RuntimeError): - pipe_error[mod1]["input"]["data_9"] +def test_pipe_runtime_error_check(): + # This function is used to trigger runtime error by applying wrong logic. + if pipeline_executor.pipeline_executor_enabled(): + # Get three pipeline modules here. + (mod1, mod2, mod3), dshape = get_mannual_mod() + + # The input or output name is illegal and expects a runtime error. + pipe_error = pipeline_executor.PipelineConfig() + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][9] + + with pytest.raises(RuntimeError): + pipe_error[mod1]["input"]["data_9"] + + # The module connection will cause a cycle in DAG and expects runtime error. + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][0].connect(pipe_error[mod2]["input"]["data_0"]) + pipe_error[mod2]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) + + # The module connection is illegal and expects runtime error. + + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) - # The module connection will cause a cycle in DAG and expects runtime error. - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][0].connect(pipe_error[mod2]["input"]["data_0"]) - pipe_error[mod2]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod1]["input"]["data_0"]) - # The module connection is illegal and expects runtime error. + with pytest.raises(RuntimeError): + pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod2]["input"]["data_0"]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][0].connect(pipe_error["input"]["data_0"]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod1]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error["input"]["data_0"].connect(pipe_error[mod1]["output"][0]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod2]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error["output"]["0"].connect(pipe_error[mod1]["output"][0]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][0].connect(pipe_error["input"]["data_0"]) + # Create pipeline executor to check the executor runtime errors. + pipe_config = pipeline_executor.PipelineConfig() + pipe_config[mod1].target = "llvm" + pipe_config[mod1].dev = tvm.cpu(0) + pipe_config["param_group"]["param_0"].connect(pipe_config[mod1]["param"]) + pipe_config[mod1]["output"][0].connect(pipe_config["output"]["0"]) + # Build and create a pipeline module. + with tvm.transform.PassContext(opt_level=3): + pipeline_mod_factory = pipeline_executor.build(pipe_config) + pipeline_module = pipeline_executor.PipelineModule(pipeline_mod_factory) + customized_parameters = recreate_parameters(mod1) - with pytest.raises(RuntimeError): - pipe_error["input"]["data_0"].connect(pipe_error[mod1]["output"][0]) + # Checking the pipeline executor runtime errors. + with pytest.raises(RuntimeError): + pipeline_module.set_params("param_0", None) - with pytest.raises(RuntimeError): - pipe_error["output"]["0"].connect(pipe_error[mod1]["output"][0]) + with pytest.raises(RuntimeError): + pipeline_module.set_params("param_1", customized_parameters) def test_pipeline(): @@ -180,6 +212,9 @@ def test_pipeline(): pipe_config = pipeline_executor.PipelineConfig() + customized_parameters = recreate_parameters(mod2) + # The global parameters group named "param_0" will be connected to "mod1" as parameters. + pipe_config["param_group"]["param_0"].connect(pipe_config[mod2]["param"]) # The pipeline input named "data_0" will be connected to a input named "data_0" # of mod1. pipe_config["input"]["data_a"].connect(pipe_config[mod1]["input"]["data_0"]) @@ -202,6 +237,7 @@ def test_pipeline(): # The mod3 output[0] will be connected to pipeline output[1]. pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) + print(pipe_config) # Print configueration (print(pipe_config)), the result looks like following. # # Inputs @@ -254,6 +290,10 @@ def test_pipeline(): assert input_map[0] == "1" and input_map[1] == "data_1" input_map = pipeline_module_test.get_input_pipeline_map("data_a") assert input_map[0] == "0" and input_map[1] == "data_0" + module_index = pipeline_module_test.get_params_group_pipeline_map("param_0") + assert module_index == 1 + # Use the parameters group name to set parameters. + pipeline_module_test.set_params("param_0", customized_parameters) if __name__ == "__main__":