-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Introduce RemoteModelHandler abstract base class #34379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0837518
a666394
3609147
35d44b9
6e674b8
be288a9
181b218
d5a424e
36af373
59a969a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,13 +27,16 @@ | |
| collection, sharing model between threads, and batching elements. | ||
| """ | ||
|
|
||
| import functools | ||
| import logging | ||
| import os | ||
| import pickle | ||
| import sys | ||
| import threading | ||
| import time | ||
| import uuid | ||
| from abc import ABC | ||
| from abc import abstractmethod | ||
| from collections import OrderedDict | ||
| from collections import defaultdict | ||
| from copy import deepcopy | ||
|
|
@@ -56,7 +59,10 @@ | |
| from typing import Union | ||
|
|
||
| import apache_beam as beam | ||
| from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler | ||
| from apache_beam.metrics.metric import Metrics | ||
| from apache_beam.utils import multi_process_shared | ||
| from apache_beam.utils import retry | ||
| from apache_beam.utils import shared | ||
|
|
||
| try: | ||
|
|
@@ -67,6 +73,7 @@ | |
|
|
||
| _NANOSECOND_TO_MILLISECOND = 1_000_000 | ||
| _NANOSECOND_TO_MICROSECOND = 1_000 | ||
| _MILLISECOND_TO_SECOND = 1_000 | ||
|
|
||
| ModelT = TypeVar('ModelT') | ||
| ExampleT = TypeVar('ExampleT') | ||
|
|
@@ -339,6 +346,139 @@ def should_garbage_collect_on_timeout(self) -> bool: | |
| return self.share_model_across_processes() | ||
|
|
||
|
|
||
| class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]): | ||
| """Has the ability to call a model at a remote endpoint.""" | ||
| def __init__( | ||
| self, | ||
| namespace: str = '', | ||
| num_retries: int = 5, | ||
| throttle_delay_secs: int = 5, | ||
| retry_filter: Callable[[Exception], bool] = lambda x: True, | ||
| *, | ||
| window_ms: int = 1 * _MILLISECOND_TO_SECOND, | ||
| bucket_ms: int = 1 * _MILLISECOND_TO_SECOND, | ||
| overload_ratio: float = 2): | ||
| """Initializes metrics tracking + an AdaptiveThrottler class for enabling | ||
| client-side throttling for remote calls to an inference service. | ||
| See https://s.apache.org/beam-client-side-throttling for more details | ||
| on the configuration of the throttling and retry | ||
| mechanics. | ||
|
|
||
| Args: | ||
| namespace: the metrics and logging namespace | ||
| num_retries: the maximum number of times to retry a request on retriable | ||
| errors before failing | ||
| throttle_delay_secs: the amount of time to throttle when the client-side | ||
| elects to throttle | ||
| retry_filter: a function accepting an exception as an argument and | ||
| returning a boolean. On a true return, the run_inference call will | ||
| be retried. Defaults to always retrying. | ||
| window_ms: length of history to consider, in ms, to set throttling. | ||
| bucket_ms: granularity of time buckets that we store data in, in ms. | ||
| overload_ratio: the target ratio between requests sent and successful | ||
| requests. This is "K" in the formula in | ||
| https://landing.google.com/sre/book/chapters/handling-overload.html. | ||
| """ | ||
| # Configure AdaptiveThrottler and throttling metrics for client-side | ||
| # throttling behavior. | ||
| self.throttled_secs = Metrics.counter( | ||
| namespace, "cumulativeThrottlingSeconds") | ||
| self.throttler = AdaptiveThrottler( | ||
| window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio) | ||
| self.logger = logging.getLogger(namespace) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work if namespace is empty?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it will work, the logging and the metric will just not be specific to the model handler (for logging that isn't a big deal, but if you hypothetically had multiple distinct remote model handler classes they would share the same cumulativeThrottlingSeconds counter)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Maybe we can put a docstring about this behavior. |
||
|
|
||
| self.num_retries = num_retries | ||
| self.throttle_delay_secs = throttle_delay_secs | ||
| self.retry_filter = retry_filter | ||
|
|
||
| def __init_subclass__(cls): | ||
| if cls.load_model is not RemoteModelHandler.load_model: | ||
| raise Exception( | ||
| "Cannot override RemoteModelHandler.load_model, ", | ||
| "implement create_client instead.") | ||
| if cls.run_inference is not RemoteModelHandler.run_inference: | ||
| raise Exception( | ||
| "Cannot override RemoteModelHandler.run_inference, ", | ||
| "implement request instead.") | ||
|
|
||
| @abstractmethod | ||
| def create_client(self) -> ModelT: | ||
| """Creates the client that is used to make the remote inference request | ||
| in request(). All relevant arguments should be passed to __init__(). | ||
| """ | ||
| raise NotImplementedError(type(self)) | ||
|
|
||
| def load_model(self) -> ModelT: | ||
| return self.create_client() | ||
|
|
||
| def retry_on_exception(func): | ||
| @functools.wraps(func) | ||
| def wrapper(self, *args, **kwargs): | ||
| return retry.with_exponential_backoff( | ||
| num_retries=self.num_retries, | ||
| retry_filter=self.retry_filter)(func)(self, *args, **kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
| @retry_on_exception | ||
| def run_inference( | ||
| self, | ||
| batch: Sequence[ExampleT], | ||
| model: ModelT, | ||
| inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: | ||
| """Runs inferences on a batch of examples. Calls a remote model for | ||
| predictions and will retry if a retryable exception is raised. | ||
|
|
||
| Args: | ||
| batch: A sequence of examples or features. | ||
| model: The model used to make inferences. | ||
| inference_args: Extra arguments for models whose inference call requires | ||
| extra parameters. | ||
|
|
||
| Returns: | ||
| An Iterable of Predictions. | ||
| """ | ||
| while self.throttler.throttle_request(time.time() * _MILLISECOND_TO_SECOND): | ||
| self.logger.info( | ||
| "Delaying request for %d seconds due to previous failures", | ||
| self.throttle_delay_secs) | ||
| time.sleep(self.throttle_delay_secs) | ||
| self.throttled_secs.inc(self.throttle_delay_secs) | ||
|
|
||
| try: | ||
| req_time = time.time() | ||
| predictions = self.request(batch, model, inference_args) | ||
| self.throttler.successful_request(req_time * _MILLISECOND_TO_SECOND) | ||
| return predictions | ||
| except Exception as e: | ||
| self.logger.error("exception raised as part of request, got %s", e) | ||
| raise | ||
|
|
||
| @abstractmethod | ||
| def request( | ||
| self, | ||
| batch: Sequence[ExampleT], | ||
| model: ModelT, | ||
| inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: | ||
| """Makes a request to a remote inference service and returns the response. | ||
| Should raise an exception of some kind if there is an error to enable the | ||
| retry and client-side throttling logic to work. Returns an iterable of the | ||
| desired prediction type. This method should return the values directly, as | ||
| handling return values as a generator can prevent the retry logic from | ||
| functioning correctly. | ||
|
|
||
| Args: | ||
| batch: A sequence of examples or features. | ||
| model: The model used to make inferences. | ||
| inference_args: Extra arguments for models whose inference call requires | ||
| extra parameters. | ||
|
|
||
| Returns: | ||
| An Iterable of Predictions. | ||
| """ | ||
| raise NotImplementedError(type(self)) | ||
|
|
||
|
|
||
| class _ModelManager: | ||
| """ | ||
| A class for efficiently managing copies of multiple models. Will load a | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just an idea to use config classes since the init arg list could be long (might be not):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe @shunping has some good ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is largely coming from a perspective of how to surface this for yaml, but is this configuration consideration actually blocking here? We have to surface all of these parameters at some point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think it blocks your work. I just want to call out whether we want to do more like https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/anomaly/transforms.py#L61. If the arguments keep growing, a config class might be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mainly ask because you didn't approve the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, #34379 (comment)