Skip to content
Merged
34 changes: 22 additions & 12 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import sys
import time
from typing import Any
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
Expand Down Expand Up @@ -84,8 +86,11 @@ def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))

def run_inference(self, batch: Sequence[ExampleT], model: ModelT,
**kwargs) -> Iterable[PredictionT]:
def run_inference(
self,
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))
Expand Down Expand Up @@ -125,11 +130,14 @@ def load_model(self) -> ModelT:
return self._unkeyed.load_model()

def run_inference(
self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
**kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
self,
batch: Sequence[Tuple[KeyT, ExampleT]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[Tuple[KeyT, PredictionT]]:
keys, unkeyed_batch = zip(*batch)
return zip(
keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs))
keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args))

def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
keys, unkeyed_batch = zip(*batch)
Expand Down Expand Up @@ -173,7 +181,7 @@ def run_inference(
self,
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
model: ModelT,
**kwargs
inference_args: Optional[Dict[str, Any]] = None
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
# Really the input should be
# Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
Expand All @@ -185,7 +193,7 @@ def run_inference(
is_keyed = False
unkeyed_batch = batch # type: ignore[assignment]
unkeyed_results = self._unkeyed.run_inference(
unkeyed_batch, model, **kwargs)
unkeyed_batch, model, inference_args)
if is_keyed:
return zip(keys, unkeyed_results)
else:
Expand Down Expand Up @@ -217,6 +225,8 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing get_current_time_in_microseconds.
inference_args: Extra arguments for models whose inference call requires
extra parameters.

A transform that takes a PCollection of examples (or features) to be used on
an ML model. It will then output inferences (or predictions) for those
Expand All @@ -236,9 +246,9 @@ def __init__(
self,
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock=time,
**kwargs):
inference_args: Optional[Dict[str, Any]] = None):
self._model_handler = model_handler
self._kwargs = kwargs
self._inference_args = inference_args
self._clock = clock

@classmethod
Expand Down Expand Up @@ -268,7 +278,7 @@ def expand(
| (
beam.ParDo(
_RunInferenceDoFn(self._model_handler, self._clock),
**self._kwargs).with_resource_hints(**resource_hints)))
self._inference_args).with_resource_hints(**resource_hints)))


class _MetricsCollector:
Expand Down Expand Up @@ -352,10 +362,10 @@ def setup(self):
self._model_handler.get_metrics_namespace())
self._model = self._load_model()

def process(self, batch, **kwargs):
def process(self, batch, inference_args):
start_time = _to_microseconds(self._clock.time_ns())
result_generator = self._model_handler.run_inference(
batch, self._model, **kwargs)
batch, self._model, inference_args)
predictions = list(result_generator)

end_time = _to_microseconds(self._clock.time_ns())
Expand Down
24 changes: 14 additions & 10 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def load_model(self):
self._fake_clock.current_time_ns += 500_000_000 # 500ms
return FakeModel()

def run_inference(self, batch: Sequence[int], model: FakeModel,
**kwargs) -> Iterable[int]:
def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
Expand All @@ -67,7 +70,7 @@ def process(self, prediction_result):


class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
def run_inference(self, batch, unused_model):
def run_inference(self, batch, unused_model, inference_args=None):
if len(batch) < 100:
raise ValueError('Unexpectedly small batch')
return batch
Expand All @@ -76,10 +79,10 @@ def batch_elements_kwargs(self):
return {'min_batch_size': 9999}


class FakeModelHandlerWithKwargs(FakeModelHandler):
def run_inference(self, batch, unused_model, **kwargs):
if not kwargs.get('key'):
raise ValueError('key should be True')
class FakeModelHandlerExtraInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
if not inference_args:
raise ValueError('inference_args should exist')
return batch


Expand Down Expand Up @@ -119,12 +122,13 @@ def test_run_inference_impl_with_maybe_keyed_examples(self):
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')

def test_run_inference_impl_kwargs(self):
def test_run_inference_impl_inference_args(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
kwargs = {'key': True}
actual = pcoll | base.RunInference(FakeModelHandlerWithKwargs(), **kwargs)
inference_args = {'key': True}
actual = pcoll | base.RunInference(
FakeModelHandlerExtraInferenceArgs(), inference_args=inference_args)
assert_that(actual, equal_to(examples), label='assert:inferences')

def test_counted_metrics(self):
Expand Down
20 changes: 13 additions & 7 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence

import torch
Expand Down Expand Up @@ -91,19 +92,23 @@ def load_model(self) -> torch.nn.Module:
**self._model_params)

def run_inference(
self, batch: Sequence[torch.Tensor], model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
self,
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
Tensor Predictions.

This method stacks the list of Tensors in a vectorized format to optimize
the inference call.
"""
prediction_params = kwargs.get('prediction_params', {})
inference_args = {} if not inference_args else inference_args

batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, self._device)
predictions = model(batched_tensors, **prediction_params)
predictions = model(batched_tensors, **inference_args)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
Expand Down Expand Up @@ -163,15 +168,16 @@ def run_inference(
self,
batch: Sequence[Dict[str, torch.Tensor]],
model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Keyed Tensors and returns an Iterable of
Tensor Predictions.

For the same key across all examples, this will stack all Tensors values
in a vectorized format to optimize the inference call.
"""
prediction_params = kwargs.get('prediction_params', {})
inference_args = {} if not inference_args else inference_args

# If elements in `batch` are provided as a dictionaries from key to Tensors,
# then iterate through the batch list, and group Tensors to the same key
Expand All @@ -184,7 +190,7 @@ def run_inference(
batched_tensors = torch.stack(key_to_tensor_list[key])
batched_tensors = _convert_to_device(batched_tensors, self._device)
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **prediction_params)
predictions = model(**key_to_batched_tensors, **inference_args)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
Expand Down
52 changes: 25 additions & 27 deletions sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
for f1, f2 in TWO_FEATURES_EXAMPLES]).reshape(-1, 1))
]

