diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index 0509950b1140..070fc80dd769 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -28,8 +28,8 @@ import torch from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.base import KeyedModelHandler -from apache_beam.ml.inference.api import PredictionResult -from apache_beam.ml.inference.api import RunInference +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions @@ -135,9 +135,7 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True): lambda file_name, data: (file_name, preprocess_image(data)))) predictions = ( filename_value_pair - | - 'PyTorchRunInference' >> RunInference(model_handler).with_output_types( - Tuple[str, PredictionResult]) + | 'PyTorchRunInference' >> RunInference(model_handler) | 'ProcessOutput' >> beam.ParDo(PostProcessor())) if known_args.output: diff --git a/sdks/python/apache_beam/ml/inference/__init__.py b/sdks/python/apache_beam/ml/inference/__init__.py index cce3acad34a4..d3b4ff354067 100644 --- a/sdks/python/apache_beam/ml/inference/__init__.py +++ b/sdks/python/apache_beam/ml/inference/__init__.py @@ -14,3 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from apache_beam.ml.inference.base import RunInference diff --git a/sdks/python/apache_beam/ml/inference/api.py b/sdks/python/apache_beam/ml/inference/api.py deleted file mode 100644 index 3d70f874733b..000000000000 --- a/sdks/python/apache_beam/ml/inference/api.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# mypy: ignore-errors - -from dataclasses import dataclass -from typing import Tuple -from typing import TypeVar -from typing import Union - -import apache_beam as beam -from apache_beam.ml.inference import base - -_K = TypeVar('_K') -_INPUT_TYPE = TypeVar('_INPUT_TYPE') -_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') - - -@dataclass -class PredictionResult: - example: _INPUT_TYPE - inference: _OUTPUT_TYPE - - -@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) -@beam.typehints.with_output_types(Union[PredictionResult, Tuple[_K, PredictionResult]]) # pylint: disable=line-too-long -class RunInference(beam.PTransform): - """ - NOTE: This API and its implementation are under development and - do not provide backward compatibility guarantees. - - A transform that takes a PCollection of examples (or features) to be used on - an ML model. It will then output inferences (or predictions) for those - examples in a PCollection of PredictionResults, containing the input examples - and output inferences. - - If examples are paired with keys, it will output a tuple - (key, PredictionResult) for each (key, example) input. - - Models for supported frameworks can be loaded via a URI. Supported services - can also be used. - - TODO(BEAM-14046): Add and link to help documentation - """ - def __init__(self, model_loader: base.ModelHandler): - self._model_loader = model_loader - - def expand(self, pcoll: beam.PCollection) -> beam.PCollection: - return pcoll | base.RunInference(self._model_loader) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 6d4d54c911d7..ad7191cb59b9 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# TODO: https://github.com/apache/beam/issues/21822 +# mypy: ignore-errors """An extensible run inference transform. @@ -32,6 +34,7 @@ import pickle import sys import time +from dataclasses import dataclass from typing import Any from typing import Generic from typing import Iterable @@ -56,9 +59,17 @@ ModelT = TypeVar('ModelT') ExampleT = TypeVar('ExampleT') PredictionT = TypeVar('PredictionT') +_INPUT_TYPE = TypeVar('_INPUT_TYPE') +_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') KeyT = TypeVar('KeyT') +@dataclass +class PredictionResult: + example: _INPUT_TYPE + inference: _OUTPUT_TYPE + + def _to_milliseconds(time_ns: int) -> int: return int(time_ns / _NANOSECOND_TO_MILLISECOND) @@ -206,6 +217,19 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], Args: model_handler: An implementation of ModelHandler. clock: A clock implementing get_current_time_in_microseconds. + + A transform that takes a PCollection of examples (or features) to be used on + an ML model. It will then output inferences (or predictions) for those + examples in a PCollection of PredictionResults, containing the input examples + and output inferences. + + If examples are paired with keys, it will output a tuple + (key, PredictionResult) for each (key, example) input. + + Models for supported frameworks can be loaded via a URI. Supported services + can also be used. + + TODO(BEAM-14046): Add and link to help documentation """ def __init__( self, diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 1a1afaaf1c86..d8ab31b8b708 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -27,8 +27,8 @@ import torch from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult class PytorchModelHandler(ModelHandler[torch.Tensor, diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 604307c3d326..ad51a4e77f7b 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -35,7 +35,7 @@ # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: import torch - from apache_beam.ml.inference.api import PredictionResult + from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler except ImportError: diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 5ca6a18b5d1d..d7e0b7395deb 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -28,8 +28,8 @@ from sklearn.base import BaseEstimator from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult try: import joblib diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 0d7294eb4063..1e788525dc04 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -37,8 +37,9 @@ from sklearn.preprocessing import StandardScaler import apache_beam as beam -from apache_beam.ml.inference import api -from apache_beam.ml.inference import base +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.sklearn_inference import ModelFileType from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler from apache_beam.testing.test_pipeline import TestPipeline @@ -134,9 +135,9 @@ def test_predict_output(self): numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9]) ] expected_predictions = [ - api.PredictionResult(numpy.array([1, 2, 3]), 6), - api.PredictionResult(numpy.array([4, 5, 6]), 15), - api.PredictionResult(numpy.array([7, 8, 9]), 24) + PredictionResult(numpy.array([1, 2, 3]), 6), + PredictionResult(numpy.array([4, 5, 6]), 15), + PredictionResult(numpy.array([7, 8, 9]), 24) ] inferences = inference_runner.run_inference(batched_examples, fake_model) for actual, expected in zip(inferences, expected_predictions): @@ -181,11 +182,11 @@ def test_pipeline_pickled(self): pcoll = pipeline | 'start' >> beam.Create(examples) #TODO(BEAM-14305) Test against the public API. - actual = pcoll | base.RunInference( + actual = pcoll | RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - api.PredictionResult(numpy.array([0, 0]), 0), - api.PredictionResult(numpy.array([1, 1]), 1) + PredictionResult(numpy.array([0, 0]), 0), + PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result)) @@ -201,12 +202,12 @@ def test_pipeline_joblib(self): pcoll = pipeline | 'start' >> beam.Create(examples) #TODO(BEAM-14305) Test against the public API. - actual = pcoll | base.RunInference( + actual = pcoll | RunInference( SklearnModelHandler( model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB)) expected = [ - api.PredictionResult(numpy.array([0, 0]), 0), - api.PredictionResult(numpy.array([1, 1]), 1) + PredictionResult(numpy.array([0, 0]), 0), + PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result)) @@ -217,7 +218,7 @@ def test_bad_file_raises(self): examples = [numpy.array([0, 0])] pcoll = pipeline | 'start' >> beam.Create(examples) # TODO(BEAM-14305) Test against the public API. - _ = pcoll | base.RunInference( + _ = pcoll | RunInference( SklearnModelHandler(model_uri='/var/bad_file_name')) pipeline.run() @@ -239,15 +240,15 @@ def test_pipeline_pandas(self): dataframe = pandas_dataframe() splits = [dataframe.loc[[i]] for i in dataframe.index] pcoll = pipeline | 'start' >> beam.Create(splits) - actual = pcoll | api.RunInference( + actual = pcoll | RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - api.PredictionResult(splits[0], 5), - api.PredictionResult(splits[1], 8), - api.PredictionResult(splits[2], 1), - api.PredictionResult(splits[3], 1), - api.PredictionResult(splits[4], 2), + PredictionResult(splits[0], 5), + PredictionResult(splits[1], 8), + PredictionResult(splits[2], 1), + PredictionResult(splits[3], 1), + PredictionResult(splits[4], 2), ] assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) @@ -264,14 +265,14 @@ def test_pipeline_pandas_with_keys(self): keyed_rows = [(key, value) for key, value in zip(keys, splits)] pcoll = pipeline | 'start' >> beam.Create(keyed_rows) - actual = pcoll | api.RunInference( - base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) + actual = pcoll | RunInference( + KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) expected = [ - ('0', api.PredictionResult(splits[0], 5)), - ('1', api.PredictionResult(splits[1], 8)), - ('2', api.PredictionResult(splits[2], 1)), - ('3', api.PredictionResult(splits[3], 1)), - ('4', api.PredictionResult(splits[4], 2)), + ('0', PredictionResult(splits[0], 5)), + ('1', PredictionResult(splits[1], 8)), + ('2', PredictionResult(splits[2], 1)), + ('3', PredictionResult(splits[3], 1)), + ('4', PredictionResult(splits[4], 2)), ] assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))