diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index d703f7b0312a..9664c6261356 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -35,10 +35,12 @@ import sys import time from typing import Any +from typing import Dict from typing import Generic from typing import Iterable from typing import Mapping from typing import NamedTuple +from typing import Optional from typing import Sequence from typing import Tuple from typing import TypeVar @@ -84,8 +86,11 @@ def load_model(self) -> ModelT: """Loads and initializes a model for processing.""" raise NotImplementedError(type(self)) - def run_inference(self, batch: Sequence[ExampleT], model: ModelT, - **kwargs) -> Iterable[PredictionT]: + def run_inference( + self, + batch: Sequence[ExampleT], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) @@ -125,11 +130,14 @@ def load_model(self) -> ModelT: return self._unkeyed.load_model() def run_inference( - self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, - **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: + self, + batch: Sequence[Tuple[KeyT, ExampleT]], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[Tuple[KeyT, PredictionT]]: keys, unkeyed_batch = zip(*batch) return zip( - keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs)) + keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: keys, unkeyed_batch = zip(*batch) @@ -173,7 +181,7 @@ def run_inference( self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], model: ModelT, - **kwargs + inference_args: Optional[Dict[str, Any]] = None ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: # Really the input should be # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] @@ -185,7 +193,7 @@ def run_inference( is_keyed = False unkeyed_batch = batch # type: ignore[assignment] unkeyed_results = self._unkeyed.run_inference( - unkeyed_batch, model, **kwargs) + unkeyed_batch, model, inference_args) if is_keyed: return zip(keys, unkeyed_results) else: @@ -217,6 +225,8 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], Args: model_handler: An implementation of ModelHandler. clock: A clock implementing get_current_time_in_microseconds. + inference_args: Extra arguments for models whose inference call requires + extra parameters. A transform that takes a PCollection of examples (or features) to be used on an ML model. It will then output inferences (or predictions) for those @@ -236,9 +246,9 @@ def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], clock=time, - **kwargs): + inference_args: Optional[Dict[str, Any]] = None): self._model_handler = model_handler - self._kwargs = kwargs + self._inference_args = inference_args self._clock = clock @classmethod @@ -268,7 +278,7 @@ def expand( | ( beam.ParDo( _RunInferenceDoFn(self._model_handler, self._clock), - **self._kwargs).with_resource_hints(**resource_hints))) + self._inference_args).with_resource_hints(**resource_hints))) class _MetricsCollector: @@ -352,10 +362,10 @@ def setup(self): self._model_handler.get_metrics_namespace()) self._model = self._load_model() - def process(self, batch, **kwargs): + def process(self, batch, inference_args): start_time = _to_microseconds(self._clock.time_ns()) result_generator = self._model_handler.run_inference( - batch, self._model, **kwargs) + batch, self._model, inference_args) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 702cdd77f937..98fc2523b6dd 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -44,8 +44,11 @@ def load_model(self): self._fake_clock.current_time_ns += 500_000_000 # 500ms return FakeModel() - def run_inference(self, batch: Sequence[int], model: FakeModel, - **kwargs) -> Iterable[int]: + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + inference_args=None) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds for example in batch: @@ -67,7 +70,7 @@ def process(self, prediction_result): class FakeModelHandlerNeedsBigBatch(FakeModelHandler): - def run_inference(self, batch, unused_model): + def run_inference(self, batch, unused_model, inference_args=None): if len(batch) < 100: raise ValueError('Unexpectedly small batch') return batch @@ -76,10 +79,10 @@ def batch_elements_kwargs(self): return {'min_batch_size': 9999} -class FakeModelHandlerWithKwargs(FakeModelHandler): - def run_inference(self, batch, unused_model, **kwargs): - if not kwargs.get('key'): - raise ValueError('key should be True') +class FakeModelHandlerExtraInferenceArgs(FakeModelHandler): + def run_inference(self, batch, unused_model, inference_args=None): + if not inference_args: + raise ValueError('inference_args should exist') return batch @@ -119,12 +122,13 @@ def test_run_inference_impl_with_maybe_keyed_examples(self): model_handler) assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') - def test_run_inference_impl_kwargs(self): + def test_run_inference_impl_inference_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] pcoll = pipeline | 'start' >> beam.Create(examples) - kwargs = {'key': True} - actual = pcoll | base.RunInference(FakeModelHandlerWithKwargs(), **kwargs) + inference_args = {'key': True} + actual = pcoll | base.RunInference( + FakeModelHandlerExtraInferenceArgs(), inference_args=inference_args) assert_that(actual, equal_to(examples), label='assert:inferences') def test_counted_metrics(self): diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 959bce4778eb..331677d76c2b 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -22,6 +22,7 @@ from typing import Callable from typing import Dict from typing import Iterable +from typing import Optional from typing import Sequence import torch @@ -91,8 +92,11 @@ def load_model(self) -> torch.nn.Module: **self._model_params) def run_inference( - self, batch: Sequence[torch.Tensor], model: torch.nn.Module, - **kwargs) -> Iterable[PredictionResult]: + self, + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Tensors and returns an Iterable of Tensor Predictions. @@ -100,10 +104,11 @@ def run_inference( This method stacks the list of Tensors in a vectorized format to optimize the inference call. """ - prediction_params = kwargs.get('prediction_params', {}) + inference_args = {} if not inference_args else inference_args + batched_tensors = torch.stack(batch) batched_tensors = _convert_to_device(batched_tensors, self._device) - predictions = model(batched_tensors, **prediction_params) + predictions = model(batched_tensors, **inference_args) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: @@ -163,7 +168,8 @@ def run_inference( self, batch: Sequence[Dict[str, torch.Tensor]], model: torch.nn.Module, - **kwargs) -> Iterable[PredictionResult]: + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Keyed Tensors and returns an Iterable of Tensor Predictions. @@ -171,7 +177,7 @@ def run_inference( For the same key across all examples, this will stack all Tensors values in a vectorized format to optimize the inference call. """ - prediction_params = kwargs.get('prediction_params', {}) + inference_args = {} if not inference_args else inference_args # If elements in `batch` are provided as a dictionaries from key to Tensors, # then iterate through the batch list, and group Tensors to the same key @@ -184,7 +190,7 @@ def run_inference( batched_tensors = torch.stack(key_to_tensor_list[key]) batched_tensors = _convert_to_device(batched_tensors, self._device) key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **prediction_params) + predictions = model(**key_to_batched_tensors, **inference_args) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index d852dd72bb74..5749efa8a8bd 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -63,7 +63,7 @@ for f1, f2 in TWO_FEATURES_EXAMPLES]).reshape(-1, 1)) ] -KWARGS_TORCH_EXAMPLES = [ +KEYED_TORCH_EXAMPLES = [ { 'k1': torch.from_numpy(np.array([1], dtype="float32")), 'k2': torch.from_numpy(np.array([1.5], dtype="float32")) @@ -82,12 +82,12 @@ }, ] -KWARGS_TORCH_PREDICTIONS = [ +KEYED_TORCH_PREDICTIONS = [ PredictionResult(ex, pred) for ex, pred in zip( - KWARGS_TORCH_EXAMPLES, + KEYED_TORCH_EXAMPLES, torch.Tensor([(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5) - for example in KWARGS_TORCH_EXAMPLES]).reshape(-1, 1)) + for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) ] @@ -122,12 +122,12 @@ def forward(self, x): return out -class PytorchLinearRegressionKwargsPredictionParams(torch.nn.Module): +class PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(torch.nn.Module): """ - A linear model with kwargs inputs and non-batchable input params. + A linear model with batched keyed inputs and non-batchable extra args. - Note: k1 and k2 are batchable inputs passed in as a kwargs. - prediction_param_array, prediction_param_bool are non-batchable inputs + Note: k1 and k2 are batchable examples passed in as a dict from str to tensor. + prediction_param_array, prediction_param_bool are non-batchable extra args (typically model-related info) used to configure the model before its predict call is invoked """ @@ -186,10 +186,10 @@ def test_run_inference_multiple_tensor_features(self): for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS): self.assertEqual(actual, expected) - def test_run_inference_kwargs(self): + def test_run_inference_keyed(self): """ This tests for inputs that are passed as a dictionary from key to tensor - instead of a standard non-kwarg input. + instead of a standard non-keyed tensor example. Example: Typical input format is @@ -218,23 +218,23 @@ def forward(self, k1, k2): inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) - predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model) - for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): + predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model) + for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertTrue(_compare_prediction_result(actual, expected)) - def test_run_inference_kwargs_prediction_params(self): + def test_inference_runner_inference_args(self): """ This tests for non-batchable input arguments. Since we do the batching for the user, we have to distinguish between the inputs that should be batched and the ones that should not be batched. """ - prediction_params = { + inference_args = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True } - model = PytorchLinearRegressionKwargsPredictionParams( + model = PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs( input_dim=1, output_dim=1) model.load_state_dict( OrderedDict([('linear.weight', torch.Tensor([[2.0]])), @@ -244,10 +244,8 @@ def test_run_inference_kwargs_prediction_params(self): inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference( - batch=KWARGS_TORCH_EXAMPLES, - model=model, - prediction_params=prediction_params) - for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): + batch=KEYED_TORCH_EXAMPLES, model=model, inference_args=inference_args) + for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertEqual(actual, expected) def test_num_bytes(self): @@ -295,9 +293,9 @@ def test_pipeline_local_model_simple(self): equal_to( TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) - def test_pipeline_local_model_kwargs_prediction_params(self): + def test_pipeline_local_model_extra_inference_args(self): with TestPipeline() as pipeline: - prediction_params = { + inference_args = { 'prediction_param_array': torch.from_numpy( np.array([1, 2], dtype="float32")), 'prediction_param_bool': True @@ -310,21 +308,21 @@ def test_pipeline_local_model_kwargs_prediction_params(self): model_handler = PytorchModelHandlerKeyedTensor( state_dict_path=path, - model_class=PytorchLinearRegressionKwargsPredictionParams, + model_class=PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs, model_params={ 'input_dim': 1, 'output_dim': 1 }) - pcoll = pipeline | 'start' >> beam.Create(KWARGS_TORCH_EXAMPLES) - prediction_params_side_input = ( - pipeline | 'create side' >> beam.Create(prediction_params)) + pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES) + inference_args_side_input = ( + pipeline | 'create side' >> beam.Create(inference_args)) predictions = pcoll | RunInference( model_handler=model_handler, - prediction_params=beam.pvalue.AsDict(prediction_params_side_input)) + inference_args=beam.pvalue.AsDict(inference_args_side_input)) assert_that( predictions, equal_to( - KWARGS_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result)) + KEYED_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result)) @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed') def test_pipeline_gcs_model(self): diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 19c6f5b5046f..80b550730581 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -18,7 +18,10 @@ import enum import pickle import sys +from typing import Any +from typing import Dict from typing import Iterable +from typing import Optional from typing import Sequence import numpy @@ -56,6 +59,20 @@ def _load_model(model_uri, file_type): raise AssertionError('Unsupported serialization type.') +def _validate_inference_args(inference_args): + """Confirms that inference_args is None. + + scikit-learn models do not need extra arguments in their predict() call. + However, since inference_args is an argument in the RunInference interface, + we want to make sure it is not passed here in Sklearn's implementation of + RunInference. + """ + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because scikit-learn ' + 'models do not need extra arguments in their predict() call.') + + class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, PredictionResult, BaseEstimator]): @@ -74,8 +91,12 @@ def load_model(self) -> BaseEstimator: return _load_model(self._model_uri, self._model_file_type) def run_inference( - self, batch: Sequence[numpy.ndarray], model: BaseEstimator, - **kwargs) -> Iterable[PredictionResult]: + self, + batch: Sequence[numpy.ndarray], + model: BaseEstimator, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + _validate_inference_args(inference_args) # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) @@ -107,8 +128,12 @@ def load_model(self) -> BaseEstimator: return _load_model(self._model_uri, self._model_file_type) def run_inference( - self, batch: Sequence[pandas.DataFrame], model: BaseEstimator, - **kwargs) -> Iterable[PredictionResult]: + self, + batch: Sequence[pandas.DataFrame], + model: BaseEstimator, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + _validate_inference_args(inference_args) # sklearn_inference currently only supports single rowed dataframes. for dataframe in batch: if dataframe.shape[0] != 1: diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 422ee8c613db..978c3a8934d2 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -310,12 +310,22 @@ def test_pipeline_pandas_with_keys(self): actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) def test_infer_too_many_rows_in_dataframe(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, r'Only dataframes with single rows are supported'): data_frame_too_many_rows = pandas_dataframe() fake_model = FakeModel() inference_runner = SklearnModelHandlerPandas(model_uri='unused') inference_runner.run_inference([data_frame_too_many_rows], fake_model) + def test_inference_args_passed(self): + with self.assertRaisesRegex(ValueError, r'inference_args were provided'): + data_frame = pandas_dataframe() + fake_model = FakeModel() + inference_runner = SklearnModelHandlerPandas(model_uri='unused') + inference_runner.run_inference([data_frame], + fake_model, + inference_args={'key1': 'value1'}) + if __name__ == '__main__': unittest.main()