From 0837518906e11d36df478bf9cbd88671c959c58d Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 19 Mar 2025 16:54:00 -0400 Subject: [PATCH 01/10] Stash first functioning version --- sdks/python/apache_beam/ml/inference/base.py | 97 +++++++++++++++++++ .../apache_beam/ml/inference/base_test.py | 62 ++++++++++++ 2 files changed, 159 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index e4c3e4cab5e0..e9f5f61d61ad 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -56,7 +56,10 @@ from typing import Union import apache_beam as beam +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics.metric import Metrics from apache_beam.utils import multi_process_shared +from apache_beam.utils import retry from apache_beam.utils import shared try: @@ -67,6 +70,7 @@ _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 +_MILLISECOND_TO_SECOND = 1_000 ModelT = TypeVar('ModelT') ExampleT = TypeVar('ExampleT') @@ -339,6 +343,99 @@ def should_garbage_collect_on_timeout(self) -> bool: return self.share_model_across_processes() +class RemoteModelHandler(ModelHandler[ExampleT, PredictionT, ModelT]): + """Has the ability to call a model at a remote endpoint.""" + def __init__( + self, + namespace: str = '', + num_retries: int = 5, + throttle_delay_secs: int = 5, + retry_filter: Callable[[Exception], bool] = lambda x: True, + *, + window_ms: int = 1 * _MILLISECOND_TO_SECOND, + bucket_ms: int = 1 * _MILLISECOND_TO_SECOND, + overload_ratio: float = 2): + """Initializes metrics tracking + an AdaptiveThrottler class for enabling + client-side throttling for remote calls to an inference service. + + Args: + namespace: the metrics and logging namespace + num_retries: the maximum number of times to retry a request on retriable + errors before failing + throttle_delay_secs: the amount of time to throttle when the client-side + elects to throttle + retry_filter: a function accepting an exception as an argument and + returning a boolean. On a true return, the run_inference call will + be retried. Defaults to always retrying. + window_ms: length of history to consider, in ms, to set throttling. + bucket_ms: granularity of time buckets that we store data in, in ms. + overload_ratio: the target ratio between requests sent and successful + requests. This is "K" in the formula in + https://landing.google.com/sre/book/chapters/handling-overload.html. + """ + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + namespace, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio) + self.logger = logging.getLogger(namespace) + + self.num_retries = num_retries + self.throttle_delay_secs = throttle_delay_secs + self.retry_filter = retry_filter + + def retry_on_exception(func): + def wrapper(self, *args, **kwargs): + return retry.with_exponential_backoff( + num_retries=self.num_retries, retry_filter=self.retry_filter)( + func)(self, *args, **kwargs) + return wrapper + + @retry_on_exception + def run_inference( + self, + batch: Sequence[ExampleT], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: + """Runs inferences on a batch of examples. Calls a remote model for + predictions and will retry if a retryable exception is raised. + + Args: + batch: A sequence of examples or features. + model: The model used to make inferences. + inference_args: Extra arguments for models whose inference call requires + extra parameters. + + Returns: + An Iterable of Predictions. + """ + while self.throttler.throttle_request(time.time() * _MILLISECOND_TO_SECOND): + self.logger.info( + "Delaying request for %d seconds due to previous failures", + self.throttle_delay_secs) + time.sleep(self.throttle_delay_secs) + self.throttled_secs.inc(self.throttle_delay_secs) + + try: + req_time = time.time() + predictions = self.request(batch, model, inference_args) + self.throttler.successful_request(req_time * _MILLISECOND_TO_SECOND) + return predictions + except Exception as e: + self.logger.error("exception raised as part of request, got %s", e) + raise + + def request( + self, + batch: Sequence[ExampleT], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: + raise NotImplementedError(type(self)) + + class _ModelManager: """ A class for efficiently managing copies of multiple models. Will load a diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 31f02c9c61c5..e21bf65aa11a 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1870,5 +1870,67 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) +class FakeRemoteModelHandler(base.RemoteModelHandler[int, int, FakeModel]): + def __init__( + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + max_copies=1, + num_bytes_per_element=None, + retry_filter = lambda x: True, + **kwargs): + self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._env_vars = kwargs.get('env_vars', {}) + self._multi_process_shared = multi_process_shared + self._max_copies = max_copies + self._num_bytes_per_element = num_bytes_per_element + super().__init__(namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + + def load_model(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + for example in batch: + yield model.predict(example) + + def update_model_path(self, model_path: Optional[str] = None): + pass + + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } + + def share_model_across_processes(self): + return self._multi_process_shared + + def model_copies(self): + return self._max_copies + + def get_num_bytes(self, batch: Sequence[int]) -> int: + if self._num_bytes_per_element: + return self._num_bytes_per_element * len(batch) + return super().get_num_bytes(batch) + + +class RunInferenceRemoteTest(unittest.TestCase): + def test_normal_model_execution(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference(FakeRemoteModelHandler()) + assert_that(actual, equal_to(expected), label='assert:inferences') + + def test_works_on_retry(self): + pass + + def test_repeated_requests_fail(self): + pass + if __name__ == '__main__': unittest.main() From a666394d8be8b44ef5630a1fafb2bb6fe9b3271a Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Thu, 20 Mar 2025 10:50:36 -0400 Subject: [PATCH 02/10] unit tests --- .../apache_beam/ml/inference/base_test.py | 71 +++++++++++++------ 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index e21bf65aa11a..275a39b4afde 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -45,6 +45,7 @@ from apache_beam.transforms import window from apache_beam.transforms.periodicsequence import TimestampedValue from apache_beam.utils import multi_process_shared +from apache_beam.utils import retry class FakeModel: @@ -1870,51 +1871,71 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) -class FakeRemoteModelHandler(base.RemoteModelHandler[int, int, FakeModel]): +def _always_retry(e: Exception) -> bool: + return True + + +class FakeRemoteModelHandler(base.RemoteModelHandler[int, int, FakeModel]): def __init__( self, clock=None, min_batch_size=1, max_batch_size=9999, - max_copies=1, - num_bytes_per_element=None, - retry_filter = lambda x: True, + retry_filter=_always_retry, **kwargs): self._fake_clock = clock self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._env_vars = kwargs.get('env_vars', {}) self._multi_process_shared = multi_process_shared - self._max_copies = max_copies - self._num_bytes_per_element = num_bytes_per_element - super().__init__(namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + super().__init__( + namespace='FakeRemoteModelHandler', retry_filter=retry_filter) def load_model(self): return FakeModel() - + def request(self, batch, model, inference_args=None) -> Iterable[int]: for example in batch: yield model.predict(example) - def update_model_path(self, model_path: Optional[str] = None): - pass - def batch_elements_kwargs(self): return { 'min_batch_size': self._min_batch_size, 'max_batch_size': self._max_batch_size } - def share_model_across_processes(self): - return self._multi_process_shared - def model_copies(self): - return self._max_copies +class FakeAlwaysFailsRemoteModelHandler(base.RemoteModelHandler[int, + int, + FakeModel]): + def __init__( + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + retry_filter=_always_retry, + **kwargs): + self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._env_vars = kwargs.get('env_vars', {}) + super().__init__( + namespace='FakeRemoteModelHandler', + retry_filter=retry_filter, + num_retries=2, + throttle_delay_secs=1) - def get_num_bytes(self, batch: Sequence[int]) -> int: - if self._num_bytes_per_element: - return self._num_bytes_per_element * len(batch) - return super().get_num_bytes(batch) + def load_model(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + raise Exception + + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } class RunInferenceRemoteTest(unittest.TestCase): @@ -1926,11 +1947,15 @@ def test_normal_model_execution(self): actual = pcoll | base.RunInference(FakeRemoteModelHandler()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_works_on_retry(self): - pass - def test_repeated_requests_fail(self): - pass + test_pipeline = TestPipeline() + with self.assertRaises(Exception) as e: + _ = ( + test_pipeline + | beam.Create([1, 2, 3, 4]) + | base.RunInference(FakeAlwaysFailsRemoteModelHandler())) + test_pipeline.run() + if __name__ == '__main__': unittest.main() From 36091478f21b72b500e1367b7346e082a30cf238 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Fri, 21 Mar 2025 09:47:36 -0400 Subject: [PATCH 03/10] Unit tests, documentation --- sdks/python/apache_beam/ml/inference/base.py | 18 ++++++ .../apache_beam/ml/inference/base_test.py | 55 ++++++++++++++++++- sdks/python/apache_beam/utils/retry.py | 4 +- 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index e9f5f61d61ad..4387c863c97a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -27,6 +27,7 @@ collection, sharing model between threads, and batching elements. """ +import functools import logging import os import pickle @@ -388,6 +389,7 @@ def __init__( self.retry_filter = retry_filter def retry_on_exception(func): + @functools.wraps(func) def wrapper(self, *args, **kwargs): return retry.with_exponential_backoff( num_retries=self.num_retries, retry_filter=self.retry_filter)( @@ -433,6 +435,22 @@ def request( batch: Sequence[ExampleT], model: ModelT, inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: + """Makes a request to a remote inference service and returns the response. + Should raise an exception of some kind if there is an error to enable the + retry and client-side throttling logic to work. Returns an iterable of the + desired prediction type. This method should return the values directly, as + handling return values as a generator can prevent the retry logic from + functioning correctly. + + Args: + batch: A sequence of examples or features. + model: The model used to make inferences. + inference_args: Extra arguments for models whose inference call requires + extra parameters. + + Returns: + An Iterable of Predictions. + """ raise NotImplementedError(type(self)) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 275a39b4afde..39d6000bb284 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1895,8 +1895,10 @@ def load_model(self): return FakeModel() def request(self, batch, model, inference_args=None) -> Iterable[int]: + responses = [] for example in batch: - yield model.predict(example) + responses.append(model.predict(example)) + return responses def batch_elements_kwargs(self): return { @@ -1936,6 +1938,49 @@ def batch_elements_kwargs(self): 'min_batch_size': self._min_batch_size, 'max_batch_size': self._max_batch_size } + + +class FakeFailsOnceRemoteModelHandler(base.RemoteModelHandler[int, + int, + FakeModel]): + def __init__( + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + retry_filter=_always_retry, + **kwargs): + self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._env_vars = kwargs.get('env_vars', {}) + self._should_fail = True + super().__init__( + namespace='FakeRemoteModelHandler', + retry_filter=retry_filter, + num_retries=2, + throttle_delay_secs=1) + + def load_model(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + if self._should_fail: + self._should_fail = False + raise Exception + else: + self._should_fail = True + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + + + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } class RunInferenceRemoteTest(unittest.TestCase): @@ -1956,6 +2001,14 @@ def test_repeated_requests_fail(self): | base.RunInference(FakeAlwaysFailsRemoteModelHandler())) test_pipeline.run() + def test_works_on_retry(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference(FakeFailsOnceRemoteModelHandler()) + assert_that(actual, equal_to(expected), label='assert:inferences') + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/utils/retry.py b/sdks/python/apache_beam/utils/retry.py index 485fc9d627e9..03e42829b843 100644 --- a/sdks/python/apache_beam/utils/retry.py +++ b/sdks/python/apache_beam/utils/retry.py @@ -274,7 +274,9 @@ def with_exponential_backoff( The decorator is intended to be used on callables that make HTTP or RPC requests that can temporarily timeout or have transient errors. For instance the make_http_request() call below will be retried 16 times with exponential - backoff and fuzzing of the delay interval (default settings). + backoff and fuzzing of the delay interval (default settings). The callable + should return values directly instead of yielding them, as generators are not + evaluated within the try-catch block and will not be retried on exception. from apache_beam.utils import retry # ... From 35d44b9618ac18f776bb658f0938e40e54778a4d Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Fri, 21 Mar 2025 10:49:01 -0400 Subject: [PATCH 04/10] linting and formatting --- sdks/python/apache_beam/ml/inference/base.py | 5 +++-- sdks/python/apache_beam/ml/inference/base_test.py | 10 ++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 4387c863c97a..fda53ae922aa 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -392,8 +392,9 @@ def retry_on_exception(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): return retry.with_exponential_backoff( - num_retries=self.num_retries, retry_filter=self.retry_filter)( - func)(self, *args, **kwargs) + num_retries=self.num_retries, + retry_filter=self.retry_filter)(func)(self, *args, **kwargs) + return wrapper @retry_on_exception diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 39d6000bb284..697ce33f57ab 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -45,7 +45,6 @@ from apache_beam.transforms import window from apache_beam.transforms.periodicsequence import TimestampedValue from apache_beam.utils import multi_process_shared -from apache_beam.utils import retry class FakeModel: @@ -1938,11 +1937,11 @@ def batch_elements_kwargs(self): 'min_batch_size': self._min_batch_size, 'max_batch_size': self._max_batch_size } - + class FakeFailsOnceRemoteModelHandler(base.RemoteModelHandler[int, - int, - FakeModel]): + int, + FakeModel]): def __init__( self, clock=None, @@ -1975,7 +1974,6 @@ def request(self, batch, model, inference_args=None) -> Iterable[int]: responses.append(model.predict(example)) return responses - def batch_elements_kwargs(self): return { 'min_batch_size': self._min_batch_size, @@ -1994,7 +1992,7 @@ def test_normal_model_execution(self): def test_repeated_requests_fail(self): test_pipeline = TestPipeline() - with self.assertRaises(Exception) as e: + with self.assertRaises(Exception): _ = ( test_pipeline | beam.Create([1, 2, 3, 4]) From 6e674b88aeaed3c7dddba301e96b5d3540c84c33 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Fri, 21 Mar 2025 13:47:43 -0400 Subject: [PATCH 05/10] adjust line-too-long --- sdks/python/apache_beam/ml/inference/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index fda53ae922aa..67619f2ad5f5 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -358,6 +358,9 @@ def __init__( overload_ratio: float = 2): """Initializes metrics tracking + an AdaptiveThrottler class for enabling client-side throttling for remote calls to an inference service. + See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + for more details. on the configuration of the throttling and retry + mechanics. Args: namespace: the metrics and logging namespace @@ -376,7 +379,7 @@ def __init__( """ # Configure AdaptiveThrottler and throttling metrics for client-side # throttling behavior. - # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing # for more details. self.throttled_secs = Metrics.counter( namespace, "cumulativeThrottlingSeconds") From be288a9fffcc42924838d9dc42d08edea8da9599 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Mon, 24 Mar 2025 10:15:08 -0400 Subject: [PATCH 06/10] move external doc to shortlink --- sdks/python/apache_beam/ml/inference/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 67619f2ad5f5..ae3f70fe6ec7 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -358,8 +358,8 @@ def __init__( overload_ratio: float = 2): """Initializes metrics tracking + an AdaptiveThrottler class for enabling client-side throttling for remote calls to an inference service. - See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing - for more details. on the configuration of the throttling and retry + See https://s.apache.org/beam-client-side-throttling for more details + on the configuration of the throttling and retry mechanics. Args: @@ -379,8 +379,6 @@ def __init__( """ # Configure AdaptiveThrottler and throttling metrics for client-side # throttling behavior. - # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing - # for more details. self.throttled_secs = Metrics.counter( namespace, "cumulativeThrottlingSeconds") self.throttler = AdaptiveThrottler( From 181b218da5a4e744a7d38e1fb8c4b9fd1bfa868d Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 25 Mar 2025 10:45:27 -0400 Subject: [PATCH 07/10] create_client + __init_subclass__ --- sdks/python/apache_beam/ml/inference/base.py | 19 +++++++ .../apache_beam/ml/inference/base_test.py | 53 +++++++++++++++++-- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index ae3f70fe6ec7..739f4c4463bf 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -389,6 +389,25 @@ def __init__( self.throttle_delay_secs = throttle_delay_secs self.retry_filter = retry_filter + def __init_subclass__(cls): + if cls.load_model is not RemoteModelHandler.load_model: + raise Exception( + "Cannot override RemoteModelHandler.load_model, ", + "implement load_client instead.") + if cls.run_inference is not RemoteModelHandler.run_inference: + raise Exception( + "Cannot override RemoteModelHandler.run_inference, ", + "implement request instead.") + + def create_client(self) -> ModelT: + """Creates the client that is used to make the remote inference request + in request(). All relevant arguments should be passed to __init__(). + """ + raise NotImplementedError(type(self)) + + def load_model(self) -> ModelT: + return self.create_client() + def retry_on_exception(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 697ce33f57ab..fdd07c5c7a70 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1890,7 +1890,7 @@ def __init__( super().__init__( namespace='FakeRemoteModelHandler', retry_filter=retry_filter) - def load_model(self): + def create_client(self): return FakeModel() def request(self, batch, model, inference_args=None) -> Iterable[int]: @@ -1926,7 +1926,7 @@ def __init__( num_retries=2, throttle_delay_secs=1) - def load_model(self): + def create_client(self): return FakeModel() def request(self, batch, model, inference_args=None) -> Iterable[int]: @@ -1960,7 +1960,7 @@ def __init__( num_retries=2, throttle_delay_secs=1) - def load_model(self): + def create_client(self): return FakeModel() def request(self, batch, model, inference_args=None) -> Iterable[int]: @@ -2007,6 +2007,53 @@ def test_works_on_retry(self): actual = pcoll | base.RunInference(FakeFailsOnceRemoteModelHandler()) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_exception_on_run_inference_override(self): + with self.assertRaises(Exception): + + class FakeRemoteModelHandlerOverridesLoadModel( + base.RemoteModelHandler[int, int, FakeModel]): + def __init__(self, clock=None, retry_filter=_always_retry, **kwargs): + self._fake_clock = clock + self._min_batch_size = 1 + self._max_batch_size = 1 + self._env_vars = kwargs.get('env_vars', {}) + super().__init__( + namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + + def load_model(self): + return FakeModel() + + def request(self, batch, model, inference_args=None) -> Iterable[int]: + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + + def test_exception_on_run_inference_override(self): + with self.assertRaises(Exception): + + class FakeRemoteModelHandlerOverridesRunInference( + base.RemoteModelHandler[int, int, FakeModel]): + def __init__(self, clock=None, retry_filter=_always_retry, **kwargs): + self._fake_clock = clock + self._min_batch_size = 1 + self._max_batch_size = 1 + self._env_vars = kwargs.get('env_vars', {}) + super().__init__( + namespace='FakeRemoteModelHandler', retry_filter=retry_filter) + + def create_client(self): + return FakeModel() + + def run_inference(self, + batch, + model, + inference_args=None) -> Iterable[int]: + responses = [] + for example in batch: + responses.append(model.predict(example)) + return responses + if __name__ == '__main__': unittest.main() From d5a424e99c58dc52fff3eab1db9aefb6eb60b222 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 25 Mar 2025 11:16:39 -0400 Subject: [PATCH 08/10] fix exception message to match symbol change --- sdks/python/apache_beam/ml/inference/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 739f4c4463bf..1a0ba3944de4 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -393,7 +393,7 @@ def __init_subclass__(cls): if cls.load_model is not RemoteModelHandler.load_model: raise Exception( "Cannot override RemoteModelHandler.load_model, ", - "implement load_client instead.") + "implement create_client instead.") if cls.run_inference is not RemoteModelHandler.run_inference: raise Exception( "Cannot override RemoteModelHandler.run_inference, ", From 36af373da98f2affabcd040e0b40ca2dc4de527c Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 25 Mar 2025 13:03:07 -0400 Subject: [PATCH 09/10] Fix linting --- sdks/python/apache_beam/ml/inference/base_test.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index fdd07c5c7a70..b1dfded99432 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2007,11 +2007,10 @@ def test_works_on_retry(self): actual = pcoll | base.RunInference(FakeFailsOnceRemoteModelHandler()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_exception_on_run_inference_override(self): + def test_exception_on_load_model_override(self): with self.assertRaises(Exception): - class FakeRemoteModelHandlerOverridesLoadModel( - base.RemoteModelHandler[int, int, FakeModel]): + class _(base.RemoteModelHandler[int, int, FakeModel]): def __init__(self, clock=None, retry_filter=_always_retry, **kwargs): self._fake_clock = clock self._min_batch_size = 1 @@ -2032,8 +2031,7 @@ def request(self, batch, model, inference_args=None) -> Iterable[int]: def test_exception_on_run_inference_override(self): with self.assertRaises(Exception): - class FakeRemoteModelHandlerOverridesRunInference( - base.RemoteModelHandler[int, int, FakeModel]): + class _(base.RemoteModelHandler[int, int, FakeModel]): def __init__(self, clock=None, retry_filter=_always_retry, **kwargs): self._fake_clock = clock self._min_batch_size = 1 From 59a969a0a8c5b9d08488b7948fc662fd17f1aa25 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 2 Apr 2025 13:00:38 -0400 Subject: [PATCH 10/10] make RemoteModelHandler an explicit abstract base class --- sdks/python/apache_beam/ml/inference/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 1a0ba3944de4..117a73de1b9a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -35,6 +35,8 @@ import threading import time import uuid +from abc import ABC +from abc import abstractmethod from collections import OrderedDict from collections import defaultdict from copy import deepcopy @@ -344,7 +346,7 @@ def should_garbage_collect_on_timeout(self) -> bool: return self.share_model_across_processes() -class RemoteModelHandler(ModelHandler[ExampleT, PredictionT, ModelT]): +class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]): """Has the ability to call a model at a remote endpoint.""" def __init__( self, @@ -399,6 +401,7 @@ def __init_subclass__(cls): "Cannot override RemoteModelHandler.run_inference, ", "implement request instead.") + @abstractmethod def create_client(self) -> ModelT: """Creates the client that is used to make the remote inference request in request(). All relevant arguments should be passed to __init__(). @@ -451,6 +454,7 @@ def run_inference( self.logger.error("exception raised as part of request, got %s", e) raise + @abstractmethod def request( self, batch: Sequence[ExampleT],