Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
collection, sharing model between threads, and batching elements.
"""

import functools
import logging
import os
import pickle
import sys
import threading
import time
import uuid
from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
from collections import defaultdict
from copy import deepcopy
Expand All @@ -56,7 +59,10 @@
from typing import Union

import apache_beam as beam
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics.metric import Metrics
from apache_beam.utils import multi_process_shared
from apache_beam.utils import retry
from apache_beam.utils import shared

try:
Expand All @@ -67,6 +73,7 @@

_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000
_MILLISECOND_TO_SECOND = 1_000

ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
Expand Down Expand Up @@ -339,6 +346,139 @@ def should_garbage_collect_on_timeout(self) -> bool:
return self.share_model_across_processes()


class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]):
"""Has the ability to call a model at a remote endpoint."""
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an idea to use config classes since the init arg list could be long (might be not):

from dataclasses import dataclass, field
from typing import Callable

@dataclass
class RemoteModelHandlerConfig:
    namespace: str = ''
    num_retries: int = 5
    throttle_delay_secs: int = 5
    retry_filter: Callable[[Exception], bool] = field(default_factory=lambda: lambda x: True)
    window_ms: int = 1 * _MILLISECOND_TO_SECOND
    bucket_ms: int = 1 * _MILLISECOND_TO_SECOND
    overload_ratio: float = 2
    ```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe @shunping has some good ideas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is largely coming from a perspective of how to surface this for yaml, but is this configuration consideration actually blocking here? We have to surface all of these parameters at some point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it blocks your work. I just want to call out whether we want to do more like https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/anomaly/transforms.py#L61. If the arguments keep growing, a config class might be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mainly ask because you didn't approve the PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self,
namespace: str = '',
num_retries: int = 5,
throttle_delay_secs: int = 5,
retry_filter: Callable[[Exception], bool] = lambda x: True,
*,
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
overload_ratio: float = 2):
"""Initializes metrics tracking + an AdaptiveThrottler class for enabling
client-side throttling for remote calls to an inference service.
See https://s.apache.org/beam-client-side-throttling for more details
on the configuration of the throttling and retry
mechanics.

Args:
namespace: the metrics and logging namespace
num_retries: the maximum number of times to retry a request on retriable
errors before failing
throttle_delay_secs: the amount of time to throttle when the client-side
elects to throttle
retry_filter: a function accepting an exception as an argument and
returning a boolean. On a true return, the run_inference call will
be retried. Defaults to always retrying.
window_ms: length of history to consider, in ms, to set throttling.
bucket_ms: granularity of time buckets that we store data in, in ms.
overload_ratio: the target ratio between requests sent and successful
requests. This is "K" in the formula in
https://landing.google.com/sre/book/chapters/handling-overload.html.
"""
# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
self.throttled_secs = Metrics.counter(
namespace, "cumulativeThrottlingSeconds")
self.throttler = AdaptiveThrottler(
window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio)
self.logger = logging.getLogger(namespace)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work if namespace is empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will work, the logging and the metric will just not be specific to the model handler (for logging that isn't a big deal, but if you hypothetically had multiple distinct remote model handler classes they would share the same cumulativeThrottlingSeconds counter)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Maybe we can put a docstring about this behavior.


self.num_retries = num_retries
self.throttle_delay_secs = throttle_delay_secs
self.retry_filter = retry_filter

def __init_subclass__(cls):
if cls.load_model is not RemoteModelHandler.load_model:
raise Exception(
"Cannot override RemoteModelHandler.load_model, ",
"implement create_client instead.")
if cls.run_inference is not RemoteModelHandler.run_inference:
raise Exception(
"Cannot override RemoteModelHandler.run_inference, ",
"implement request instead.")

@abstractmethod
def create_client(self) -> ModelT:
"""Creates the client that is used to make the remote inference request
in request(). All relevant arguments should be passed to __init__().
"""
raise NotImplementedError(type(self))

def load_model(self) -> ModelT:
return self.create_client()

def retry_on_exception(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
return retry.with_exponential_backoff(
num_retries=self.num_retries,
retry_filter=self.retry_filter)(func)(self, *args, **kwargs)

return wrapper

@retry_on_exception
def run_inference(
self,
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples. Calls a remote model for
predictions and will retry if a retryable exception is raised.

Args:
batch: A sequence of examples or features.
model: The model used to make inferences.
inference_args: Extra arguments for models whose inference call requires
extra parameters.

Returns:
An Iterable of Predictions.
"""
while self.throttler.throttle_request(time.time() * _MILLISECOND_TO_SECOND):
self.logger.info(
"Delaying request for %d seconds due to previous failures",
self.throttle_delay_secs)
time.sleep(self.throttle_delay_secs)
self.throttled_secs.inc(self.throttle_delay_secs)

try:
req_time = time.time()
predictions = self.request(batch, model, inference_args)
self.throttler.successful_request(req_time * _MILLISECOND_TO_SECOND)
return predictions
except Exception as e:
self.logger.error("exception raised as part of request, got %s", e)
raise

