Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions merlin/systems/dag/ops/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@
import cuml.ensemble as cuml_ensemble
except ImportError:
cuml_ensemble = None
try:
import triton_python_backend_utils as pb_utils
except ImportError:
pb_utils = None
187 changes: 181 additions & 6 deletions merlin/systems/dag/ops/fil.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import pathlib
import pickle
Expand All @@ -9,8 +24,142 @@

from merlin.dag import ColumnSelector # noqa
from merlin.schema import ColumnSchema, Schema # noqa
from merlin.systems.dag.ops.compat import cuml_ensemble, lightgbm, sklearn_ensemble, xgboost
from merlin.systems.dag.ops.operator import InferenceOperator
from merlin.systems.dag.ops.compat import (
cuml_ensemble,
lightgbm,
pb_utils,
sklearn_ensemble,
xgboost,
)
from merlin.systems.dag.ops.operator import (
InferenceDataFrame,
InferenceOperator,
PipelineableInferenceOperator,
)


class PredictForest(PipelineableInferenceOperator):
"""Operator for running inference on Forest models.

This works for gradient-boosted decision trees (GBDTs) and Random forests (RF).
While RF and GBDT algorithms differ in the way they train the models,
they both produce a decision forest as their output.

Uses the Forest Inference Library (FIL) backend for inference.
"""

def __init__(self, model, input_schema, *, backend="python", **fil_params):
"""Instantiate a FIL inference operator.

Parameters
----------
model : Forest Model Instance
A forest model class. Supports XGBoost, LightGBM, and Scikit-Learn.
input_schema : merlin.schema.Schema
The schema representing the input columns expected by the model.
backend : str
The Triton backend to use to when running this operator.
**fil_params
The parameters to pass to the FIL operator.
"""
if model is not None:
self.fil_op = FIL(model, **fil_params)
self.backend = backend
self.input_schema = input_schema
self._fil_model_name = None

def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Return the output schema representing the columns this operator returns."""
return self.fil_op.compute_output_schema(
input_schema, col_selector, prev_output_schema=prev_output_schema
)

def compute_input_schema(
self,
root_schema: Schema,
parents_schema: Schema,
deps_schema: Schema,
selector: ColumnSelector,
) -> Schema:
"""Return the input schema representing the input columns this operator expects to use."""
return self.input_schema

def export(self, path, input_schema, output_schema, params=None, node_id=None, version=1):
"""Export the class and related files to the path specified."""
fil_model_config = self.fil_op.export(
path,
input_schema,
output_schema,
params=params,
node_id=node_id,
version=version,
)
params = params or {}
params = {**params, "fil_model_name": fil_model_config.name}
return super().export(
path,
input_schema,
output_schema,
params=params,
node_id=node_id,
version=version,
backend=self.backend,
)

@classmethod
def from_config(cls, config: dict) -> "PredictForest":
"""Instantiate the class from a dictionary representation.

Expected structure:
{
"input_dict": str # JSON dict with input names and schemas
"params": str # JSON dict with params saved at export
}

"""
column_schemas = [
ColumnSchema(name, **schema_properties)
for name, schema_properties in json.loads(config["input_dict"]).items()
]
input_schema = Schema(column_schemas)
cls_instance = cls(None, input_schema)
params = json.loads(config["params"])
cls_instance.set_fil_model_name(params["fil_model_name"])
return cls_instance

@property
def fil_model_name(self):
return self._fil_model_name

def set_fil_model_name(self, fil_model_name):
self._fil_model_name = fil_model_name

def transform(self, df: InferenceDataFrame) -> InferenceDataFrame:
"""Transform the dataframe by applying this FIL operator to the set of input columns.

Parameters
-----------
df: InferenceDataFrame
A pandas or cudf dataframe that this operator will work on

Returns
-------
InferenceDataFrame
Returns a transformed dataframe for this operator"""
input0 = np.array([x.ravel() for x in df.tensors.values()]).astype(np.float32).T
inference_request = pb_utils.InferenceRequest(
model_name=self.fil_model_name,
requested_output_names=["output__0"],
inputs=[pb_utils.Tensor("input__0", input0)],
)
inference_response = inference_request.exec()
output0 = pb_utils.get_output_tensor_by_name(inference_response, "output__0")
return InferenceDataFrame({"output__0": output0})


class FIL(InferenceOperator):
Expand All @@ -32,6 +181,7 @@ def __init__(
threads_per_tree=1,
blocks_per_sm=0,
transfer_threshold=0,
instance_group="AUTO",
):
"""Instantiate a FIL inference operator.

