From 6ff743944f6be8235db8b9c1d2dada31e3410a52 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Fri, 27 Jun 2025 10:46:58 +0100 Subject: [PATCH 1/2] Attempt2: Fix mypy in gcp generative_model --- .../google/cloud/hooks/vertex_ai/generative_model.py | 11 ++++++----- .../cloud/operators/vertex_ai/generative_model.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index 3ed03b66b63a2..be15a282905f0 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -35,8 +35,9 @@ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook if TYPE_CHECKING: - from google.cloud.aiplatform_v1 import types as types_v1 from google.cloud.aiplatform_v1beta1 import types as types_v1beta1 + from vertexai.generative_models._generative_models import _GenerativeModel + from vertexai.tuning._supervised_tuning import SupervisedTuningJob class GenerativeModelHook(GoogleBaseHook): @@ -50,7 +51,7 @@ def get_text_embedding_model(self, pretrained_model: str): def get_generative_model( self, pretrained_model: str, - system_instruction: str | None = None, + system_instruction: Any | None = None, generation_config: dict | None = None, safety_settings: dict | None = None, tools: list | None = None, @@ -82,7 +83,7 @@ def get_eval_task( def get_cached_context_model( self, cached_content_name: str, - ) -> preview_generative_model: + ) -> _GenerativeModel: """Return a Generative Model with Cached Context.""" cached_content = CachedContent(cached_content_name=cached_content_name) @@ -167,7 +168,7 @@ def supervised_fine_tuning_train( adapter_size: Literal[1, 4, 8, 16] | None = None, learning_rate_multiplier: float | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> types_v1.TuningJob: + ) -> SupervisedTuningJob: """ Use the Supervised Fine Tuning API to create a tuning job. @@ -300,7 +301,7 @@ def create_cached_content( model_name: str, location: str, ttl_hours: float = 1, - system_instruction: str | None = None, + system_instruction: Any | None = None, contents: list[Any] | None = None, display_name: str | None = None, project_id: str = PROVIDE_PROJECT_ID, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 686e1674078ac..20257dad196c8 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -473,7 +473,7 @@ def __init__( project_id: str, location: str, model_name: str, - system_instruction: str | None = None, + system_instruction: Any | None = None, contents: list[Any] | None = None, ttl_hours: float = 1, display_name: str | None = None, From d99748af7cb8793f49e72ac359f826fcb2c7bbac Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Fri, 27 Jun 2025 11:26:42 +0100 Subject: [PATCH 2/2] Remove private class imports --- .../google/cloud/hooks/vertex_ai/generative_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index be15a282905f0..871ab81d10d1e 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -36,8 +36,6 @@ if TYPE_CHECKING: from google.cloud.aiplatform_v1beta1 import types as types_v1beta1 - from vertexai.generative_models._generative_models import _GenerativeModel - from vertexai.tuning._supervised_tuning import SupervisedTuningJob class GenerativeModelHook(GoogleBaseHook): @@ -83,7 +81,7 @@ def get_eval_task( def get_cached_context_model( self, cached_content_name: str, - ) -> _GenerativeModel: + ) -> Any: """Return a Generative Model with Cached Context.""" cached_content = CachedContent(cached_content_name=cached_content_name) @@ -168,7 +166,7 @@ def supervised_fine_tuning_train( adapter_size: Literal[1, 4, 8, 16] | None = None, learning_rate_multiplier: float | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> SupervisedTuningJob: + ) -> Any: """ Use the Supervised Fine Tuning API to create a tuning job.