@abstractmethod
def request(
self,
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
"""Makes a request to a remote inference service and returns the response.
Should raise an exception of some kind if there is an error to enable the
retry and client-side throttling logic to work. Returns an iterable of the
desired prediction type. This method should return the values directly, as
handling return values as a generator can prevent the retry logic from
functioning correctly.

Args:
batch: A sequence of examples or features.
model: The model used to make inferences.
inference_args: Extra arguments for models whose inference call requires
extra parameters.

Returns:
An Iterable of Predictions.
"""
raise NotImplementedError(type(self))


class _ModelManager:
"""
A class for efficiently managing copies of multiple models. Will load a
Expand Down
183 changes: 183 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,5 +1870,188 @@ def test_model_status_provides_valid_garbage_collection(self):
self.assertEqual(0, len(tags))


def _always_retry(e: Exception) -> bool:
return True


class FakeRemoteModelHandler(base.RemoteModelHandler[int, int, FakeModel]):
def __init__(
self,
clock=None,
min_batch_size=1,
max_batch_size=9999,
retry_filter=_always_retry,
**kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._env_vars = kwargs.get('env_vars', {})
self._multi_process_shared = multi_process_shared
super().__init__(
namespace='FakeRemoteModelHandler', retry_filter=retry_filter)

def create_client(self):
return FakeModel()

def request(self, batch, model, inference_args=None) -> Iterable[int]:
responses = []
for example in batch:
responses.append(model.predict(example))
return responses

def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}


class FakeAlwaysFailsRemoteModelHandler(base.RemoteModelHandler[int,
int,
FakeModel]):
def __init__(
self,
clock=None,
min_batch_size=1,
max_batch_size=9999,
retry_filter=_always_retry,
**kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._env_vars = kwargs.get('env_vars', {})
super().__init__(
namespace='FakeRemoteModelHandler',
retry_filter=retry_filter,
num_retries=2,
throttle_delay_secs=1)

def create_client(self):
return FakeModel()

def request(self, batch, model, inference_args=None) -> Iterable[int]:
raise Exception

def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}


class FakeFailsOnceRemoteModelHandler(base.RemoteModelHandler[int,
int,
FakeModel]):
def __init__(
self,
clock=None,
min_batch_size=1,
max_batch_size=9999,
retry_filter=_always_retry,
**kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._env_vars = kwargs.get('env_vars', {})
self._should_fail = True
super().__init__(
namespace='FakeRemoteModelHandler',
retry_filter=retry_filter,
num_retries=2,
throttle_delay_secs=1)

def create_client(self):
return FakeModel()

def request(self, batch, model, inference_args=None) -> Iterable[int]:
if self._should_fail:
self._should_fail = False
raise Exception
else:
self._should_fail = True
responses = []
for example in batch:
responses.append(model.predict(example))
return responses

def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}


class RunInferenceRemoteTest(unittest.TestCase):
def test_normal_model_execution(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeRemoteModelHandler())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_repeated_requests_fail(self):
test_pipeline = TestPipeline()
with self.assertRaises(Exception):
_ = (
test_pipeline
| beam.Create([1, 2, 3, 4])
| base.RunInference(FakeAlwaysFailsRemoteModelHandler()))
test_pipeline.run()

def test_works_on_retry(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeFailsOnceRemoteModelHandler())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_exception_on_load_model_override(self):
with self.assertRaises(Exception):

class _(base.RemoteModelHandler[int, int, FakeModel]):
def __init__(self, clock=None, retry_filter=_always_retry, **kwargs):
self._fake_clock = clock
self._min_batch_size = 1
self._max_batch_size = 1
self._env_vars = kwargs.get('env_vars', {})
super().__init__(
namespace='FakeRemoteModelHandler', retry_filter=retry_filter)

def load_model(self):
return FakeModel()

def request(self, batch, model, inference_args=None) -> Iterable[int]:
responses = []
for example in batch:
responses.append(model.predict(example))
return responses

def test_exception_on_run_inference_override(self):
with self.assertRaises(Exception):

class _(base.RemoteModelHandler[int, int, FakeModel]):
def __init__(self, clock=None, retry_filter=_always_retry, **kwargs):
self._fake_clock = clock
self._min_batch_size = 1
self._max_batch_size = 1
self._env_vars = kwargs.get('env_vars', {})
super().__init__(
namespace='FakeRemoteModelHandler', retry_filter=retry_filter)

def create_client(self):
return FakeModel()

def run_inference(self,
batch,
model,
inference_args=None) -> Iterable[int]:
responses = []
for example in batch:
responses.append(model.predict(example))
return responses


if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/utils/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def with_exponential_backoff(
The decorator is intended to be used on callables that make HTTP or RPC
requests that can temporarily timeout or have transient errors. For instance
the make_http_request() call below will be retried 16 times with exponential
backoff and fuzzing of the delay interval (default settings).
backoff and fuzzing of the delay interval (default settings). The callable
should return values directly instead of yielding them, as generators are not
evaluated within the try-catch block and will not be retried on exception.

from apache_beam.utils import retry
# ...
Expand Down
Loading