diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 534512a44f3f..ae07ac0531ee 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -35,9 +35,11 @@ from typing import Any from typing import Generic from typing import Iterable -from typing import List from typing import Mapping +from typing import Sequence +from typing import Tuple from typing import TypeVar +from typing import Union import apache_beam as beam from apache_beam.utils import shared @@ -54,6 +56,7 @@ ModelT = TypeVar('ModelT') ExampleT = TypeVar('ExampleT') PredictionT = TypeVar('PredictionT') +KeyT = TypeVar('KeyT') def _to_milliseconds(time_ns: int) -> int: @@ -70,13 +73,13 @@ def load_model(self) -> ModelT: """Loads and initializes a model for processing.""" raise NotImplementedError(type(self)) - def run_inference(self, batch: List[ExampleT], model: ModelT, + def run_inference(self, batch: Sequence[ExampleT], model: ModelT, **kwargs) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) - def get_num_bytes(self, batch: List[ExampleT]) -> int: + def get_num_bytes(self, batch: Sequence[ExampleT]) -> int: """Returns the number of bytes of data for a batch.""" return len(pickle.dumps(batch)) @@ -93,6 +96,111 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: return {} +class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], + ModelHandler[Tuple[KeyT, ExampleT], + Tuple[KeyT, PredictionT], + ModelT]): + """A ModelHandler that takes keyed examples and returns keyed predictions. + + For example, if the original model was used with RunInference to take a + PCollection[E] to a PCollection[P], this would take a + PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], allowing one to + associate the outputs with the inputs based on the key. + """ + def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): + self._unkeyed = unkeyed + + def load_model(self) -> ModelT: + return self._unkeyed.load_model() + + def run_inference( + self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, + **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: + keys, unkeyed_batch = zip(*batch) + return zip( + keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs)) + + def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: + keys, unkeyed_batch = zip(*batch) + return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) + + def get_metrics_namespace(self) -> str: + return self._unkeyed.get_metrics_namespace() + + def get_resource_hints(self): + return self._unkeyed.get_resource_hints() + + def batch_elements_kwargs(self): + return self._unkeyed.batch_elements_kwargs() + return {} + + +class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], + ModelHandler[Union[ExampleT, Tuple[KeyT, + ExampleT]], + Union[PredictionT, + Tuple[KeyT, PredictionT]], + ModelT]): + """A ModelHandler that takes possibly keyed examples and returns possibly + keyed predictions. + + For example, if the original model was used with RunInference to take a + PCollection[E] to a PCollection[P], this would take either PCollection[E] to a + PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], + depending on the whether the elements happen to be tuples, allowing one to + associate the outputs with the inputs based on the key. + + Note that this cannot be used if E happens to be a tuple type. In addition, + either all examples should be keyed, or none of them. + """ + def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): + self._unkeyed = unkeyed + + def load_model(self) -> ModelT: + return self._unkeyed.load_model() + + def run_inference( + self, + batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], + model: ModelT, + **kwargs + ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: + # Really the input should be + # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] + # but there's not a good way to express (or check) that. + if isinstance(batch[0], tuple): + is_keyed = True + keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] + else: + is_keyed = False + unkeyed_batch = batch # type: ignore[assignment] + unkeyed_results = self._unkeyed.run_inference( + unkeyed_batch, model, **kwargs) + if is_keyed: + return zip(keys, unkeyed_results) + else: + return unkeyed_results + + def get_num_bytes( + self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: + # MyPy can't follow the branching logic. + if isinstance(batch[0], tuple): + keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] + return len( + pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) + else: + return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type] + + def get_metrics_namespace(self) -> str: + return self._unkeyed.get_metrics_namespace() + + def get_resource_hints(self): + return self._unkeyed.get_resource_hints() + + def batch_elements_kwargs(self): + return self._unkeyed.batch_elements_kwargs() + + class RunInference(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[PredictionT]]): """An extensible transform for running inferences. @@ -205,32 +313,18 @@ def setup(self): self._model = self._load_model() def process(self, batch, **kwargs): - # Process supports both keyed data, and example only data. - # First keys and samples are separated (if there are keys) - has_keys = isinstance(batch[0], tuple) - if has_keys: - examples = [example for _, example in batch] - keys = [key for key, _ in batch] - else: - examples = batch - keys = None - start_time = _to_microseconds(self._clock.time_ns()) result_generator = self._model_handler.run_inference( - examples, self._model, **kwargs) + batch, self._model, **kwargs) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) inference_latency = end_time - start_time - num_bytes = self._model_handler.get_num_bytes(examples) + num_bytes = self._model_handler.get_num_bytes(batch) num_elements = len(batch) self._metrics_collector.update(num_elements, num_bytes, inference_latency) - # Keys are recombined with predictions in the RunInference PTransform. - if has_keys: - yield from zip(keys, predictions) - else: - yield from predictions + return predictions def finish_bundle(self): # TODO(BEAM-13970): Figure out why there is a cache. diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 3ea2a9db12b5..52f8f883203f 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -20,7 +20,7 @@ import pickle import unittest from typing import Iterable -from typing import List +from typing import Sequence import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter @@ -44,7 +44,7 @@ def load_model(self): self._fake_clock.current_time_ns += 500_000_000 # 500ms return FakeModel() - def run_inference(self, batch: List[int], model: FakeModel, + def run_inference(self, batch: Sequence[int], model: FakeModel, **kwargs) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds @@ -98,9 +98,27 @@ def test_run_inference_impl_with_keyed_examples(self): keyed_examples = [(i, example) for i, example in enumerate(examples)] expected = [(i, example + 1) for i, example in enumerate(examples)] pcoll = pipeline | 'start' >> beam.Create(keyed_examples) - actual = pcoll | base.RunInference(FakeModelHandler()) + actual = pcoll | base.RunInference( + base.KeyedModelHandler(FakeModelHandler())) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_maybe_keyed_examples(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + keyed_examples = [(i, example) for i, example in enumerate(examples)] + expected = [example + 1 for example in examples] + keyed_expected = [(i, example + 1) for i, example in enumerate(examples)] + model_handler = base.MaybeKeyedModelHandler(FakeModelHandler()) + + pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) + actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler) + assert_that(actual, equal_to(expected), label='CheckUnkeyed') + + keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples) + keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference( + model_handler) + assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') + def test_run_inference_impl_kwargs(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 3a4fb2926f81..1a1afaaf1c86 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -22,7 +22,7 @@ from typing import Callable from typing import Dict from typing import Iterable -from typing import List +from typing import Sequence from typing import Union import torch @@ -87,7 +87,7 @@ def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: def run_inference( self, - batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], model: torch.nn.Module, **kwargs) -> Iterable[PredictionResult]: """ @@ -119,7 +119,7 @@ def run_inference( predictions = model(batched_tensors, **prediction_params) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] - def get_num_bytes(self, batch: List[torch.Tensor]) -> int: + def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """Returns the number of bytes of data for a batch of Tensors.""" # If elements in `batch` are provided as a dictionaries from key to Tensors if isinstance(batch[0], dict): diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 3c8eddfd7d3a..5ca6a18b5d1d 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -20,7 +20,7 @@ import sys from typing import Any from typing import Iterable -from typing import List +from typing import Sequence from typing import Union import numpy @@ -75,7 +75,7 @@ def load_model(self) -> BaseEstimator: def run_inference( self, - batch: List[Union[numpy.ndarray, pandas.DataFrame]], + batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]], model: BaseEstimator, **kwargs) -> Iterable[PredictionResult]: # TODO(github.com/apache/beam/issues/21769): Use supplied input type hint. @@ -86,7 +86,7 @@ def run_inference( raise ValueError('Unsupported data type.') @staticmethod - def _predict_np_array(batch: List[numpy.ndarray], + def _predict_np_array(batch: Sequence[numpy.ndarray], model: Any) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) @@ -94,7 +94,7 @@ def _predict_np_array(batch: List[numpy.ndarray], return [PredictionResult(x, y) for x, y in zip(batch, predictions)] @staticmethod - def _predict_pandas_dataframe(batch: List[pandas.DataFrame], + def _predict_pandas_dataframe(batch: Sequence[pandas.DataFrame], model: Any) -> Iterable[PredictionResult]: # sklearn_inference currently only supports single rowed dataframes. for dataframe in batch: @@ -113,11 +113,11 @@ def _predict_pandas_dataframe(batch: List[pandas.DataFrame], ] def get_num_bytes( - self, batch: List[Union[numpy.ndarray, pandas.DataFrame]]) -> int: + self, batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]]) -> int: """Returns the number of bytes of data for a batch.""" if isinstance(batch[0], numpy.ndarray): return sum(sys.getsizeof(element) for element in batch) elif isinstance(batch[0], pandas.DataFrame): - data_frames: List[pandas.DataFrame] = batch + data_frames: Sequence[pandas.DataFrame] = batch return sum(df.memory_usage(deep=True).sum() for df in data_frames) raise ValueError('Unsupported data type.') diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 91eb86e2de4b..0d7294eb4063 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -265,7 +265,7 @@ def test_pipeline_pandas_with_keys(self): pcoll = pipeline | 'start' >> beam.Create(keyed_rows) actual = pcoll | api.RunInference( - SklearnModelHandler(model_uri=temp_file_name)) + base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) expected = [ ('0', api.PredictionResult(splits[0], 5)), ('1', api.PredictionResult(splits[1], 8)),