Expand Down Expand Up @@ -88,6 +238,9 @@ def __init__(
to the GPU for processing) will provide optimal latency and throughput, but
for low-latency deployments with the use_experimental_optimizations flag set
to true, higher values may be desirable.
instance_group : str
One of "AUTO", "GPU", "CPU". Default value is "AUTO". Specifies whether
inference will take place on the GPU or CPU.
"""
self.max_batch_size = max_batch_size
self.parameters = dict(
Expand All @@ -98,6 +251,7 @@ def __init__(
blocks_per_sm=blocks_per_sm,
storage_type=storage_type,
threshold=threshold,
instance_group=instance_group,
)
self.fil_model = get_fil_model(model)
super().__init__()
Expand All @@ -121,7 +275,15 @@ def compute_output_schema(
"""Returns output schema for FIL op"""
return Schema([ColumnSchema("output__0", dtype=np.float32)])

def export(self, path, input_schema, output_schema, node_id=None, version=1):
def export(
self,
path,
input_schema,
output_schema,
params: dict = None,
node_id=None,
version=1,
):
"""Export the model to the supplied path. Returns the config"""
node_name = f"{node_id}_{self.export_name}" if node_id is not None else self.export_name
node_export_path = pathlib.Path(path) / node_name
Expand Down Expand Up @@ -391,6 +553,7 @@ def fil_config(
blocks_per_sm=0,
threads_per_tree=1,
transfer_threshold=0,
instance_group="AUTO",
) -> model_config.ModelConfig:
"""Construct and return a FIL ModelConfig protobuf object.

Expand Down Expand Up @@ -453,6 +616,9 @@ def fil_config(
to the GPU for processing) will provide optimal latency and throughput, but
for low-latency deployments with the use_experimental_optimizations flag set
to true, higher values may be desirable.
instance_group : str
One of "AUTO", "GPU", "CPU". Default value is "AUTO". Specifies whether
inference will take place on the GPU or CPU.

Returns
model_config.ModelConfig
Expand Down Expand Up @@ -485,6 +651,17 @@ def fil_config(
"transfer_threshold": f"{transfer_threshold:d}",
}

supported_instance_groups = {"auto", "cpu", "gpu"}
instance_group = instance_group.lower() if isinstance(instance_group, str) else instance_group
if instance_group == "auto":
instance_group_kind = model_config.ModelInstanceGroup.Kind.KIND_AUTO
elif instance_group == "cpu":
instance_group_kind = model_config.ModelInstanceGroup.Kind.KIND_CPU
elif instance_group == "gpu":
instance_group_kind = model_config.ModelInstanceGroup.Kind.KIND_GPU
else:
raise ValueError(f"instance_group must be one of {supported_instance_groups}")

config = model_config.ModelConfig(
name=name,
backend="fil",
Expand All @@ -501,9 +678,7 @@ def fil_config(
name="output__0", data_type=model_config.TYPE_FP32, dims=[output_dim]
)
],
instance_group=[
model_config.ModelInstanceGroup(kind=model_config.ModelInstanceGroup.Kind.KIND_AUTO)
],
instance_group=[model_config.ModelInstanceGroup(kind=instance_group_kind)],
)

for parameter_key, parameter_value in parameters.items():
Expand Down
3 changes: 2 additions & 1 deletion merlin/systems/dag/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def export(
params: dict = None,
node_id: int = None,
version: int = 1,
backend: str = "python",
):
"""
Export the class object as a config and all related files to the user-defined path.
Expand Down Expand Up @@ -200,7 +201,7 @@ def export(
node_export_path = pathlib.Path(path) / node_name
node_export_path.mkdir(parents=True, exist_ok=True)

config = model_config.ModelConfig(name=node_name, backend="python", platform="op_runner")
config = model_config.ModelConfig(name=node_name, backend=backend, platform="op_runner")

config.parameters["operator_names"].string_value = json.dumps([node_name])

Expand Down
52 changes: 45 additions & 7 deletions merlin/systems/triton/oprunner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,48 @@


class TritonPythonModel:
"""Model for Triton Python Backend.

Every Python model must have "TritonPythonModel" as the class name
"""

def initialize(self, args):
"""Called only once when the model is being loaded. Allowing
the model to initialize any state associated with this model.

Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
self.model_config = json.loads(args["model_config"])
self.runner = OperatorRunner(self.model_config)

def execute(self, requests: List[InferenceRequest]) -> List[InferenceResponse]:
"""Receives a list of pb_utils.InferenceRequest as the only argument. This
function is called when an inference is requested for this model. Depending on the
batching configuration (e.g. Dynamic Batching) used, `requests` may contain
multiple requests. Every Python model, must create one pb_utils.InferenceResponse
for every pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse.

Parameters
----------
requests : list
A list of pb_utils.InferenceRequest

Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
params = self.model_config["parameters"]
op_names = json.loads(params["operator_names"]["string_value"])
first_operator_name = op_names[0]
Expand All @@ -67,14 +104,15 @@ def execute(self, requests: List[InferenceRequest]) -> List[InferenceResponse]:

raw_tensor_tuples = self.runner.execute(inf_df)

tensors = {
name: (data.get() if hasattr(data, "get") else data)
for name, data in raw_tensor_tuples
}

result = [Tensor(name, data) for name, data in tensors.items()]
output_tensors = []
for name, data in raw_tensor_tuples:
if isinstance(data, Tensor):
output_tensors.append(data)
data = data.get() if hasattr(data, "get") else data
tensor = Tensor(name, data)
output_tensors.append(tensor)

responses.append(InferenceResponse(result))
responses.append(InferenceResponse(output_tensors))

except Exception: # pylint: disable=broad-except
exc_type, exc_value, exc_traceback = sys.exc_info()
Expand Down
Loading