From 682945959f8c1589ff6d5f93fed3e770db71862d Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Fri, 10 Jun 2022 14:53:48 -0400 Subject: [PATCH 01/10] refactor code from api to base --- .../apache_beam/ml/inference/__init__.py | 1 + sdks/python/apache_beam/ml/inference/base.py | 28 +++++++++++++ .../ml/inference/pytorch_inference.py | 2 +- .../ml/inference/pytorch_inference_it_test.py | 2 +- .../ml/inference/pytorch_inference_test.py | 2 +- .../ml/inference/sklearn_inference.py | 2 +- .../ml/inference/sklearn_inference_test.py | 39 +++++++++---------- 7 files changed, 52 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/__init__.py b/sdks/python/apache_beam/ml/inference/__init__.py index cce3acad34a4..d3b4ff354067 100644 --- a/sdks/python/apache_beam/ml/inference/__init__.py +++ b/sdks/python/apache_beam/ml/inference/__init__.py @@ -14,3 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from apache_beam.ml.inference.base import RunInference diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 534512a44f3f..d7b95e889a4e 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# mypy: ignore-errors """An extensible run inference transform. @@ -32,12 +33,15 @@ import pickle import sys import time +from dataclasses import dataclass from typing import Any from typing import Generic from typing import Iterable from typing import List from typing import Mapping +from typing import Tuple from typing import TypeVar +from typing import Union import apache_beam as beam from apache_beam.utils import shared @@ -54,6 +58,15 @@ ModelT = TypeVar('ModelT') ExampleT = TypeVar('ExampleT') PredictionT = TypeVar('PredictionT') +_K = TypeVar('_K') +_INPUT_TYPE = TypeVar('_INPUT_TYPE') +_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') + + +@dataclass +class PredictionResult: + example: _INPUT_TYPE + inference: _OUTPUT_TYPE def _to_milliseconds(time_ns: int) -> int: @@ -93,12 +106,27 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: return {} +@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) +@beam.typehints.with_output_types(Union[PredictionResult, Tuple[_K, PredictionResult]]) # pylint: disable=line-too-long class RunInference(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[PredictionT]]): """An extensible transform for running inferences. Args: model_handler: An implementation of ModelHandler. clock: A clock implementing get_current_time_in_microseconds. + + A transform that takes a PCollection of examples (or features) to be used on + an ML model. It will then output inferences (or predictions) for those + examples in a PCollection of PredictionResults, containing the input examples + and output inferences. + + If examples are paired with keys, it will output a tuple + (key, PredictionResult) for each (key, example) input. + + Models for supported frameworks can be loaded via a URI. Supported services + can also be used. + + TODO(BEAM-14046): Add and link to help documentation """ def __init__( self, diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 3a4fb2926f81..d5fd642cfce0 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -27,8 +27,8 @@ import torch from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult class PytorchModelHandler(ModelHandler[torch.Tensor, diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py index 066d667b7863..bec4d0c6cdda 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py @@ -66,7 +66,7 @@ class PyTorchInference(unittest.TestCase): @pytest.mark.uses_pytorch @pytest.mark.it_postcommit def test_torch_run_inference_imagenet_mobilenetv2(self): - test_pipeline = TestPipeline(is_integration_test=True) + test_pipeline = TestPipeline(is_integration_test=False) # text files containing absolute path to the imagenet validation data on GCS file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt' # disable: line-too-long output_file_dir = 'gs://apache-beam-ml/testing/predictions' diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 7f563d7cf4c4..9d149063d07f 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -35,7 +35,7 @@ # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: import torch - from apache_beam.ml.inference.api import PredictionResult + from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler except ImportError: diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 3c8eddfd7d3a..c1e0f3a22f7c 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -28,7 +28,7 @@ from sklearn.base import BaseEstimator from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.api import PredictionResult +from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import ModelHandler try: diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 91eb86e2de4b..43c7dc4e3220 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -37,7 +37,6 @@ from sklearn.preprocessing import StandardScaler import apache_beam as beam -from apache_beam.ml.inference import api from apache_beam.ml.inference import base from apache_beam.ml.inference.sklearn_inference import ModelFileType from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler @@ -134,9 +133,9 @@ def test_predict_output(self): numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9]) ] expected_predictions = [ - api.PredictionResult(numpy.array([1, 2, 3]), 6), - api.PredictionResult(numpy.array([4, 5, 6]), 15), - api.PredictionResult(numpy.array([7, 8, 9]), 24) + base.PredictionResult(numpy.array([1, 2, 3]), 6), + base.PredictionResult(numpy.array([4, 5, 6]), 15), + base.PredictionResult(numpy.array([7, 8, 9]), 24) ] inferences = inference_runner.run_inference(batched_examples, fake_model) for actual, expected in zip(inferences, expected_predictions): @@ -184,8 +183,8 @@ def test_pipeline_pickled(self): actual = pcoll | base.RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - api.PredictionResult(numpy.array([0, 0]), 0), - api.PredictionResult(numpy.array([1, 1]), 1) + base.PredictionResult(numpy.array([0, 0]), 0), + base.PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result)) @@ -205,8 +204,8 @@ def test_pipeline_joblib(self): SklearnModelHandler( model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB)) expected = [ - api.PredictionResult(numpy.array([0, 0]), 0), - api.PredictionResult(numpy.array([1, 1]), 1) + base.PredictionResult(numpy.array([0, 0]), 0), + base.PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result)) @@ -239,15 +238,15 @@ def test_pipeline_pandas(self): dataframe = pandas_dataframe() splits = [dataframe.loc[[i]] for i in dataframe.index] pcoll = pipeline | 'start' >> beam.Create(splits) - actual = pcoll | api.RunInference( + actual = pcoll | base.RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - api.PredictionResult(splits[0], 5), - api.PredictionResult(splits[1], 8), - api.PredictionResult(splits[2], 1), - api.PredictionResult(splits[3], 1), - api.PredictionResult(splits[4], 2), + base.PredictionResult(splits[0], 5), + base.PredictionResult(splits[1], 8), + base.PredictionResult(splits[2], 1), + base.PredictionResult(splits[3], 1), + base.PredictionResult(splits[4], 2), ] assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) @@ -264,14 +263,14 @@ def test_pipeline_pandas_with_keys(self): keyed_rows = [(key, value) for key, value in zip(keys, splits)] pcoll = pipeline | 'start' >> beam.Create(keyed_rows) - actual = pcoll | api.RunInference( + actual = pcoll | base.RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - ('0', api.PredictionResult(splits[0], 5)), - ('1', api.PredictionResult(splits[1], 8)), - ('2', api.PredictionResult(splits[2], 1)), - ('3', api.PredictionResult(splits[3], 1)), - ('4', api.PredictionResult(splits[4], 2)), + ('0', base.PredictionResult(splits[0], 5)), + ('1', base.PredictionResult(splits[1], 8)), + ('2', base.PredictionResult(splits[2], 1)), + ('3', base.PredictionResult(splits[3], 1)), + ('4', base.PredictionResult(splits[4], 2)), ] assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) From 8235a5b3d0ce6847925deae9128fe6a4e6ace9db Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Fri, 10 Jun 2022 14:56:44 -0400 Subject: [PATCH 02/10] delete api.py --- sdks/python/apache_beam/ml/inference/api.py | 62 ------------------- .../ml/inference/pytorch_inference_it_test.py | 2 +- 2 files changed, 1 insertion(+), 63 deletions(-) delete mode 100644 sdks/python/apache_beam/ml/inference/api.py diff --git a/sdks/python/apache_beam/ml/inference/api.py b/sdks/python/apache_beam/ml/inference/api.py deleted file mode 100644 index 3d70f874733b..000000000000 --- a/sdks/python/apache_beam/ml/inference/api.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# mypy: ignore-errors - -from dataclasses import dataclass -from typing import Tuple -from typing import TypeVar -from typing import Union - -import apache_beam as beam -from apache_beam.ml.inference import base - -_K = TypeVar('_K') -_INPUT_TYPE = TypeVar('_INPUT_TYPE') -_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') - - -@dataclass -class PredictionResult: - example: _INPUT_TYPE - inference: _OUTPUT_TYPE - - -@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) -@beam.typehints.with_output_types(Union[PredictionResult, Tuple[_K, PredictionResult]]) # pylint: disable=line-too-long -class RunInference(beam.PTransform): - """ - NOTE: This API and its implementation are under development and - do not provide backward compatibility guarantees. - - A transform that takes a PCollection of examples (or features) to be used on - an ML model. It will then output inferences (or predictions) for those - examples in a PCollection of PredictionResults, containing the input examples - and output inferences. - - If examples are paired with keys, it will output a tuple - (key, PredictionResult) for each (key, example) input. - - Models for supported frameworks can be loaded via a URI. Supported services - can also be used. - - TODO(BEAM-14046): Add and link to help documentation - """ - def __init__(self, model_loader: base.ModelHandler): - self._model_loader = model_loader - - def expand(self, pcoll: beam.PCollection) -> beam.PCollection: - return pcoll | base.RunInference(self._model_loader) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py index bec4d0c6cdda..066d667b7863 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py @@ -66,7 +66,7 @@ class PyTorchInference(unittest.TestCase): @pytest.mark.uses_pytorch @pytest.mark.it_postcommit def test_torch_run_inference_imagenet_mobilenetv2(self): - test_pipeline = TestPipeline(is_integration_test=False) + test_pipeline = TestPipeline(is_integration_test=True) # text files containing absolute path to the imagenet validation data on GCS file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt' # disable: line-too-long output_file_dir = 'gs://apache-beam-ml/testing/predictions' From 5ecbf647b4c25ad163d3696b7ffb8771d63f784f Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Fri, 10 Jun 2022 15:41:18 -0400 Subject: [PATCH 03/10] modify imports --- .../ml/inference/sklearn_inference_test.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 43c7dc4e3220..e974fd1163fa 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -37,7 +37,8 @@ from sklearn.preprocessing import StandardScaler import apache_beam as beam -from apache_beam.ml.inference import base +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.sklearn_inference import ModelFileType from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler from apache_beam.testing.test_pipeline import TestPipeline @@ -133,9 +134,9 @@ def test_predict_output(self): numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9]) ] expected_predictions = [ - base.PredictionResult(numpy.array([1, 2, 3]), 6), - base.PredictionResult(numpy.array([4, 5, 6]), 15), - base.PredictionResult(numpy.array([7, 8, 9]), 24) + PredictionResult(numpy.array([1, 2, 3]), 6), + PredictionResult(numpy.array([4, 5, 6]), 15), + PredictionResult(numpy.array([7, 8, 9]), 24) ] inferences = inference_runner.run_inference(batched_examples, fake_model) for actual, expected in zip(inferences, expected_predictions): @@ -180,11 +181,11 @@ def test_pipeline_pickled(self): pcoll = pipeline | 'start' >> beam.Create(examples) #TODO(BEAM-14305) Test against the public API. - actual = pcoll | base.RunInference( + actual = pcoll | RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - base.PredictionResult(numpy.array([0, 0]), 0), - base.PredictionResult(numpy.array([1, 1]), 1) + PredictionResult(numpy.array([0, 0]), 0), + PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result)) @@ -200,12 +201,12 @@ def test_pipeline_joblib(self): pcoll = pipeline | 'start' >> beam.Create(examples) #TODO(BEAM-14305) Test against the public API. - actual = pcoll | base.RunInference( + actual = pcoll | RunInference( SklearnModelHandler( model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB)) expected = [ - base.PredictionResult(numpy.array([0, 0]), 0), - base.PredictionResult(numpy.array([1, 1]), 1) + PredictionResult(numpy.array([0, 0]), 0), + PredictionResult(numpy.array([1, 1]), 1) ] assert_that( actual, equal_to(expected, equals_fn=_compare_prediction_result)) @@ -216,7 +217,7 @@ def test_bad_file_raises(self): examples = [numpy.array([0, 0])] pcoll = pipeline | 'start' >> beam.Create(examples) # TODO(BEAM-14305) Test against the public API. - _ = pcoll | base.RunInference( + _ = pcoll | RunInference( SklearnModelHandler(model_uri='/var/bad_file_name')) pipeline.run() @@ -238,15 +239,15 @@ def test_pipeline_pandas(self): dataframe = pandas_dataframe() splits = [dataframe.loc[[i]] for i in dataframe.index] pcoll = pipeline | 'start' >> beam.Create(splits) - actual = pcoll | base.RunInference( + actual = pcoll | RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - base.PredictionResult(splits[0], 5), - base.PredictionResult(splits[1], 8), - base.PredictionResult(splits[2], 1), - base.PredictionResult(splits[3], 1), - base.PredictionResult(splits[4], 2), + PredictionResult(splits[0], 5), + PredictionResult(splits[1], 8), + PredictionResult(splits[2], 1), + PredictionResult(splits[3], 1), + PredictionResult(splits[4], 2), ] assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) @@ -263,14 +264,14 @@ def test_pipeline_pandas_with_keys(self): keyed_rows = [(key, value) for key, value in zip(keys, splits)] pcoll = pipeline | 'start' >> beam.Create(keyed_rows) - actual = pcoll | base.RunInference( + actual = pcoll | RunInference( SklearnModelHandler(model_uri=temp_file_name)) expected = [ - ('0', base.PredictionResult(splits[0], 5)), - ('1', base.PredictionResult(splits[1], 8)), - ('2', base.PredictionResult(splits[2], 1)), - ('3', base.PredictionResult(splits[3], 1)), - ('4', base.PredictionResult(splits[4], 2)), + ('0', PredictionResult(splits[0], 5)), + ('1', PredictionResult(splits[1], 8)), + ('2', PredictionResult(splits[2], 1)), + ('3', PredictionResult(splits[3], 1)), + ('4', PredictionResult(splits[4], 2)), ] assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) From cd1118f2e961b568f393aaede8c96c959cd66782 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 08:46:13 -0400 Subject: [PATCH 04/10] Add todo to mypy github issue --- sdks/python/apache_beam/ml/inference/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index d7b95e889a4e..b055ba102b78 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# (TODO) https://github.com/apache/beam/issues/21441 # mypy: ignore-errors """An extensible run inference transform. From 77f769102d93adcbcb61b86e59fbfbb19fde07d2 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 09:07:38 -0400 Subject: [PATCH 05/10] Refactor code to reflect changes of #21777 --- sdks/python/apache_beam/ml/inference/base.py | 4 ++-- .../apache_beam/ml/inference/pytorch_inference_test.py | 2 +- .../apache_beam/ml/inference/sklearn_inference_test.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 6e4cf295fe2f..efebff0eab3a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -211,8 +211,8 @@ def batch_elements_kwargs(self): return self._unkeyed.batch_elements_kwargs() -@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) -@beam.typehints.with_output_types(Union[PredictionResult, Tuple[_K, PredictionResult]]) # pylint: disable=line-too-long +@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[KeyT, _INPUT_TYPE]]) +@beam.typehints.with_output_types(Union[PredictionResult, Tuple[KeyT, PredictionResult]]) # pylint: disable=line-too-long class RunInference(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[PredictionT]]): """An extensible transform for running inferences. diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 9d149063d07f..aaecaa8b7988 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -312,7 +312,7 @@ def test_pipeline_local_model_kwargs_prediction_params(self): prediction_params_side_input = ( pipeline | 'create side' >> beam.Create(prediction_params)) predictions = pcoll | RunInference( - model_loader=model_loader, + model_handler=model_loader, prediction_params=beam.pvalue.AsDict(prediction_params_side_input)) assert_that( predictions, diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index fe427271f3be..1e788525dc04 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -37,6 +37,7 @@ from sklearn.preprocessing import StandardScaler import apache_beam as beam +from apache_beam.ml.inference.base import KeyedModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.sklearn_inference import ModelFileType @@ -264,8 +265,8 @@ def test_pipeline_pandas_with_keys(self): keyed_rows = [(key, value) for key, value in zip(keys, splits)] pcoll = pipeline | 'start' >> beam.Create(keyed_rows) - actual = pcoll | api.RunInference( - base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) + actual = pcoll | RunInference( + KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) expected = [ ('0', PredictionResult(splits[0], 5)), ('1', PredictionResult(splits[1], 8)), From 65a0483f0203d89f648fb917041a74291ba619b2 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 09:13:56 -0400 Subject: [PATCH 06/10] Refactor example with KeyedModelHandler --- .../inference/pytorch_image_classification.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index a3ea84ad01b5..46faacb88ebd 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -27,9 +27,10 @@ import apache_beam as beam import torch from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.api import PredictionResult -from apache_beam.ml.inference.api import RunInference -from apache_beam.ml.inference.pytorch_inference import PytorchModelLoader +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from PIL import Image @@ -114,10 +115,13 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True): model_class = MobileNetV2 model_params = {'num_classes': 1000} - model_loader = PytorchModelLoader( - state_dict_path=known_args.model_state_dict_path, - model_class=model_class, - model_params=model_params) + # the input to RunInference transform is keyed. Wrap + # PytorchModelHandler on KeyedModelHandler for keyed examples. + model_loader = KeyedModelHandler( + PytorchModelHandler( + state_dict_path=known_args.model_state_dict_path, + model_class=model_class, + model_params=model_params)) with beam.Pipeline(options=pipeline_options) as p: filename_value_pair = ( From 55b279309f73268bfb8447ff8a8169332b356b6c Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 11:20:04 -0400 Subject: [PATCH 07/10] remove explicit type hints from RunInference class --- .../examples/inference/pytorch_image_classification.py | 4 +--- sdks/python/apache_beam/ml/inference/base.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index 59880d04a74d..070fc80dd769 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -135,9 +135,7 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True): lambda file_name, data: (file_name, preprocess_image(data)))) predictions = ( filename_value_pair - | - 'PyTorchRunInference' >> RunInference(model_handler).with_output_types( - Tuple[str, PredictionResult]) + | 'PyTorchRunInference' >> RunInference(model_handler) | 'ProcessOutput' >> beam.ParDo(PostProcessor())) if known_args.output: diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index efebff0eab3a..654fe3fc86a5 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -211,8 +211,6 @@ def batch_elements_kwargs(self): return self._unkeyed.batch_elements_kwargs() -@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[KeyT, _INPUT_TYPE]]) -@beam.typehints.with_output_types(Union[PredictionResult, Tuple[KeyT, PredictionResult]]) # pylint: disable=line-too-long class RunInference(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[PredictionT]]): """An extensible transform for running inferences. From 1576be50b8a14b719e31ff908b946e0218ec1eab Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 11:39:29 -0400 Subject: [PATCH 08/10] Fixup : Lint --- sdks/python/apache_beam/ml/inference/sklearn_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index f3cca3c1083f..d7e0b7395deb 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -28,8 +28,8 @@ from sklearn.base import BaseEstimator from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult try: import joblib From 0150487be615d0e3a6de8442dde58bf26421a1bb Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 13:04:24 -0400 Subject: [PATCH 09/10] remove TODO to github issue for mypy error --- sdks/python/apache_beam/ml/inference/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 654fe3fc86a5..dffc62025e04 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# (TODO) https://github.com/apache/beam/issues/21441 # mypy: ignore-errors """An extensible run inference transform. From af81bacfce26c32c8294530950c1543622492cba Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 13 Jun 2022 13:19:32 -0400 Subject: [PATCH 10/10] Add mypy github issue as TODO --- sdks/python/apache_beam/ml/inference/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index dffc62025e04..ad7191cb59b9 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# TODO: https://github.com/apache/beam/issues/21822 # mypy: ignore-errors """An extensible run inference transform.