diff --git a/sdks/python/apache_beam/examples/inference/gemini_text_classification.py b/sdks/python/apache_beam/examples/inference/gemini_text_classification.py new file mode 100644 index 000000000000..e82f407374a7 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/gemini_text_classification.py @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" A sample pipeline using the RunInference API to classify text using an LLM. +This pipeline creates a set of prompts and sends it to a Gemini service then +returns the predictions from the classifier model. This example uses the +gemini-2.0-flash-001 model. +""" + +import argparse +import logging +from collections.abc import Iterable + +import apache_beam as beam +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.gemini_inference import GeminiModelHandler +from apache_beam.ml.inference.gemini_inference import generate_from_string +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--output', + dest='output', + type=str, + required=True, + help='Path to save output predictions.') + parser.add_argument( + '--api_key', + dest='api_key', + type=str, + required=False, + help='Gemini Developer API key.') + parser.add_argument( + '--cloud_project', + dest='project', + type=str, + required=False, + help='GCP Project') + parser.add_argument( + '--cloud_region', + dest='location', + type=str, + required=False, + help='GCP location for the Endpoint') + return parser.parse_known_args(argv) + + +class PostProcessor(beam.DoFn): + def process(self, element: PredictionResult) -> Iterable[str]: + yield "Input: " + str(element.example) + " Output: " + str( + element.inference[1][0].content.parts[0].text) + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + """ + Args: + argv: Command line arguments defined for this example. + save_main_session: Used for internal testing. + test_pipeline: Used for internal testing. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + model_handler = GeminiModelHandler( + model_name='gemini-2.0-flash-001', + request_fn=generate_from_string, + api_key=known_args.api_key, + project=known_args.project, + location=known_args.location) + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + prompts = [ + "What is 5+2?", + "Who is the protagonist of Lord of the Rings?", + "What is the air-speed velocity of a laden swallow?" + ] + + read_prompts = pipeline | "Get prompt" >> beam.Create(prompts) + predictions = read_prompts | "RunInference" >> RunInference(model_handler) + processed = predictions | "PostProcess" >> beam.ParDo(PostProcessor()) + _ = processed | "PrintOutput" >> beam.Map(print) + _ = processed | "WriteOutput" >> beam.io.WriteToText( + known_args.output, shard_name_template='', append_trailing_newlines=True) + + result = pipeline.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py new file mode 100644 index 000000000000..fd1a7b0f7ac9 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Sequence +from typing import Any +from typing import Optional + +from google import genai +from google.genai import errors + +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RemoteModelHandler + +LOGGER = logging.getLogger("GeminiModelHandler") + + +def _retry_on_appropriate_service_error(exception: Exception) -> bool: + """ + Retry filter that returns True if a returned HTTP error code is 5xx or 429. + This is used to retry remote requests that fail, most notably 429 + (throttling by the service) + + Args: + exception: the returned exception encountered during the request/response + loop. + + Returns: + boolean indication whether or not the exception is a ServerError (5xx) or + a 429 error. + """ + if not isinstance(exception, errors.APIError): + return False + return exception.code == 429 or exception.code >= 500 + + +def generate_from_string( + model_name: str, + batch: Sequence[str], + model: genai.Client, + inference_args: dict[str, Any]): + return model.models.generate_content( + model=model_name, contents=batch, **inference_args) + + +class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult, + genai.Client]): + def __init__( + self, + model_name: str, + request_fn: Callable[[str, Sequence[Any], genai.Client, dict[str, Any]], + Any], + api_key: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + **kwargs): + """Implementation of the ModelHandler interface for Google Gemini. + **NOTE:** This API and its implementation are under development and + do not provide backward compatibility guarantees. + Gemini can be accessed through either the Vertex AI API or the Gemini + Developer API, and this handler chooses which to connect to based upon + the arguments provided. As written, this model handler operates solely on + string input. + + Args: + model_name: the Gemini model to send the request to + request_fn: the function to use to send the request. Should take the + model name and the parameters from request() and return the responses + from Gemini. The class will handle bundling the inputs and responses + together. + api_key: the Gemini Developer API key to use for the requests. Setting + this parameter sends requests for this job to the Gemini Developer API. + If this paramter is provided, do not set the project or location + parameters. + project: the GCP project to use for Vertex AI requests. Setting this + parameter routes requests to Vertex AI. If this paramter is provided, + location must also be provided and api_key should not be set. + location: the GCP project to use for Vertex AI requests. Setting this + parameter routes requests to Vertex AI. If this paramter is provided, + project must also be provided and api_key should not be set. + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + """ + self._batching_kwargs = {} + self._env_vars = kwargs.get('env_vars', {}) + if min_batch_size is not None: + self._batching_kwargs["min_batch_size"] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs + + self.model_name = model_name + self.request_fn = request_fn + + if api_key: + if project or location: + raise ValueError("project and location must be None if api_key is set") + self.api_key = api_key + self.use_vertex = False + else: + if project is None or location is None: + raise ValueError( + "project and location must both be provided if api_key is None") + self.project = project + self.location = location + self.use_vertex = True + + super().__init__( + namespace='GeminiModelHandler', + retry_filter=_retry_on_appropriate_service_error, + **kwargs) + + def create_client(self) -> genai.Client: + """Creates the GenAI client used to send requests. Creates a version for + the Vertex AI API or the Gemini Developer API based on the arguments + provided when the GeminiModelHandler class is instantiated. + """ + if self.use_vertex: + return genai.Client( + vertexai=True, project=self.project, location=self.location) + return genai.Client(api_key=self.api_key) + + def request( + self, + batch: Sequence[Any], + model: genai.Client, + inference_args: Optional[dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """ Sends a prediction request to a Gemini service containing a batch + of inputs and matches that input with the prediction response from + the endpoint as an iterable of PredictionResults. + + Args: + batch: a sequence of any values to be passed to the Gemini service. + Should be inputs accepted by the provided inference function. + model: a genai.Client object configured to access the desired service. + inference_args: any additional arguments to send as part of the + prediction request. + + Returns: + An iterable of Predictions. + """ + if inference_args is None: + inference_args = {} + responses = self.request_fn(self.model_name, batch, model, inference_args) + return utils._convert_to_result(batch, responses, self.model_name) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_it_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_it_test.py new file mode 100644 index 000000000000..d0cd9c236d67 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_it_test.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""End-to-End test for Gemini Remote Inference""" + +import logging +import unittest +import uuid + +import pytest + +from apache_beam.io.filesystems import FileSystems +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=ungrouped-imports +try: + from apache_beam.examples.inference import gemini_text_classification +except ImportError as e: + raise unittest.SkipTest("Gemini model handler dependencies are not installed") + +_OUTPUT_DIR = "gs://apache-beam-ml/testing/outputs/gemini" +_TEST_PROJECT = "apache-beam-testing" +_TEST_REGION = "us-central1" + + +class GeminiInference(unittest.TestCase): + @pytest.mark.gemini_postcommit + def test_gemini_text_classification(self): + output_file = '/'.join([_OUTPUT_DIR, str(uuid.uuid4()), 'output.txt']) + + test_pipeline = TestPipeline(is_integration_test=True) + extra_opts = { + 'output': output_file, + 'cloud_project': _TEST_PROJECT, + 'cloud_region': _TEST_REGION + } + gemini_text_classification.run( + test_pipeline.get_full_options_as_args(**extra_opts)) + self.assertEqual(FileSystems().exists(output_file), True) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.DEBUG) + unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py new file mode 100644 index 000000000000..bb6127a32872 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import unittest + +try: + from apache_beam.ml.inference.gemini_inference import _retry_on_appropriate_service_error + from apache_beam.ml.inference.gemini_inference import GeminiModelHandler + from apache_beam.ml.inference.gemini_inference import generate_from_string + from google.genai import errors +except ImportError: + raise unittest.SkipTest('Gemini dependencies are not installed') + + +class RetryOnClientErrorTest(unittest.TestCase): + def test_retry_on_client_error_positive(self): + e = errors.APIError(code=429, response_json={}) + self.assertTrue(_retry_on_appropriate_service_error(e)) + + def test_retry_on_client_error_negative(self): + e = errors.APIError(code=404, response_json={}) + self.assertFalse(_retry_on_appropriate_service_error(e)) + + def test_retry_on_server_error(self): + e = errors.APIError(code=501, response_json={}) + self.assertTrue(_retry_on_appropriate_service_error(e)) + + +class ModelHandlerArgConditions(unittest.TestCase): + def test_all_params_set(self): + self.assertRaises( + ValueError, + GeminiModelHandler, + model_name="gemini-model-123", + request_fn=generate_from_string, + api_key="123456789", + project="testproject", + location="us-central1", + ) + + def test_missing_vertex_location_param(self): + self.assertRaises( + ValueError, + GeminiModelHandler, + model_name="gemini-model-123", + request_fn=generate_from_string, + project="testproject", + ) + + def test_missing_vertex_project_param(self): + self.assertRaises( + ValueError, + GeminiModelHandler, + model_name="gemini-model-123", + request_fn=generate_from_string, + location="us-central1", + ) + + def test_missing_all_params(self): + self.assertRaises( + ValueError, + GeminiModelHandler, + model_name="gemini-model-123", + request_fn=generate_from_string, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/gemini_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/gemini_tests_requirements.txt new file mode 100644 index 000000000000..722ed40777b7 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/gemini_tests_requirements.txt @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +google-genai>=1.16.1 \ No newline at end of file diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini index b62a44aa25e7..2b53441927d9 100644 --- a/sdks/python/pytest.ini +++ b/sdks/python/pytest.ini @@ -69,6 +69,7 @@ markers = uses_testcontainer: tests that use testcontainers. uses_mock_api: tests that uses the mock API cluster. uses_feast: tests that uses feast in some way + gemini_postcommit: gemini postcommits that need additional deps. # Default timeout intended for unit tests. # If certain tests need a different value, please see the docs on how to diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index f1f7158165cd..c18b12654f78 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -509,6 +509,33 @@ task vertexAIInferenceTest { } } +// Vertex AI RunInference IT tests +task geminiInferenceTest { + dependsOn 'initializeForDataflowJob' + dependsOn ':sdks:python:sdist' + def requirementsFile = "${rootDir}/sdks/python/apache_beam/ml/inference/gemini_tests_requirements.txt" + doFirst { + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile" + } + } + doLast { + def testOpts = basicTestOpts + def argMap = [ + "test_opts": testOpts, + "suite": "GeminiTests-df-py${pythonVersionSuffix}", + "collect": "gemini_postcommit" , + "runner": "TestDataflowRunner", + "requirements_file": "$requirementsFile" + ] + def cmdArgs = mapToArgString(argMap) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs" + } + } +} task installTFTRequirements { dependsOn 'initializeForDataflowJob' @@ -576,6 +603,7 @@ project.tasks.register("inferencePostCommitIT") { // model is fixed. // 'tensorRTtests', 'vertexAIInferenceTest', + 'geminiInferenceTest', 'mockAPITests', ] }