diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index e4c3e4cab5e0..117a73de1b9a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -27,6 +27,7 @@ collection, sharing model between threads, and batching elements. """ +import functools import logging import os import pickle @@ -34,6 +35,8 @@ import threading import time import uuid +from abc import ABC +from abc import abstractmethod from collections import OrderedDict from collections import defaultdict from copy import deepcopy @@ -56,7 +59,10 @@ from typing import Union import apache_beam as beam +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics.metric import Metrics from apache_beam.utils import multi_process_shared +from apache_beam.utils import retry from apache_beam.utils import shared try: @@ -67,6 +73,7 @@ _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 +_MILLISECOND_TO_SECOND = 1_000 ModelT = TypeVar('ModelT') ExampleT = TypeVar('ExampleT') @@ -339,6 +346,139 @@ def should_garbage_collect_on_timeout(self) -> bool: return self.share_model_across_processes() +class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]): + """Has the ability to call a model at a remote endpoint.""" + def __init__( + self, + namespace: str = '', + num_retries: int = 5, + throttle_delay_secs: int = 5, + retry_filter: Callable[[Exception], bool] = lambda x: True, + *, + window_ms: int = 1 * _MILLISECOND_TO_SECOND, + bucket_ms: int = 1 * _MILLISECOND_TO_SECOND, + overload_ratio: float = 2): + """Initializes metrics tracking + an AdaptiveThrottler class for enabling + client-side throttling for remote calls to an inference service. + See https://s.apache.org/beam-client-side-throttling for more details + on the configuration of the throttling and retry + mechanics. + + Args: + namespace: the metrics and logging namespace + num_retries: the maximum number of times to retry a request on retriable + errors before failing + throttle_delay_secs: the amount of time to throttle when the client-side + elects to throttle + retry_filter: a function accepting an exception as an argument and + returning a boolean. On a true return, the run_inference call will + be retried. Defaults to always retrying. + window_ms: length of history to consider, in ms, to set throttling. + bucket_ms: granularity of time buckets that we store data in, in ms. + overload_ratio: the target ratio between requests sent and successful + requests. This is "K" in the formula in + https://landing.google.com/sre/book/chapters/handling-overload.html. + """ + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + self.throttled_secs = Metrics.counter( + namespace, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio) + self.logger = logging.getLogger(namespace) + + self.num_retries = num_retries + self.throttle_delay_secs = throttle_delay_secs + self.retry_filter = retry_filter + + def __init_subclass__(cls): + if cls.load_model is not RemoteModelHandler.load_model: + raise Exception( + "Cannot override RemoteModelHandler.load_model, ", + "implement create_client instead.") + if cls.run_inference is not RemoteModelHandler.run_inference: + raise Exception( + "Cannot override RemoteModelHandler.run_inference, ", + "implement request instead.") + + @abstractmethod + def create_client(self) -> ModelT: + """Creates the client that is used to make the remote inference request + in request(). All relevant arguments should be passed to __init__(). + """ + raise NotImplementedError(type(self)) + + def load_model(self) -> ModelT: + return self.create_client() + + def retry_on_exception(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + return retry.with_exponential_backoff( + num_retries=self.num_retries, + retry_filter=self.retry_filter)(func)(self, *args, **kwargs) + + return wrapper + + @retry_on_exception + def run_inference( + self, + batch: Sequence[ExampleT], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: + """Runs inferences on a batch of examples. Calls a remote model for + predictions and will retry if a retryable exception is raised. + + Args: + batch: A sequence of examples or features. + model: The model used to make inferences. + inference_args: Extra arguments for models whose inference call requires + extra parameters. + + Returns: + An Iterable of Predictions. + """ + while self.throttler.throttle_request(time.time() * _MILLISECOND_TO_SECOND): + self.logger.info( + "Delaying request for %d seconds due to previous failures", + self.throttle_delay_secs) + time.sleep(self.throttle_delay_secs) + self.throttled_secs.inc(self.throttle_delay_secs) + + try: + req_time = time.time() + predictions = self.request(batch, model, inference_args) + self.throttler.successful_request(req_time * _MILLISECOND_TO_SECOND) + return predictions + except Exception as e: + self.logger.error("exception raised as part of request, got %s", e) + raise + + @abstractmethod + def request( + self, + batch: Sequence[ExampleT], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: + """Makes a request to a remote inference service and returns the response. + Should raise an exception of some kind if there is an error to enable the + retry and client-side throttling logic to work. Returns an iterable of the + desired prediction type. This method should return the values directly, as + handling return values as a generator can prevent the retry logic from + functioning correctly. + + Args: + batch: A sequence of examples or features. + model: The model used to make inferences. + inference_args: Extra arguments for models whose inference call requires + extra parameters. + + Returns: + An Iterable of Predictions. + """ + raise NotImplementedError(type(self)) + + class _ModelManager: """ A class for efficiently managing copies of multiple models. Will load a diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 31f02c9c61c5..b1dfded99432 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1870,5 +1870,188 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) +def _always_retry(e: Exception) -> bool: + return True + + +class FakeRemoteModelHandler(base.RemoteModelHandler[int, int, FakeModel]): + def __init__( + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + retry_filter=_always_retry, + **kwargs): + self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._env_vars = kwargs.get('env_vars', {}) + self._multi_process_shared = multi_process_shared + super().__init__( + namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + + def create_client(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } + + +class FakeAlwaysFailsRemoteModelHandler(base.RemoteModelHandler[int, + int, + FakeModel]): + def __init__( + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + retry_filter=_always_retry, + **kwargs): + self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._env_vars = kwargs.get('env_vars', {}) + super().__init__( + namespace='FakeRemoteModelHandler', + retry_filter=retry_filter, + num_retries=2, + throttle_delay_secs=1) + + def create_client(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + raise Exception + + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } + + +class FakeFailsOnceRemoteModelHandler(base.RemoteModelHandler[int, + int, + FakeModel]): + def __init__( + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + retry_filter=_always_retry, + **kwargs): + self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._env_vars = kwargs.get('env_vars', {}) + self._should_fail = True + super().__init__( + namespace='FakeRemoteModelHandler', + retry_filter=retry_filter, + num_retries=2, + throttle_delay_secs=1) + + def create_client(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + if self._should_fail: + self._should_fail = False + raise Exception + else: + self._should_fail = True + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } + + +class RunInferenceRemoteTest(unittest.TestCase): + def test_normal_model_execution(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference(FakeRemoteModelHandler()) + assert_that(actual, equal_to(expected), label='assert:inferences') + + def test_repeated_requests_fail(self): + test_pipeline = TestPipeline() + with self.assertRaises(Exception): + _ = ( + test_pipeline + | beam.Create([1, 2, 3, 4]) + | base.RunInference(FakeAlwaysFailsRemoteModelHandler())) + test_pipeline.run() + + def test_works_on_retry(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference(FakeFailsOnceRemoteModelHandler()) + assert_that(actual, equal_to(expected), label='assert:inferences') + + def test_exception_on_load_model_override(self): + with self.assertRaises(Exception): + + class _(base.RemoteModelHandler[int, int, FakeModel]): + def __init__(self, clock=None, retry_filter=_always_retry, **kwargs): + self._fake_clock = clock + self._min_batch_size = 1 + self._max_batch_size = 1 + self._env_vars = kwargs.get('env_vars', {}) + super().__init__( + namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + + def load_model(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + + def test_exception_on_run_inference_override(self): + with self.assertRaises(Exception): + + class _(base.RemoteModelHandler[int, int, FakeModel]): + def __init__(self, clock=None, retry_filter=_always_retry, **kwargs): + self._fake_clock = clock + self._min_batch_size = 1 + self._max_batch_size = 1 + self._env_vars = kwargs.get('env_vars', {}) + super().__init__( + namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + + def create_client(self): + return FakeModel() + + def run_inference(self, + batch, + model, + inference_args=None) -> Iterable[int]: + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/utils/retry.py b/sdks/python/apache_beam/utils/retry.py index 485fc9d627e9..03e42829b843 100644 --- a/sdks/python/apache_beam/utils/retry.py +++ b/sdks/python/apache_beam/utils/retry.py @@ -274,7 +274,9 @@ def with_exponential_backoff( The decorator is intended to be used on callables that make HTTP or RPC requests that can temporarily timeout or have transient errors. For instance the make_http_request() call below will be retried 16 times with exponential - backoff and fuzzing of the delay interval (default settings). + backoff and fuzzing of the delay interval (default settings). The callable + should return values directly instead of yielding them, as generators are not + evaluated within the try-catch block and will not be retried on exception. from apache_beam.utils import retry # ...