From a4d253ee7ee9ef8a55a742c15eb913e2e4a09594 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 9 Jun 2022 12:58:57 -0700 Subject: [PATCH 1/7] Make keying of examples explicit. This decouples the keying logic from the DoFn and helps with type inference. A MaybeKeyedModelLoader could be added to make this decision dynamically if desired. --- sdks/python/apache_beam/ml/inference/base.py | 64 ++++++++++++++----- .../apache_beam/ml/inference/base_test.py | 3 +- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 2d10cf7a1561..6688f79a619a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -37,6 +37,7 @@ from typing import Iterable from typing import List from typing import Mapping +from typing import Tuple from typing import TypeVar import apache_beam as beam @@ -54,6 +55,7 @@ ModelT = TypeVar('ModelT') ExampleT = TypeVar('ExampleT') PredictionT = TypeVar('PredictionT') +KeyT = TypeVar('KeyT') def _to_milliseconds(time_ns: int) -> int: @@ -101,6 +103,48 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: return {} +class KeyedModelLoader(Generic[KeyT, ExampleT, PredictionT, ModelT], + ModelLoader[Tuple[KeyT, ExampleT], + Tuple[KeyT, PredictionT], + ModelT]): + def __init__(self, unkeyed: ModelLoader[ExampleT, PredictionT, ModelT]): + self._unkeyed = unkeyed + + def load_model(self) -> ModelT: + return self._unkeyed.load_model() + + def get_inference_runner(self): + return KeyedInferenceRunner(self._unkeyed.get_inference_runner()) + + def get_resource_hints(self): + return self._unkeyed.get_resource_hints() + + def batch_elements_kwargs(self): + return self._unkeyed.batch_elements_kwargs() + + +class KeyedInferenceRunner(Generic[KeyT, ExampleT, PredictionT, ModelT], + InferenceRunner[Tuple[KeyT, ExampleT], + Tuple[KeyT, PredictionT], + ModelT]): + def __init__(self, unkeyed: InferenceRunner[ExampleT, PredictionT, ModelT]): + self._unkeyed = unkeyed + + def run_inference( + self, batch: List[Tuple[KeyT, ExampleT]], model: ModelT, + **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: + keys, unkeyed_batch = zip(*batch) + return zip( + keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs)) + + def get_num_bytes(self, batch: List[Tuple[KeyT, ExampleT]]) -> int: + keys, unkeyed_batch = zip(*batch) + return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) + + def get_metrics_namespace(self) -> str: + return self._unkeyed.get_metrics_namespace() + + class RunInference(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[PredictionT]]): """An extensible transform for running inferences. @@ -214,32 +258,18 @@ def setup(self): self._model = self._load_model() def process(self, batch, **kwargs): - # Process supports both keyed data, and example only data. - # First keys and samples are separated (if there are keys) - has_keys = isinstance(batch[0], tuple) - if has_keys: - examples = [example for _, example in batch] - keys = [key for key, _ in batch] - else: - examples = batch - keys = None - start_time = _to_microseconds(self._clock.time_ns()) result_generator = self._inference_runner.run_inference( - examples, self._model, **kwargs) + batch, self._model, **kwargs) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) inference_latency = end_time - start_time - num_bytes = self._inference_runner.get_num_bytes(examples) + num_bytes = self._inference_runner.get_num_bytes(batch) num_elements = len(batch) self._metrics_collector.update(num_elements, num_bytes, inference_latency) - # Keys are recombined with predictions in the RunInference PTransform. - if has_keys: - yield from zip(keys, predictions) - else: - yield from predictions + return predictions def finish_bundle(self): # TODO(BEAM-13970): Figure out why there is a cache. diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 9616ba6060a6..9b4cad1dec48 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -116,7 +116,8 @@ def test_run_inference_impl_with_keyed_examples(self): keyed_examples = [(i, example) for i, example in enumerate(examples)] expected = [(i, example + 1) for i, example in enumerate(examples)] pcoll = pipeline | 'start' >> beam.Create(keyed_examples) - actual = pcoll | base.RunInference(FakeModelLoader()) + actual = pcoll | base.RunInference( + base.KeyedModelLoader(FakeModelLoader())) assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_impl_kwargs(self): From 31c7788f3125d94145967273ff5af67bb2338bcb Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 9 Jun 2022 16:40:19 -0700 Subject: [PATCH 2/7] mypy --- sdks/python/apache_beam/ml/inference/base.py | 10 +++++----- .../apache_beam/ml/inference/pytorch_inference.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 6688f79a619a..0477365adf1f 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -35,8 +35,8 @@ from typing import Any from typing import Generic from typing import Iterable -from typing import List from typing import Mapping +from typing import Sequence from typing import Tuple from typing import TypeVar @@ -68,13 +68,13 @@ def _to_microseconds(time_ns: int) -> int: class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]): """Implements running inferences for a framework.""" - def run_inference(self, batch: List[ExampleT], model: ModelT, + def run_inference(self, batch: Sequence[ExampleT], model: ModelT, **kwargs) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" raise NotImplementedError(type(self)) - def get_num_bytes(self, batch: List[ExampleT]) -> int: + def get_num_bytes(self, batch: Sequence[ExampleT]) -> int: """Returns the number of bytes of data for a batch.""" return len(pickle.dumps(batch)) @@ -131,13 +131,13 @@ def __init__(self, unkeyed: InferenceRunner[ExampleT, PredictionT, ModelT]): self._unkeyed = unkeyed def run_inference( - self, batch: List[Tuple[KeyT, ExampleT]], model: ModelT, + self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: keys, unkeyed_batch = zip(*batch) return zip( keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs)) - def get_num_bytes(self, batch: List[Tuple[KeyT, ExampleT]]) -> int: + def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: keys, unkeyed_batch = zip(*batch) return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index d2a47d530635..9f50a672ac06 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -22,7 +22,7 @@ from typing import Callable from typing import Dict from typing import Iterable -from typing import List +from typing import Sequence from typing import Union import torch @@ -56,7 +56,7 @@ def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: def run_inference( self, - batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], model: torch.nn.Module, **kwargs) -> Iterable[PredictionResult]: """ @@ -88,7 +88,7 @@ def run_inference( predictions = model(batched_tensors, **prediction_params) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] - def get_num_bytes(self, batch: List[torch.Tensor]) -> int: + def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """Returns the number of bytes of data for a batch of Tensors.""" # If elements in `batch` are provided as a dictionaries from key to Tensors if isinstance(batch[0], dict): From 7e01380c0725292d6c95071d558c9c1c1a7aa027 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 10 Jun 2022 11:23:23 -0700 Subject: [PATCH 3/7] fix merge, yapf --- sdks/python/apache_beam/ml/inference/base.py | 23 ++++++-------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 648715617ef6..5dcff5ae5e83 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -72,7 +72,7 @@ def load_model(self) -> ModelT: """Loads and initializes a model for processing.""" raise NotImplementedError(type(self)) - def run_inference(self, batch: List[ExampleT], model: ModelT, + def run_inference(self, batch: Sequence[ExampleT], model: ModelT, **kwargs) -> Iterable[PredictionT]: """Runs inferences on a batch of examples and returns an Iterable of Predictions.""" @@ -96,10 +96,10 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], - KeyedModelHandler[Tuple[KeyT, ExampleT], - Tuple[KeyT, PredictionT], - ModelT]): - def __init__(self, unkeyed: ModelLoader[ExampleT, PredictionT, ModelT]): + ModelHandler[Tuple[KeyT, ExampleT], + Tuple[KeyT, PredictionT], + ModelT]): + def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): self._unkeyed = unkeyed def load_model(self) -> ModelT: @@ -239,22 +239,13 @@ def setup(self): def process(self, batch, **kwargs): start_time = _to_microseconds(self._clock.time_ns()) -<<<<<<< HEAD - result_generator = self._inference_runner.run_inference( - batch, self._model, **kwargs) -======= result_generator = self._model_handler.run_inference( - examples, self._model, **kwargs) ->>>>>>> master + batch, self._model, **kwargs) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) inference_latency = end_time - start_time -<<<<<<< HEAD - num_bytes = self._inference_runner.get_num_bytes(batch) -======= - num_bytes = self._model_handler.get_num_bytes(examples) ->>>>>>> master + num_bytes = self._model_handler.get_num_bytes(batch) num_elements = len(batch) self._metrics_collector.update(num_elements, num_bytes, inference_latency) From dcd328967eb1b462b6d05bd5fce430c380defca6 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 10 Jun 2022 11:35:29 -0700 Subject: [PATCH 4/7] MaybeKeyed --- sdks/python/apache_beam/ml/inference/base.py | 63 +++++++++++++++++++ .../apache_beam/ml/inference/base_test.py | 17 +++++ 2 files changed, 80 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 5dcff5ae5e83..d6746995d888 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -99,6 +99,13 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], Tuple[KeyT, PredictionT], ModelT]): + """A ModelHandler that takes keyed examples and returns keyed predictions. + + For example, if the original model was used with RunInference to take a + PCollection[E] to a PCollection[P], this would take a + PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], allowing one to + associate the outputs with the inputs based on the key. + """ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): self._unkeyed = unkeyed @@ -124,6 +131,62 @@ def get_resource_hints(self): def batch_elements_kwargs(self): return self._unkeyed.batch_elements_kwargs() + return {} + + +class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], + ModelHandler[Tuple[KeyT, ExampleT], + Tuple[KeyT, PredictionT], + ModelT]): + """A ModelHandler that takes possibly keyed examples and returns possibly + keyed predictions. + + For example, if the original model was used with RunInference to take a + PCollection[E] to a PCollection[P], this would take either PCollection[E] to a + PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], + depending on the whether the elements happen to be tuples, allowing one to + associate the outputs with the inputs based on the key. + + Note that this cannot be used if E happens to be a tuple type. + """ + def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): + self._unkeyed = unkeyed + + def load_model(self) -> ModelT: + return self._unkeyed.load_model() + + def run_inference( + self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, + **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: + if isinstance(batch[0], tuple): + is_keyed = True + keys, unkeyed_batch = zip(*batch) + else: + is_keyed = False + unkeyed_batch = batch + unkeyed_results = self._unkeyed.run_inference( + unkeyed_batch, model, **kwargs) + if is_keyed: + return zip(keys, unkeyed_results) + else: + return unkeyed_results + + def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: + if isinstance(batch[0], tuple): + keys, unkeyed_batch = zip(*batch) + return len( + pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) + else: + return self._unkeyed.get_num_bytes(batch) + + def get_metrics_namespace(self) -> str: + return self._unkeyed.get_metrics_namespace() + + def get_resource_hints(self): + return self._unkeyed.get_resource_hints() + + def batch_elements_kwargs(self): + return self._unkeyed.batch_elements_kwargs() class RunInference(beam.PTransform[beam.PCollection[ExampleT], diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 0341a32289cb..d7ae43eaf6af 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -102,6 +102,23 @@ def test_run_inference_impl_with_keyed_examples(self): base.KeyedModelHandler(FakeModelHandler())) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_maybe_keyed_examples(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + keyed_examples = [(i, example) for i, example in enumerate(examples)] + expected = [example + 1 for example in examples] + keyed_expected = [(i, example + 1) for i, example in enumerate(examples)] + model_handler = base.MaybeKeyedModelHandler(FakeModelHandler()) + + pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) + actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler) + assert_that(actual, equal_to(expected), label='CheckUnkeyed') + + keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples) + keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference( + model_handler) + assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') + def test_run_inference_impl_kwargs(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] From d6945863091e810d39d242c291aa22cb6c413df2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 10 Jun 2022 12:53:30 -0700 Subject: [PATCH 5/7] More List -> Sequence changes. --- sdks/python/apache_beam/ml/inference/base_test.py | 4 ++-- .../apache_beam/ml/inference/sklearn_inference.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index d7ae43eaf6af..52f8f883203f 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -20,7 +20,7 @@ import pickle import unittest from typing import Iterable -from typing import List +from typing import Sequence import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter @@ -44,7 +44,7 @@ def load_model(self): self._fake_clock.current_time_ns += 500_000_000 # 500ms return FakeModel() - def run_inference(self, batch: List[int], model: FakeModel, + def run_inference(self, batch: Sequence[int], model: FakeModel, **kwargs) -> Iterable[int]: if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 3c8eddfd7d3a..5ca6a18b5d1d 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -20,7 +20,7 @@ import sys from typing import Any from typing import Iterable -from typing import List +from typing import Sequence from typing import Union import numpy @@ -75,7 +75,7 @@ def load_model(self) -> BaseEstimator: def run_inference( self, - batch: List[Union[numpy.ndarray, pandas.DataFrame]], + batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]], model: BaseEstimator, **kwargs) -> Iterable[PredictionResult]: # TODO(github.com/apache/beam/issues/21769): Use supplied input type hint. @@ -86,7 +86,7 @@ def run_inference( raise ValueError('Unsupported data type.') @staticmethod - def _predict_np_array(batch: List[numpy.ndarray], + def _predict_np_array(batch: Sequence[numpy.ndarray], model: Any) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) @@ -94,7 +94,7 @@ def _predict_np_array(batch: List[numpy.ndarray], return [PredictionResult(x, y) for x, y in zip(batch, predictions)] @staticmethod - def _predict_pandas_dataframe(batch: List[pandas.DataFrame], + def _predict_pandas_dataframe(batch: Sequence[pandas.DataFrame], model: Any) -> Iterable[PredictionResult]: # sklearn_inference currently only supports single rowed dataframes. for dataframe in batch: @@ -113,11 +113,11 @@ def _predict_pandas_dataframe(batch: List[pandas.DataFrame], ] def get_num_bytes( - self, batch: List[Union[numpy.ndarray, pandas.DataFrame]]) -> int: + self, batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]]) -> int: """Returns the number of bytes of data for a batch.""" if isinstance(batch[0], numpy.ndarray): return sum(sys.getsizeof(element) for element in batch) elif isinstance(batch[0], pandas.DataFrame): - data_frames: List[pandas.DataFrame] = batch + data_frames: Sequence[pandas.DataFrame] = batch return sum(df.memory_usage(deep=True).sum() for df in data_frames) raise ValueError('Unsupported data type.') From 02bfe394570d4f3d406595e536c6b6da6d1a4014 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 10 Jun 2022 13:25:36 -0700 Subject: [PATCH 6/7] one more instance of keyed --- sdks/python/apache_beam/ml/inference/sklearn_inference_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 91eb86e2de4b..0d7294eb4063 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -265,7 +265,7 @@ def test_pipeline_pandas_with_keys(self): pcoll = pipeline | 'start' >> beam.Create(keyed_rows) actual = pcoll | api.RunInference( - SklearnModelHandler(model_uri=temp_file_name)) + base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) expected = [ ('0', api.PredictionResult(splits[0], 5)), ('1', api.PredictionResult(splits[1], 8)), From 1dc8dcce686a70db5931f659d6a2b18fb432f7a0 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 10 Jun 2022 13:46:53 -0700 Subject: [PATCH 7/7] Properly type unions. --- sdks/python/apache_beam/ml/inference/base.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index d6746995d888..ae07ac0531ee 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -39,6 +39,7 @@ from typing import Sequence from typing import Tuple from typing import TypeVar +from typing import Union import apache_beam as beam from apache_beam.utils import shared @@ -135,8 +136,10 @@ def batch_elements_kwargs(self): class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], - ModelHandler[Tuple[KeyT, ExampleT], - Tuple[KeyT, PredictionT], + ModelHandler[Union[ExampleT, Tuple[KeyT, + ExampleT]], + Union[PredictionT, + Tuple[KeyT, PredictionT]], ModelT]): """A ModelHandler that takes possibly keyed examples and returns possibly keyed predictions. @@ -147,7 +150,8 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], depending on the whether the elements happen to be tuples, allowing one to associate the outputs with the inputs based on the key. - Note that this cannot be used if E happens to be a tuple type. + Note that this cannot be used if E happens to be a tuple type. In addition, + either all examples should be keyed, or none of them. """ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): self._unkeyed = unkeyed @@ -156,14 +160,20 @@ def load_model(self) -> ModelT: return self._unkeyed.load_model() def run_inference( - self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, - **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]: + self, + batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], + model: ModelT, + **kwargs + ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: + # Really the input should be + # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] + # but there's not a good way to express (or check) that. if isinstance(batch[0], tuple): is_keyed = True - keys, unkeyed_batch = zip(*batch) + keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] else: is_keyed = False - unkeyed_batch = batch + unkeyed_batch = batch # type: ignore[assignment] unkeyed_results = self._unkeyed.run_inference( unkeyed_batch, model, **kwargs) if is_keyed: @@ -171,13 +181,15 @@ def run_inference( else: return unkeyed_results - def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: + def get_num_bytes( + self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: + # MyPy can't follow the branching logic. if isinstance(batch[0], tuple): - keys, unkeyed_batch = zip(*batch) + keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] return len( pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) else: - return self._unkeyed.get_num_bytes(batch) + return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type] def get_metrics_namespace(self) -> str: return self._unkeyed.get_metrics_namespace()