KWARGS_TORCH_EXAMPLES = [
KEYED_TORCH_EXAMPLES = [
{
'k1': torch.from_numpy(np.array([1], dtype="float32")),
'k2': torch.from_numpy(np.array([1.5], dtype="float32"))
Expand All @@ -82,12 +82,12 @@
},
]

KWARGS_TORCH_PREDICTIONS = [
KEYED_TORCH_PREDICTIONS = [
PredictionResult(ex, pred) for ex,
pred in zip(
KWARGS_TORCH_EXAMPLES,
KEYED_TORCH_EXAMPLES,
torch.Tensor([(example['k1'] * 2.0 + 0.5) + (example['k2'] * 2.0 + 0.5)
for example in KWARGS_TORCH_EXAMPLES]).reshape(-1, 1))
for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1))
]


Expand Down Expand Up @@ -122,12 +122,12 @@ def forward(self, x):
return out


class PytorchLinearRegressionKwargsPredictionParams(torch.nn.Module):
class PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(torch.nn.Module):
"""
A linear model with kwargs inputs and non-batchable input params.
A linear model with batched keyed inputs and non-batchable extra args.

Note: k1 and k2 are batchable inputs passed in as a kwargs.
prediction_param_array, prediction_param_bool are non-batchable inputs
Note: k1 and k2 are batchable examples passed in as a dict from str to tensor.
prediction_param_array, prediction_param_bool are non-batchable extra args
(typically model-related info) used to configure the model before its predict
call is invoked
"""
Expand Down Expand Up @@ -186,10 +186,10 @@ def test_run_inference_multiple_tensor_features(self):
for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS):
self.assertEqual(actual, expected)

def test_run_inference_kwargs(self):
def test_run_inference_keyed(self):
"""
This tests for inputs that are passed as a dictionary from key to tensor
instead of a standard non-kwarg input.
instead of a standard non-keyed tensor example.

Example:
Typical input format is
Expand Down Expand Up @@ -218,23 +218,23 @@ def forward(self, k1, k2):

inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model)
for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model)
for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS):
self.assertTrue(_compare_prediction_result(actual, expected))

def test_run_inference_kwargs_prediction_params(self):
def test_inference_runner_inference_args(self):
"""
This tests for non-batchable input arguments. Since we do the batching
for the user, we have to distinguish between the inputs that should be
batched and the ones that should not be batched.
"""
prediction_params = {
inference_args = {
'prediction_param_array': torch.from_numpy(
np.array([1, 2], dtype="float32")),
'prediction_param_bool': True
}

model = PytorchLinearRegressionKwargsPredictionParams(
model = PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs(
input_dim=1, output_dim=1)
model.load_state_dict(
OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
Expand All @@ -244,10 +244,8 @@ def test_run_inference_kwargs_prediction_params(self):
inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(
batch=KWARGS_TORCH_EXAMPLES,
model=model,
prediction_params=prediction_params)
for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
batch=KEYED_TORCH_EXAMPLES, model=model, inference_args=inference_args)
for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS):
self.assertEqual(actual, expected)

def test_num_bytes(self):
Expand Down Expand Up @@ -295,9 +293,9 @@ def test_pipeline_local_model_simple(self):
equal_to(
TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result))

def test_pipeline_local_model_kwargs_prediction_params(self):
def test_pipeline_local_model_extra_inference_args(self):
with TestPipeline() as pipeline:
prediction_params = {
inference_args = {
'prediction_param_array': torch.from_numpy(
np.array([1, 2], dtype="float32")),
'prediction_param_bool': True
Expand All @@ -310,21 +308,21 @@ def test_pipeline_local_model_kwargs_prediction_params(self):

model_handler = PytorchModelHandlerKeyedTensor(
state_dict_path=path,
model_class=PytorchLinearRegressionKwargsPredictionParams,
model_class=PytorchLinearRegressionKeyedBatchAndExtraInferenceArgs,
model_params={
'input_dim': 1, 'output_dim': 1
})

pcoll = pipeline | 'start' >> beam.Create(KWARGS_TORCH_EXAMPLES)
prediction_params_side_input = (
pipeline | 'create side' >> beam.Create(prediction_params))
pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES)
inference_args_side_input = (
pipeline | 'create side' >> beam.Create(inference_args))
predictions = pcoll | RunInference(
model_handler=model_handler,
prediction_params=beam.pvalue.AsDict(prediction_params_side_input))
inference_args=beam.pvalue.AsDict(inference_args_side_input))
assert_that(
predictions,
equal_to(
KWARGS_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result))
KEYED_TORCH_PREDICTIONS, equals_fn=_compare_prediction_result))

@unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
def test_pipeline_gcs_model(self):
Expand Down
Loading