From 68f1322ded4e4797ccf03fa87214b286d9541363 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 10 Jun 2022 15:09:24 -0400 Subject: [PATCH 1/8] Remove kwargs and add runinference_args --- sdks/python/apache_beam/ml/inference/base.py | 27 +++++++++++++------ .../apache_beam/ml/inference/base_test.py | 26 ++++++++++-------- .../ml/inference/pytorch_inference.py | 11 +++++--- .../ml/inference/pytorch_inference_test.py | 16 +++++------ .../ml/inference/sklearn_inference.py | 5 ++-- 5 files changed, 51 insertions(+), 34 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 2d10cf7a1561..38c9626ea4d2 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -33,10 +33,12 @@ import sys import time from typing import Any +from typing import Dict from typing import Generic from typing import Iterable from typing import List from typing import Mapping +from typing import Optional from typing import TypeVar import apache_beam as beam @@ -66,8 +68,12 @@ def _to_microseconds(time_ns: int) -> int: class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]): """Implements running inferences for a framework.""" - def run_inference(self, batch: List[ExampleT], model: ModelT, - **kwargs) -> Iterable[PredictionT]: + def run_inference( + self, + batch: List[ExampleT], + model: ModelT, + extra_runinference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) @@ -112,9 +118,9 @@ def __init__( self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock=time, - **kwargs): + extra_runinference_args: Optional[Dict[str, Any]] = None): self._model_loader = model_loader - self._kwargs = kwargs + self._extra_runinference_args = extra_runinference_args self._clock = clock # TODO(BEAM-14208): Add batch_size back off in the case there @@ -129,7 +135,8 @@ def expand( | ( beam.ParDo( _RunInferenceDoFn(self._model_loader, self._clock), - **self._kwargs).with_resource_hints(**resource_hints))) + self._extra_runinference_args).with_resource_hints( + **resource_hints))) class _MetricsCollector: @@ -213,7 +220,7 @@ def setup(self): self._inference_runner.get_metrics_namespace()) self._model = self._load_model() - def process(self, batch, **kwargs): + def process(self, batch, extra_runinference_args): # Process supports both keyed data, and example only data. # First keys and samples are separated (if there are keys) has_keys = isinstance(batch[0], tuple) @@ -225,8 +232,12 @@ def process(self, batch, **kwargs): keys = None start_time = _to_microseconds(self._clock.time_ns()) - result_generator = self._inference_runner.run_inference( - examples, self._model, **kwargs) + if extra_runinference_args: + result_generator = self._inference_runner.run_inference( + examples, self._model, extra_runinference_args) + else: + result_generator = self._inference_runner.run_inference( + examples, self._model) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 9616ba6060a6..c9c51c866855 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -39,8 +39,11 @@ class FakeInferenceRunner(base.InferenceRunner[int, int, FakeModel]): def __init__(self, clock=None): self._fake_clock = clock - def run_inference(self, batch: List[int], model: FakeModel, - **kwargs) -> Iterable[int]: + def run_inference( + self, + batch: List[int], + model: FakeModel, + ) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds for example in batch: @@ -89,16 +92,16 @@ def batch_elements_kwargs(self): return {'min_batch_size': 9999} -class FakeInferenceRunnerKwargs(FakeInferenceRunner): - def run_inference(self, batch, unused_model, **kwargs): - if not kwargs.get('key'): - raise ValueError('key should be True') +class FakeInferenceRunnerExtraArgs(FakeInferenceRunner): + def run_inference(self, batch, unused_model, extra_runinference_args): + if not extra_runinference_args: + raise ValueError('extra_runinference_args should exist') return batch -class FakeLoaderWithKwargs(FakeModelLoader): +class FakeLoaderWithExtraArgs(FakeModelLoader): def get_inference_runner(self): - return FakeInferenceRunnerKwargs() + return FakeInferenceRunnerExtraArgs() class RunInferenceBaseTest(unittest.TestCase): @@ -119,12 +122,13 @@ def test_run_inference_impl_with_keyed_examples(self): actual = pcoll | base.RunInference(FakeModelLoader()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_run_inference_impl_kwargs(self): + def test_run_inference_impl_extra_runinference_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] pcoll = pipeline | 'start' >> beam.Create(examples) - kwargs = {'key': True} - actual = pcoll | base.RunInference(FakeLoaderWithKwargs(), **kwargs) + extra_args = {'key': True} + actual = pcoll | base.RunInference( + FakeLoaderWithExtraArgs(), extra_runinference_args=extra_args) assert_that(actual, equal_to(examples), label='assert:inferences') def test_counted_metrics(self): diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index d2a47d530635..bf4ffb0cae40 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Iterable from typing import List +from typing import Optional from typing import Union import torch @@ -58,7 +59,8 @@ def run_inference( self, batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], model: torch.nn.Module, - **kwargs) -> Iterable[PredictionResult]: + extra_runinference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Tensors and returns an Iterable of Tensor Predictions. @@ -66,7 +68,8 @@ def run_inference( This method stacks the list of Tensors in a vectorized format to optimize the inference call. """ - prediction_params = kwargs.get('prediction_params', {}) + extra_runinference_args = ( + extra_runinference_args if extra_runinference_args else {}) # If elements in `batch` are provided as a dictionaries from key to Tensors, # then iterate through the batch list, and group Tensors to the same key @@ -80,12 +83,12 @@ def run_inference( batched_tensors = torch.stack(key_to_tensor_list[key]) batched_tensors = self._convert_to_device(batched_tensors) key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **prediction_params) + predictions = model(**key_to_batched_tensors, **extra_runinference_args) else: # If elements in `batch` are provided as Tensors, then do a regular stack batched_tensors = torch.stack(batch) batched_tensors = self._convert_to_device(batched_tensors) - predictions = model(batched_tensors, **prediction_params) + predictions = model(batched_tensors, **extra_runinference_args) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: List[torch.Tensor]) -> int: diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 0011ba58244d..468713c68496 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -208,13 +208,13 @@ def forward(self, k1, k2): for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): self.assertTrue(_compare_prediction_result(actual, expected)) - def test_inference_runner_kwargs_prediction_params(self): + def test_inference_runner_kwargs_extra_args(self): """ This tests for non-batchable input arguments. Since we do the batching for the user, we have to distinguish between the inputs that should be batched and the ones that should not be batched. """ - prediction_params = { + extra_args = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True @@ -231,7 +231,7 @@ def test_inference_runner_kwargs_prediction_params(self): predictions = inference_runner.run_inference( batch=KWARGS_TORCH_EXAMPLES, model=model, - prediction_params=prediction_params) + extra_runinference_args=extra_args) for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): self.assertEqual(actual, expected) @@ -278,9 +278,9 @@ def test_pipeline_local_model_simple(self): equal_to( TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) - def test_pipeline_local_model_kwargs_prediction_params(self): + def test_pipeline_local_model_kwargs_extra_args(self): with TestPipeline() as pipeline: - prediction_params = { + extra_args = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True @@ -299,11 +299,11 @@ def test_pipeline_local_model_kwargs_prediction_params(self): }) pcoll = pipeline | 'start' >> beam.Create(KWARGS_TORCH_EXAMPLES) - prediction_params_side_input = ( - pipeline | 'create side' >> beam.Create(prediction_params)) + extra_args_side_input = ( + pipeline | 'create side' >> beam.Create(extra_args)) predictions = pcoll | RunInference( model_loader=model_loader, - prediction_params=beam.pvalue.AsDict(prediction_params_side_input)) + extra_runinference_args=beam.pvalue.AsDict(extra_args_side_input)) assert_that( predictions, equal_to( diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 00d63f9fb6c1..c6276577d8ba 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -44,9 +44,8 @@ class ModelFileType(enum.Enum): class SklearnInferenceRunner(InferenceRunner[numpy.ndarray, PredictionResult, BaseEstimator]): - def run_inference( - self, batch: List[numpy.ndarray], model: BaseEstimator, - **kwargs) -> Iterable[PredictionResult]: + def run_inference(self, batch: List[numpy.ndarray], + model: BaseEstimator) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) From b1b9f41dc5121989caf2827aa7afe9c4cee87043 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 10 Jun 2022 16:05:52 -0400 Subject: [PATCH 2/8] Fix names in the test --- .../ml/inference/pytorch_inference_test.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 243b17857209..8b30c5d3821b 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -62,7 +62,7 @@ for f1, f2 in TWO_FEATURES_EXAMPLES]).reshape(-1, 1)) ] -KWARGS_TORCH_EXAMPLES = [ +KEYED_TORCH_EXAMPLES = [ { 'k1': torch.from_numpy(np.array([1], dtype="float32")), 'k2': torch.from_numpy(np.array([1.5], dtype="float32")) @@ -81,12 +81,12 @@ }, ] -KWARGS_TORCH_PREDICTIONS = [ +KEYED_TORCH_PREDICTIONS = [ PredictionResult(ex, pred) for ex, pred in zip( - KWARGS_TORCH_EXAMPLES, + KEYED_TORCH_EXAMPLES, torch.Tensor([(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5) - for example in KWARGS_TORCH_EXAMPLES]).reshape(-1, 1)) + for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) ] @@ -115,12 +115,12 @@ def forward(self, x): return out -class PytorchLinearRegressionKwargsPredictionParams(torch.nn.Module): +class PytorchLinearRegressionKeyedBatchAndExtraParams(torch.nn.Module): """ - A linear model with kwargs inputs and non-batchable input params. + A linear model with batched keyed inputs and non-batchable extra params. - Note: k1 and k2 are batchable inputs passed in as a kwargs. - prediction_param_array, prediction_param_bool are non-batchable inputs + Note: k1 and k2 are batchable examples passed in as a keyed to torch dict. + prediction_param_array, prediction_param_bool are non-batchable extra params (typically model-related info) used to configure the model before its predict call is invoked """ @@ -179,10 +179,10 @@ def test_run_inference_multiple_tensor_features(self): for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS): self.assertEqual(actual, expected) - def test_run_inference_kwargs(self): + def test_run_inference_keyed(self): """ This tests for inputs that are passed as a dictionary from key to tensor - instead of a standard non-kwarg input. + instead of a standard non-keyed tensor example. Example: Typical input format is @@ -211,11 +211,11 @@ def forward(self, k1, k2): inference_runner = TestPytorchModelHandlerForInferenceOnly( torch.device('cpu')) - predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model) - for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): + predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model) + for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertTrue(_compare_prediction_result(actual, expected)) - def test_inference_runner_kwargs_extra_args(self): + def test_inference_runner_extra_args(self): """ This tests for non-batchable input arguments. Since we do the batching for the user, we have to distinguish between the inputs that should be @@ -227,7 +227,7 @@ def test_inference_runner_kwargs_extra_args(self): 'prediction_param_bool': True } - model = PytorchLinearRegressionKwargsPredictionParams( + model = PytorchLinearRegressionKeyedBatchAndExtraParams( input_dim=1, output_dim=1) model.load_state_dict( OrderedDict([('linear.weight', torch.Tensor([[2.0]])), @@ -237,10 +237,10 @@ def test_inference_runner_kwargs_extra_args(self): inference_runner = TestPytorchModelHandlerForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference( - batch=KWARGS_TORCH_EXAMPLES, + batch=KEYED_TORCH_EXAMPLES, model=model, extra_runinference_args=extra_args) - for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): + for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertEqual(actual, expected) def test_num_bytes(self): @@ -288,7 +288,7 @@ def test_pipeline_local_model_simple(self): equal_to( TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) - def test_pipeline_local_model_kwargs_extra_args(self): + def test_pipeline_local_model_extra_args(self): with TestPipeline() as pipeline: extra_args = { 'prediction_param_array': torch.from_numpy( @@ -303,12 +303,12 @@ def test_pipeline_local_model_kwargs_extra_args(self): model_loader = PytorchModelHandler( state_dict_path=path, - model_class=PytorchLinearRegressionKwargsPredictionParams, + model_class=PytorchLinearRegressionKeyedBatchAndExtraParams, model_params={ 'input_dim': 1, 'output_dim': 1 }) - pcoll = pipeline | 'start' >> beam.Create(KWARGS_TORCH_EXAMPLES) + pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES) extra_args_side_input = ( pipeline | 'create side' >> beam.Create(extra_args)) predictions = pcoll | RunInference( @@ -317,7 +317,7 @@ def test_pipeline_local_model_kwargs_extra_args(self): assert_that( predictions, equal_to( - KWARGS_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result)) + KEYED_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result)) @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed') def test_pipeline_gcs_model(self): From dfbdde9efdb335a922f342e34d15ea02044129e1 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 10 Jun 2022 16:22:20 -0400 Subject: [PATCH 3/8] Revert change in api.py --- sdks/python/apache_beam/ml/inference/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/api.py b/sdks/python/apache_beam/ml/inference/api.py index 9574be380416..3d70f874733b 100644 --- a/sdks/python/apache_beam/ml/inference/api.py +++ b/sdks/python/apache_beam/ml/inference/api.py @@ -55,8 +55,8 @@ class RunInference(beam.PTransform): TODO(BEAM-14046): Add and link to help documentation """ - def __init__(self, model_handler: base.ModelHandler): - self._model_handler = model_handler + def __init__(self, model_loader: base.ModelHandler): + self._model_loader = model_loader def expand(self, pcoll: beam.PCollection) -> beam.PCollection: - return pcoll | base.RunInference(self._model_handler) + return pcoll | base.RunInference(self._model_loader) From 8a726e7e2e2437e8435e3a13d9d1f39f264b92b8 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 10 Jun 2022 17:27:45 -0400 Subject: [PATCH 4/8] Fix variable name; Add docstring --- sdks/python/apache_beam/ml/inference/base.py | 18 ++++++------ .../apache_beam/ml/inference/base_test.py | 14 +++++----- .../ml/inference/pytorch_inference.py | 11 ++++---- .../ml/inference/pytorch_inference_test.py | 28 +++++++++---------- 4 files changed, 34 insertions(+), 37 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 24d47168ff81..9719ac58f8bc 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -76,8 +76,7 @@ def run_inference( self, batch: List[ExampleT], model: ModelT, - extra_runinference_args: Optional[Dict[str, Any]] = None - ) -> Iterable[PredictionT]: + extra_kwargs: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) @@ -105,14 +104,16 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], Args: model_handler: An implementation of ModelHandler. clock: A clock implementing get_current_time_in_microseconds. + extra_kwargs: Extra arguments for models whose inference call requires + extra parameters. """ def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], clock=time, - extra_runinference_args: Optional[Dict[str, Any]] = None): + extra_kwargs: Optional[Dict[str, Any]] = None): self._model_handler = model_handler - self._extra_runinference_args = extra_runinference_args + self._extra_kwargs = extra_kwargs self._clock = clock # TODO(BEAM-14208): Add batch_size back off in the case there @@ -127,8 +128,7 @@ def expand( | ( beam.ParDo( _RunInferenceDoFn(self._model_handler, self._clock), - self._extra_runinference_args).with_resource_hints( - **resource_hints))) + self._extra_kwargs).with_resource_hints(**resource_hints))) class _MetricsCollector: @@ -211,7 +211,7 @@ def setup(self): self._model_handler.get_metrics_namespace()) self._model = self._load_model() - def process(self, batch, extra_runinference_args): + def process(self, batch, extra_kwargs): # Process supports both keyed data, and example only data. # First keys and samples are separated (if there are keys) has_keys = isinstance(batch[0], tuple) @@ -223,9 +223,9 @@ def process(self, batch, extra_runinference_args): keys = None start_time = _to_microseconds(self._clock.time_ns()) - if extra_runinference_args: + if extra_kwargs: result_generator = self._model_handler.run_inference( - examples, self._model, extra_runinference_args) + examples, self._model, extra_kwargs) else: result_generator = self._model_handler.run_inference( examples, self._model) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index dafdde1126e4..360c82a3d6c0 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -75,10 +75,10 @@ def batch_elements_kwargs(self): return {'min_batch_size': 9999} -class FakeModelHandlerExtraArgs(FakeModelHandler): - def run_inference(self, batch, unused_model, extra_runinference_args): - if not extra_runinference_args: - raise ValueError('extra_runinference_args should exist') +class FakeModelHandlerExtraKwargs(FakeModelHandler): + def run_inference(self, batch, unused_model, extra_kwargs): + if not extra_kwargs: + raise ValueError('extra_kwargs should exist') return batch @@ -100,13 +100,13 @@ def test_run_inference_impl_with_keyed_examples(self): actual = pcoll | base.RunInference(FakeModelHandler()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_run_inference_impl_extra_runinference_args(self): + def test_run_inference_impl_extra_kwargs(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] pcoll = pipeline | 'start' >> beam.Create(examples) - extra_args = {'key': True} + extra_kwargs = {'key': True} actual = pcoll | base.RunInference( - FakeModelHandlerExtraArgs(), extra_runinference_args=extra_args) + FakeModelHandlerExtraKwargs(), extra_kwargs=extra_kwargs) assert_that(actual, equal_to(examples), label='assert:inferences') def test_counted_metrics(self): diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index eab2c5fe0829..eaedbe18003f 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -90,8 +90,8 @@ def run_inference( self, batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], model: torch.nn.Module, - extra_runinference_args: Optional[Dict[str, Any]] = None - ) -> Iterable[PredictionResult]: + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Tensors and returns an Iterable of Tensor Predictions. @@ -99,8 +99,7 @@ def run_inference( This method stacks the list of Tensors in a vectorized format to optimize the inference call. """ - extra_runinference_args = ( - extra_runinference_args if extra_runinference_args else {}) + extra_kwargs = extra_kwargs if extra_kwargs else {} # If elements in `batch` are provided as a dictionaries from key to Tensors, # then iterate through the batch list, and group Tensors to the same key @@ -114,12 +113,12 @@ def run_inference( batched_tensors = torch.stack(key_to_tensor_list[key]) batched_tensors = self._convert_to_device(batched_tensors) key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **extra_runinference_args) + predictions = model(**key_to_batched_tensors, **extra_kwargs) else: # If elements in `batch` are provided as Tensors, then do a regular stack batched_tensors = torch.stack(batch) batched_tensors = self._convert_to_device(batched_tensors) - predictions = model(batched_tensors, **extra_runinference_args) + predictions = model(batched_tensors, **extra_kwargs) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: List[torch.Tensor]) -> int: diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 8b30c5d3821b..a37195aa4b61 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -115,12 +115,12 @@ def forward(self, x): return out -class PytorchLinearRegressionKeyedBatchAndExtraParams(torch.nn.Module): +class PytorchLinearRegressionKeyedBatchAndExtraKwargs(torch.nn.Module): """ - A linear model with batched keyed inputs and non-batchable extra params. + A linear model with batched keyed inputs and non-batchable extra args. - Note: k1 and k2 are batchable examples passed in as a keyed to torch dict. - prediction_param_array, prediction_param_bool are non-batchable extra params + Note: k1 and k2 are batchable examples passed in as a dict from str to tensor. + prediction_param_array, prediction_param_bool are non-batchable extra args (typically model-related info) used to configure the model before its predict call is invoked """ @@ -215,19 +215,19 @@ def forward(self, k1, k2): for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertTrue(_compare_prediction_result(actual, expected)) - def test_inference_runner_extra_args(self): + def test_inference_runner_extra_kwargs(self): """ This tests for non-batchable input arguments. Since we do the batching for the user, we have to distinguish between the inputs that should be batched and the ones that should not be batched. """ - extra_args = { + extra_kwargs = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True } - model = PytorchLinearRegressionKeyedBatchAndExtraParams( + model = PytorchLinearRegressionKeyedBatchAndExtraKwargs( input_dim=1, output_dim=1) model.load_state_dict( OrderedDict([('linear.weight', torch.Tensor([[2.0]])), @@ -237,9 +237,7 @@ def test_inference_runner_extra_args(self): inference_runner = TestPytorchModelHandlerForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference( - batch=KEYED_TORCH_EXAMPLES, - model=model, - extra_runinference_args=extra_args) + batch=KEYED_TORCH_EXAMPLES, model=model, extra_kwargs=extra_kwargs) for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertEqual(actual, expected) @@ -290,7 +288,7 @@ def test_pipeline_local_model_simple(self): def test_pipeline_local_model_extra_args(self): with TestPipeline() as pipeline: - extra_args = { + extra_kwargs = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True @@ -303,17 +301,17 @@ def test_pipeline_local_model_extra_args(self): model_loader = PytorchModelHandler( state_dict_path=path, - model_class=PytorchLinearRegressionKeyedBatchAndExtraParams, + model_class=PytorchLinearRegressionKeyedBatchAndExtraKwargs, model_params={ 'input_dim': 1, 'output_dim': 1 }) pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES) - extra_args_side_input = ( - pipeline | 'create side' >> beam.Create(extra_args)) + extra_kwargs_side_input = ( + pipeline | 'create side' >> beam.Create(extra_kwargs)) predictions = pcoll | RunInference( model_handler=model_loader, - extra_runinference_args=beam.pvalue.AsDict(extra_args_side_input)) + extra_kwargs=beam.pvalue.AsDict(extra_kwargs_side_input)) assert_that( predictions, equal_to( From 4af9194724549667a6989dda874e53cf59f9f360 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Mon, 13 Jun 2022 07:13:23 -0400 Subject: [PATCH 5/8] Refactor out remaining kwargs --- sdks/python/apache_beam/ml/inference/base.py | 25 +++++++++++-------- .../apache_beam/ml/inference/base_test.py | 11 +++++--- .../ml/inference/sklearn_inference.py | 6 ++++- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index e62e8b25bd33..419c8cbc5008 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -119,11 +119,14 @@ def load_model(self) -> ModelT: return self._unkeyed.load_model() def run_inference( - self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, - **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: + self, + batch: Sequence[Tuple[KeyT, ExampleT]], + model: ModelT, + extra_kwargs: Optional[Dict[str, Any]] = None + ) -> Iterable[Tuple[KeyT, PredictionT]]: keys, unkeyed_batch = zip(*batch) return zip( - keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs)) + keys, self._unkeyed.run_inference(unkeyed_batch, model, extra_kwargs)) def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: keys, unkeyed_batch = zip(*batch) @@ -167,7 +170,7 @@ def run_inference( self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], model: ModelT, - **kwargs + extra_kwargs: Optional[Dict[str, Any]] = None ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: # Really the input should be # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] @@ -179,7 +182,7 @@ def run_inference( is_keyed = False unkeyed_batch = batch # type: ignore[assignment] unkeyed_results = self._unkeyed.run_inference( - unkeyed_batch, model, **kwargs) + unkeyed_batch, model, extra_kwargs) if is_keyed: return zip(keys, unkeyed_results) else: @@ -320,11 +323,13 @@ def setup(self): def process(self, batch, extra_kwargs): start_time = _to_microseconds(self._clock.time_ns()) - if extra_kwargs: - result_generator = self._model_handler.run_inference( - batch, self._model, extra_kwargs) - else: - result_generator = self._model_handler.run_inference(batch, self._model) + # if extra_kwargs: + # result_generator = self._model_handler.run_inference( + # batch, self._model, extra_kwargs) + # else: + # result_generator = self._model_handler.run_inference(batch, self._model) + result_generator = self._model_handler.run_inference( + batch, self._model, extra_kwargs) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index d8d3e5c481f1..018194b22e93 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -44,8 +44,11 @@ def load_model(self): self._fake_clock.current_time_ns += 500_000_000 # 500ms return FakeModel() - def run_inference(self, batch: Sequence[int], - model: FakeModel) -> Iterable[int]: + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + extra_kwargs=None) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds for example in batch: @@ -67,7 +70,7 @@ def process(self, prediction_result): class FakeModelHandlerNeedsBigBatch(FakeModelHandler): - def run_inference(self, batch, unused_model): + def run_inference(self, batch, unused_model, extra_kwargs=None): if len(batch) < 100: raise ValueError('Unexpectedly small batch') return batch @@ -77,7 +80,7 @@ def batch_elements_kwargs(self): class FakeModelHandlerExtraKwargs(FakeModelHandler): - def run_inference(self, batch, unused_model, extra_kwargs): + def run_inference(self, batch, unused_model, extra_kwargs=None): if not extra_kwargs: raise ValueError('extra_kwargs should exist') return batch diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index aeda29734eab..a67873c257c6 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -19,7 +19,9 @@ import pickle import sys from typing import Any +from typing import Dict from typing import Iterable +from typing import Optional from typing import Sequence from typing import Union @@ -76,7 +78,9 @@ def load_model(self) -> BaseEstimator: def run_inference( self, batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]], - model: BaseEstimator) -> Iterable[PredictionResult]: + model: BaseEstimator, + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: # TODO(github.com/apache/beam/issues/21769): Use supplied input type hint. if isinstance(batch[0], numpy.ndarray): return SklearnModelHandler._predict_np_array(batch, model) From 4e700ce14d69d83519b2d1faea56ea9426ecc32d Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Mon, 13 Jun 2022 17:45:06 -0400 Subject: [PATCH 6/8] Fix missing extra_kwargs --- sdks/python/apache_beam/ml/inference/sklearn_inference.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 020ba77a6254..e6e1ef11520e 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -77,8 +77,11 @@ def load_model(self) -> BaseEstimator: return _load_model(self._model_uri, self._model_file_type) def run_inference( - self, batch: Sequence[numpy.ndarray], model: BaseEstimator, - **kwargs) -> Iterable[PredictionResult]: + self, + batch: Sequence[numpy.ndarray], + model: BaseEstimator, + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) From 527f24c212a423350c8f27da967b852777e55e1c Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 14 Jun 2022 10:52:20 -0400 Subject: [PATCH 7/8] Remove comments --- sdks/python/apache_beam/ml/inference/base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index f8fbfb6bac21..828808302f0a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -347,11 +347,6 @@ def setup(self): def process(self, batch, extra_kwargs): start_time = _to_microseconds(self._clock.time_ns()) - # if extra_kwargs: - # result_generator = self._model_handler.run_inference( - # batch, self._model, extra_kwargs) - # else: - # result_generator = self._model_handler.run_inference(batch, self._model) result_generator = self._model_handler.run_inference( batch, self._model, extra_kwargs) predictions = list(result_generator) From b7f5d94a1e32f94c6061cb86f7ad23b3967bcfb2 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 14 Jun 2022 14:37:05 -0400 Subject: [PATCH 8/8] Add extra_args back to sklearn; Throw exception for sklearn --- sdks/python/apache_beam/ml/inference/base.py | 24 +++++++++---------- .../apache_beam/ml/inference/base_test.py | 18 +++++++------- .../ml/inference/pytorch_inference.py | 16 ++++++------- .../ml/inference/pytorch_inference_test.py | 22 ++++++++--------- .../ml/inference/sklearn_inference.py | 24 +++++++++++++++---- .../ml/inference/sklearn_inference_test.py | 12 +++++++++- 6 files changed, 71 insertions(+), 45 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 828808302f0a..4e5ec960e701 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -90,7 +90,7 @@ def run_inference( self, batch: Sequence[ExampleT], model: ModelT, - extra_kwargs: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: + inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) @@ -133,11 +133,11 @@ def run_inference( self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, - extra_kwargs: Optional[Dict[str, Any]] = None + inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[Tuple[KeyT, PredictionT]]: keys, unkeyed_batch = zip(*batch) return zip( - keys, self._unkeyed.run_inference(unkeyed_batch, model, extra_kwargs)) + keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: keys, unkeyed_batch = zip(*batch) @@ -181,7 +181,7 @@ def run_inference( self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], model: ModelT, - extra_kwargs: Optional[Dict[str, Any]] = None + inference_args: Optional[Dict[str, Any]] = None ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: # Really the input should be # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] @@ -193,7 +193,7 @@ def run_inference( is_keyed = False unkeyed_batch = batch # type: ignore[assignment] unkeyed_results = self._unkeyed.run_inference( - unkeyed_batch, model, extra_kwargs) + unkeyed_batch, model, inference_args) if is_keyed: return zip(keys, unkeyed_results) else: @@ -225,8 +225,8 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], Args: model_handler: An implementation of ModelHandler. clock: A clock implementing get_current_time_in_microseconds. - extra_kwargs: Extra arguments for models whose inference call requires - extra parameters. + inference_args: Extra arguments for models whose inference call requires + extra parameters. A transform that takes a PCollection of examples (or features) to be used on an ML model. It will then output inferences (or predictions) for those @@ -245,9 +245,9 @@ def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], clock=time, - extra_kwargs: Optional[Dict[str, Any]] = None): + inference_args: Optional[Dict[str, Any]] = None): self._model_handler = model_handler - self._extra_kwargs = extra_kwargs + self._inference_args = inference_args self._clock = clock # TODO(BEAM-14208): Add batch_size back off in the case there @@ -262,7 +262,7 @@ def expand( | ( beam.ParDo( _RunInferenceDoFn(self._model_handler, self._clock), - self._extra_kwargs).with_resource_hints(**resource_hints))) + self._inference_args).with_resource_hints(**resource_hints))) class _MetricsCollector: @@ -345,10 +345,10 @@ def setup(self): self._model_handler.get_metrics_namespace()) self._model = self._load_model() - def process(self, batch, extra_kwargs): + def process(self, batch, inference_args): start_time = _to_microseconds(self._clock.time_ns()) result_generator = self._model_handler.run_inference( - batch, self._model, extra_kwargs) + batch, self._model, inference_args) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 018194b22e93..71bf2131a68b 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -48,7 +48,7 @@ def run_inference( self, batch: Sequence[int], model: FakeModel, - extra_kwargs=None) -> Iterable[int]: + inference_args=None) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds for example in batch: @@ -70,7 +70,7 @@ def process(self, prediction_result): class FakeModelHandlerNeedsBigBatch(FakeModelHandler): - def run_inference(self, batch, unused_model, extra_kwargs=None): + def run_inference(self, batch, unused_model, inference_args=None): if len(batch) < 100: raise ValueError('Unexpectedly small batch') return batch @@ -79,10 +79,10 @@ def batch_elements_kwargs(self): return {'min_batch_size': 9999} -class FakeModelHandlerExtraKwargs(FakeModelHandler): - def run_inference(self, batch, unused_model, extra_kwargs=None): - if not extra_kwargs: - raise ValueError('extra_kwargs should exist') +class FakeModelHandlerExtraInferenceArgs(FakeModelHandler): + def run_inference(self, batch, unused_model, inference_args=None): + if not inference_args: + raise ValueError('inference_args should exist') return batch @@ -122,13 +122,13 @@ def test_run_inference_impl_with_maybe_keyed_examples(self): model_handler) assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') - def test_run_inference_impl_extra_kwargs(self): + def test_run_inference_impl_inference_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] pcoll = pipeline | 'start' >> beam.Create(examples) - extra_kwargs = {'key': True} + inference_args = {'key': True} actual = pcoll | base.RunInference( - FakeModelHandlerExtraKwargs(), extra_kwargs=extra_kwargs) + FakeModelHandlerExtraInferenceArgs(), inference_args=inference_args) assert_that(actual, equal_to(examples), label='assert:inferences') def test_counted_metrics(self): diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 04fbafea5f09..331677d76c2b 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -95,8 +95,8 @@ def run_inference( self, batch: Sequence[torch.Tensor], model: torch.nn.Module, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Iterable[PredictionResult]: + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Tensors and returns an Iterable of Tensor Predictions. @@ -104,11 +104,11 @@ def run_inference( This method stacks the list of Tensors in a vectorized format to optimize the inference call. """ - extra_kwargs = {} if not extra_kwargs else extra_kwargs + inference_args = {} if not inference_args else inference_args batched_tensors = torch.stack(batch) batched_tensors = _convert_to_device(batched_tensors, self._device) - predictions = model(batched_tensors, **extra_kwargs) + predictions = model(batched_tensors, **inference_args) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: @@ -168,8 +168,8 @@ def run_inference( self, batch: Sequence[Dict[str, torch.Tensor]], model: torch.nn.Module, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Iterable[PredictionResult]: + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Keyed Tensors and returns an Iterable of Tensor Predictions. @@ -177,7 +177,7 @@ def run_inference( For the same key across all examples, this will stack all Tensors values in a vectorized format to optimize the inference call. """ - extra_kwargs = {} if not extra_kwargs else extra_kwargs + inference_args = {} if not inference_args else inference_args # If elements in `batch` are provided as a dictionaries from key to Tensors, # then iterate through the batch list, and group Tensors to the same key @@ -190,7 +190,7 @@ def run_inference( batched_tensors = torch.stack(key_to_tensor_list[key]) batched_tensors = _convert_to_device(batched_tensors, self._device) key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **extra_kwargs) + predictions = model(**key_to_batched_tensors, **inference_args) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index b0fffeefeca9..5749efa8a8bd 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -122,7 +122,7 @@ def forward(self, x): return out -class PytorchLinearRegressionKeyedBatchAndExtraKwargs(torch.nn.Module): +class PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(torch.nn.Module): """ A linear model with batched keyed inputs and non-batchable extra args. @@ -222,19 +222,19 @@ def forward(self, k1, k2): for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertTrue(_compare_prediction_result(actual, expected)) - def test_inference_runner_extra_kwargs(self): + def test_inference_runner_inference_args(self): """ This tests for non-batchable input arguments. Since we do the batching for the user, we have to distinguish between the inputs that should be batched and the ones that should not be batched. """ - extra_kwargs = { + inference_args = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True } - model = PytorchLinearRegressionKeyedBatchAndExtraKwargs( + model = PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs( input_dim=1, output_dim=1) model.load_state_dict( OrderedDict([('linear.weight', torch.Tensor([[2.0]])), @@ -244,7 +244,7 @@ def test_inference_runner_extra_kwargs(self): inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference( - batch=KEYED_TORCH_EXAMPLES, model=model, extra_kwargs=extra_kwargs) + batch=KEYED_TORCH_EXAMPLES, model=model, inference_args=inference_args) for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertEqual(actual, expected) @@ -293,9 +293,9 @@ def test_pipeline_local_model_simple(self): equal_to( TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) - def test_pipeline_local_model_extra_args(self): + def test_pipeline_local_model_extra_inference_args(self): with TestPipeline() as pipeline: - extra_kwargs = { + inference_args = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True @@ -308,17 +308,17 @@ def test_pipeline_local_model_extra_args(self): model_handler = PytorchModelHandlerKeyedTensor( state_dict_path=path, - model_class=PytorchLinearRegressionKeyedBatchAndExtraKwargs, + model_class=PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs, model_params={ 'input_dim': 1, 'output_dim': 1 }) pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES) - extra_kwargs_side_input = ( - pipeline | 'create side' >> beam.Create(extra_kwargs)) + inference_args_side_input = ( + pipeline | 'create side' >> beam.Create(inference_args)) predictions = pcoll | RunInference( model_handler=model_handler, - extra_kwargs=beam.pvalue.AsDict(extra_kwargs_side_input)) + inference_args=beam.pvalue.AsDict(inference_args_side_input)) assert_that( predictions, equal_to( diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index e6e1ef11520e..80b550730581 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -59,6 +59,20 @@ def _load_model(model_uri, file_type): raise AssertionError('Unsupported serialization type.') +def _validate_inference_args(inference_args): + """Confirms that inference_args is None. + + scikit-learn models do not need extra arguments in their predict() call. + However, since inference_args is an argument in the RunInference interface, + we want to make sure it is not passed here in Sklearn's implementation of + RunInference. + """ + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because scikit-learn ' + 'models do not need extra arguments in their predict() call.') + + class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, PredictionResult, BaseEstimator]): @@ -80,8 +94,9 @@ def run_inference( self, batch: Sequence[numpy.ndarray], model: BaseEstimator, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Iterable[PredictionResult]: + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + _validate_inference_args(inference_args) # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) @@ -116,8 +131,9 @@ def run_inference( self, batch: Sequence[pandas.DataFrame], model: BaseEstimator, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Iterable[PredictionResult]: + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + _validate_inference_args(inference_args) # sklearn_inference currently only supports single rowed dataframes. for dataframe in batch: if dataframe.shape[0] != 1: diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index ecd81d204d6d..fc533940bab4 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -300,12 +300,22 @@ def test_pipeline_pandas_with_keys(self): actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) def test_infer_too_many_rows_in_dataframe(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, r'Only dataframes with single rows are supported'): data_frame_too_many_rows = pandas_dataframe() fake_model = FakeModel() inference_runner = SklearnModelHandlerPandas(model_uri='unused') inference_runner.run_inference([data_frame_too_many_rows], fake_model) + def test_inference_args_passed(self): + with self.assertRaisesRegex(ValueError, r'inference_args were provided'): + data_frame = pandas_dataframe() + fake_model = FakeModel() + inference_runner = SklearnModelHandlerPandas(model_uri='unused') + inference_runner.run_inference([data_frame], + fake_model, + inference_args={'key1': 'value1'}) + if __name__ == '__main__': unittest.main()