From a43e59521b5ab67ee1ff9806988864567f5e7e7e Mon Sep 17 00:00:00 2001 From: anand Date: Tue, 3 Jun 2025 12:11:19 -0500 Subject: [PATCH 01/10] added retry logic for sync and async requests to snowflake api --- .../snowflake/hooks/snowflake_sql_api.py | 104 ++++++++++++++++-- .../snowflake/operators/snowflake.py | 3 + 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 7e4ef8dfe0288..6e84843c84e92 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -25,6 +25,7 @@ import aiohttp import requests +import tenacity from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -75,12 +76,23 @@ def __init__( snowflake_conn_id: str, token_life_time: timedelta = LIFETIME, token_renewal_delta: timedelta = RENEWAL_DELTA, + api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry *args: Any, **kwargs: Any, ): self.snowflake_conn_id = snowflake_conn_id self.token_life_time = token_life_time self.token_renewal_delta = token_renewal_delta + self.retry_config = { + "retry": tenacity.retry_if_exception(self._should_retry_on_error), + "wait": tenacity.wait_exponential(multiplier=1, min=1, max=60), + "stop": tenacity.stop_after_attempt(5), + "before_sleep": tenacity.before_sleep_log(self.log, logger_level=20), # INFO level + "reraise": True, + } + if api_retry_args: + self.retry_config.update(api_retry_args) + super().__init__(snowflake_conn_id, *args, **kwargs) self.private_key: Any = None @@ -168,9 +180,8 @@ def execute_query( "query_tag": query_tag, }, } - response = requests.post(url, json=data, headers=headers, params=params) try: - response.raise_for_status() + response = self._make_api_call_with_retries("POST", url, headers, params, data) except requests.exceptions.HTTPError as e: # pragma: no cover msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" raise AirflowException(msg) @@ -295,9 +306,7 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: """ self.log.info("Retrieving status for query id %s", query_id) header, params, url = self.get_request_url_header_params(query_id) - response = requests.get(url, params=params, headers=header) - status_code = response.status_code - resp = response.json() + status_code, resp = self._make_api_call_with_retries("GET", url, header, params) return self._process_response(status_code, resp) async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]: @@ -308,10 +317,83 @@ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | """ self.log.info("Retrieving status for query id %s", query_id) header, params, url = self.get_request_url_header_params(query_id) - async with ( - aiohttp.ClientSession(headers=header) as session, - session.get(url, params=params) as response, + status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params) + return self._process_response(status_code, resp) + + @staticmethod + def _should_retry_on_error(exception) -> bool: + """ + Determine if the exception should trigger a retry based on error type and status code. + + Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable), + and 504 (Gateway Timeout) as recommended by Snowflake error handling docs. + Retries on connection errors and timeouts. + + :param exception: The exception to check + :return: True if the request should be retried, False otherwise + """ + if isinstance(exception, (requests.exceptions.HTTPError, aiohttp.ClientResponseError)): + return exception.response.status_code in [429, 503, 504] + if isinstance( + exception, + (requests.exceptions.ConnectionError, requests.exceptions.Timeout, aiohttp.ClientConnectionError), ): - status_code = response.status - resp = await response.json() - return self._process_response(status_code, resp) + return True + return False + + def _make_api_call_with_retries(self, method, url, headers, params=None, data=None): + """ + Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors. + + Error handling implemented based on Snowflake error handling docs: + https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors + + :param method: The HTTP method to use for the API call. + :param url: The URL for the API endpoint. + :param headers: The headers to include in the API call. + :param params: (Optional) The query parameters to include in the API call. + :param data: (Optional) The data to include in the API call. + :return: The response object from the API call. + """ + + @tenacity.retry(**self.retry_config) # Use the retry args defined in constructor + def _make_request(): + if method.upper() == "GET": + response = requests.get(url, headers=headers, params=params) + elif method.upper() == "POST": + response = requests.post(url, headers=headers, params=params, json=data) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + response.raise_for_status() + return response.status_code, response.json() + + return _make_request() + + async def _make_api_call_with_retries_async(self, method, url, headers, params=None): + """ + Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors. + + Error handling implemented based on Snowflake error handling docs: + https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors + + :param method: The HTTP method to use for the API call. Only GET is supported as is synchronous. + :param url: The URL for the API endpoint. + :param headers: The headers to include in the API call. + :param params: (Optional) The query parameters to include in the API call. + :param data: (Optional) The data to include in the API call. + :return: The response object from the API call. + """ + + @tenacity.retry(**self.retry_config) + async def _make_request(): + async with aiohttp.ClientSession(headers=headers) as session: + if method.upper() == "GET": + async with session.get(url, params=params) as response: + response.raise_for_status() + # Return status and json content for async processing + content = await response.json() + return response.status, content + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + return await _make_request() diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index 2c8db1391e82b..21eda67132df9 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -381,6 +381,7 @@ def __init__( token_renewal_delta: timedelta = RENEWAL_DELTA, bindings: dict[str, Any] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + snowflake_api_retry_args: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.snowflake_conn_id = snowflake_conn_id @@ -390,6 +391,7 @@ def __init__( self.token_renewal_delta = token_renewal_delta self.bindings = bindings self.execute_async = False + self.snowflake_api_retry_args = snowflake_api_retry_args or {} self.deferrable = deferrable self.query_ids: list[str] = [] if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover @@ -412,6 +414,7 @@ def _hook(self): token_life_time=self.token_life_time, token_renewal_delta=self.token_renewal_delta, deferrable=self.deferrable, + api_retry_args=self.snowflake_api_retry_args, **self.hook_params, ) From 3a4c95fd168c73f1b8ff491772337ed88898c7b7 Mon Sep 17 00:00:00 2001 From: anand Date: Tue, 3 Jun 2025 16:35:18 -0500 Subject: [PATCH 02/10] first draft of unit tests --- .../snowflake/hooks/snowflake_sql_api.py | 26 ++- .../snowflake/hooks/test_snowflake_sql_api.py | 190 +++++++++++++++++- 2 files changed, 206 insertions(+), 10 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 6e84843c84e92..b968df7f30140 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -26,8 +26,10 @@ import aiohttp import requests import tenacity +from aiohttp import ClientConnectionError, ClientResponseError from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from requests.exceptions import ConnectionError, HTTPError, Timeout from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -87,7 +89,7 @@ def __init__( "retry": tenacity.retry_if_exception(self._should_retry_on_error), "wait": tenacity.wait_exponential(multiplier=1, min=1, max=60), "stop": tenacity.stop_after_attempt(5), - "before_sleep": tenacity.before_sleep_log(self.log, logger_level=20), # INFO level + "before_sleep": tenacity.before_sleep_log(self.log, log_level=20), # INFO level "reraise": True, } if api_retry_args: @@ -180,12 +182,8 @@ def execute_query( "query_tag": query_tag, }, } - try: - response = self._make_api_call_with_retries("POST", url, headers, params, data) - except requests.exceptions.HTTPError as e: # pragma: no cover - msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" - raise AirflowException(msg) - json_response = response.json() + + _, json_response = self._make_api_call_with_retries("POST", url, headers, params, data) self.log.info("Snowflake SQL POST API response: %s", json_response) if "statementHandles" in json_response: self.query_ids = json_response["statementHandles"] @@ -332,11 +330,21 @@ def _should_retry_on_error(exception) -> bool: :param exception: The exception to check :return: True if the request should be retried, False otherwise """ - if isinstance(exception, (requests.exceptions.HTTPError, aiohttp.ClientResponseError)): + if isinstance( + exception, + ( + HTTPError, + ClientResponseError, + ), + ): return exception.response.status_code in [429, 503, 504] if isinstance( exception, - (requests.exceptions.ConnectionError, requests.exceptions.Timeout, aiohttp.ClientConnectionError), + ( + ConnectionError, + Timeout, + ClientConnectionError, + ), ): return True return False diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 21a3fa7a999a8..ddceaa0bfb83c 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -332,7 +332,7 @@ def test_check_query_output(self, mock_geturl_header_params, mock_requests, quer hook.check_query_output(query_ids) mock_log_info.assert_called_with(GET_RESPONSE) - @pytest.mark.parametrize("query_ids", [(["uuid", "uuid1"])]) + @pytest.mark.parametrize("query_ids", [["uuid", "uuid1"]]) @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" @@ -623,6 +623,9 @@ def __init__(self, status_code, data): def json(self): return self.data + def raise_for_status(self): + return + mock_requests.get.return_value = MockResponse(status_code, response) hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") assert hook.get_sql_api_query_status("uuid") == expected_response @@ -813,3 +816,188 @@ def test_proper_parametrization_of_execute_query_api_request( hook.execute_query(sql, statement_count) mock_requests.post.assert_called_once_with(url, headers=HEADERS, json=expected_payload, params=params) + + @pytest.mark.parametrize( + "status_code,should_retry", + [ + (429, True), # Too Many Requests - should retry + (503, True), # Service Unavailable - should retry + (504, True), # Gateway Timeout - should retry + (500, False), # Internal Server Error - should not retry + (400, False), # Bad Request - should not retry + (401, False), # Unauthorized - should not retry + (404, False), # Not Found - should not retry + ], + ) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + def test_make_api_call_with_retries_http_errors(self, mock_requests, status_code, should_retry): + """ + Test that _make_api_call_with_retries method only retries on specific HTTP status codes. + Should retry on 429, 503, 504 but not on other error codes. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock failed response + failed_response = mock.MagicMock() + failed_response.status_code = status_code + failed_response.json.return_value = {"error": "test error"} + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) + + # Mock successful response for retries + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + if should_retry: + # For retryable errors, first call fails, second succeeds + mock_requests.get.side_effect = [failed_response, success_response] + status_code, resp_json = hook._make_api_call_with_retries( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.get.call_count == 2 + else: + # For non-retryable errors, should fail immediately + mock_requests.get.side_effect = [failed_response] + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + assert mock_requests.get.call_count == 1 + + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + def test_make_api_call_with_retries_connection_errors(self, mock_requests): + """ + Test that _make_api_call_with_retries method retries on connection errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock connection error then success + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + mock_requests.post.side_effect = [ + requests.exceptions.ConnectionError("Connection failed"), + success_response, + ] + + status_code, resp_json = hook._make_api_call_with_retries( + "POST", "https://test.snowflakecomputing.com/api/v2/statements", HEADERS, data={"test": "data"} + ) + + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.post.call_count == 2 + + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + def test_make_api_call_with_retries_timeout_errors(self, mock_requests): + """ + Test that _make_api_call_with_retries method retries on timeout errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock timeout error then success + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + mock_requests.get.side_effect = [requests.exceptions.Timeout("Request timed out"), success_response] + + status_code, resp_json = hook._make_api_call_with_retries( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.get.call_count == 2 + + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + def test_make_api_call_with_retries_max_attempts(self, mock_requests): + """ + Test that _make_api_call_with_retries method respects max retry attempts. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock response that always fails with retryable error + failed_response = mock.MagicMock() + failed_response.status_code = 429 + failed_response.json.return_value = {"error": "rate limited"} + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) + + mock_requests.get.side_effect = [failed_response] * 10 # More failures than max retries + + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + + # Should attempt 5 times (initial + 4 retries) based on default retry config + assert mock_requests.get.call_count == 5 + + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + def test_make_api_call_with_retries_success_no_retry(self, mock_requests): + """ + Test that _make_api_call_with_retries method doesn't retry on successful requests. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock successful response + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + mock_requests.post.return_value = success_response + + status_code, resp_json = hook._make_api_call_with_retries( + "POST", "https://test.snowflakecomputing.com/api/v2/statements", HEADERS, data={"test": "data"} + ) + + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.post.call_count == 1 + + def test_make_api_call_with_retries_unsupported_method(self): + """ + Test that _make_api_call_with_retries method raises ValueError for unsupported HTTP methods. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + with pytest.raises(ValueError, match="Unsupported HTTP method: PUT"): + hook._make_api_call_with_retries( + "PUT", "https://test.snowflakecomputing.com/api/v2/statements", HEADERS + ) + + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): + """ + Test that _make_api_call_with_retries method respects custom retry configuration. + """ + import tenacity + + # Create hook with custom retry config + custom_retry_args = { + "stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5 + } + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn", api_retry_args=custom_retry_args) + + # Mock response that always fails with retryable error + failed_response = mock.MagicMock() + failed_response.status_code = 503 + failed_response.json.return_value = {"error": "service unavailable"} + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) + + mock_requests.get.side_effect = [failed_response] * 3 + + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + + # Should attempt only 2 times due to custom config + assert mock_requests.get.call_count == 2 From 18b5c90ec78df6399c22c64bf2f47bfff211e32f Mon Sep 17 00:00:00 2001 From: anand Date: Thu, 5 Jun 2025 13:30:57 -0500 Subject: [PATCH 03/10] unit tests for hook and retries --- .../snowflake/hooks/snowflake_sql_api.py | 13 +- .../snowflake/hooks/test_snowflake_sql_api.py | 149 ++++++++++++++++++ 2 files changed, 155 insertions(+), 7 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index b968df7f30140..76459b33d3653 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -184,6 +184,9 @@ def execute_query( } _, json_response = self._make_api_call_with_retries("POST", url, headers, params, data) + # except requests.exceptions.HTTPError as e: # pragma: no cover + # msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" + # raise AirflowException(msg) self.log.info("Snowflake SQL POST API response: %s", json_response) if "statementHandles" in json_response: self.query_ids = json_response["statementHandles"] @@ -330,14 +333,10 @@ def _should_retry_on_error(exception) -> bool: :param exception: The exception to check :return: True if the request should be retried, False otherwise """ - if isinstance( - exception, - ( - HTTPError, - ClientResponseError, - ), - ): + if isinstance(exception, HTTPError): return exception.response.status_code in [429, 503, 504] + if isinstance(exception, ClientResponseError): + return exception.status in [429, 503, 504] if isinstance( exception, ( diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index ddceaa0bfb83c..576ca03981454 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -23,6 +23,7 @@ from unittest import mock from unittest.mock import AsyncMock, PropertyMock +import aiohttp import pytest import requests from cryptography.hazmat.backends import default_backend @@ -167,6 +168,35 @@ def create_post_side_effect(status_code=429): return response +def create_async_request_client_response_error(request_info=None, history=None, status_code=429): + """Create mock response for async request side effect""" + response = mock.MagicMock() + request_info = mock.MagicMock() if request_info is None else request_info + history = mock.MagicMock() if history is None else history + response.status = status_code + response.reason = f"{status_code} Error" + response.raise_for_status.side_effect = aiohttp.ClientResponseError( + request_info=request_info, history=history, status=status_code, message=response.reason + ) + return response + + +def create_async_connection_error(): + response = mock.MagicMock() + response.raise_for_status.side_effect = aiohttp.ClientConnectionError() + return response + + +def create_async_request_client_response_success(json=GET_RESPONSE, status_code=200): + """Create mock response for async request side effect""" + response = mock.MagicMock() + response.status = status_code + response.reason = "test" + response.json = AsyncMock(return_value=json) + response.raise_for_status.side_effect = None + return response + + class TestSnowflakeSqlApiHook: @pytest.mark.parametrize( "sql,statement_count,expected_response, expected_query_ids", @@ -1001,3 +1031,122 @@ def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): # Should attempt only 2 times due to custom config assert mock_requests.get.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_make_api_call_with_retries_async_success(self, mock_get): + """ + Test that _make_api_call_with_retries_async returns response on success. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + mock_response = create_async_request_client_response_success() + mock_get.return_value.__aenter__.return_value = mock_response + status_code, resp_json = await hook._make_api_call_with_retries_async( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + assert status_code == 200 + assert resp_json == GET_RESPONSE + assert mock_get.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_get): + """ + Test that _make_api_call_with_retries_async retries on retryable HTTP errors (429, 503, 504). + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # First response: 429, then 200 + mock_response_429 = create_async_request_client_response_error() + mock_response_200 = create_async_request_client_response_success() + # Side effect for request context manager + mock_get.return_value.__aenter__.side_effect = [ + mock_response_429, + mock_response_200, + ] + + status_code, resp_json = await hook._make_api_call_with_retries_async( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + assert status_code == 200 + assert resp_json == GET_RESPONSE + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_make_api_call_with_retries_async_non_retryable_http_error(self, mock_get): + """ + Test that _make_api_call_with_retries_async does not retry on non-retryable HTTP errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + mock_response_400 = create_async_request_client_response_error(status_code=400) + + mock_get.return_value.__aenter__.return_value = mock_response_400 + + with pytest.raises(aiohttp.ClientResponseError): + await hook._make_api_call_with_retries_async( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + assert mock_get.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_make_api_call_with_retries_async_connection_error(self, mock_get): + """ + Test that _make_api_call_with_retries_async retries on connection errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # First: connection error, then: success + failed_conn = create_async_connection_error() + + mock_request_200 = create_async_request_client_response_success() + + # Side effect for request context manager + mock_get.return_value.__aenter__.side_effect = [ + failed_conn, + mock_request_200, + ] + + status_code, resp_json = await hook._make_api_call_with_retries_async( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + assert status_code == 200 + assert resp_json == GET_RESPONSE + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_make_api_call_with_retries_async_max_attempts(self, mock_get): + """ + Test that _make_api_call_with_retries_async respects max retry attempts. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + mock_request_429 = create_async_request_client_response_error(status_code=429) + + # Always returns 429 + mock_get.return_value.__aenter__.side_effect = [mock_request_429] * 5 + + with pytest.raises(aiohttp.ClientResponseError): + await hook._make_api_call_with_retries_async( + "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + # Should attempt 5 times (default max retries) + assert mock_get.call_count == 5 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_make_api_call_with_retries_async_unsupported_method(self, mock_session): + """ + Test that _make_api_call_with_retries_async raises ValueError for unsupported HTTP methods. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + with pytest.raises(ValueError, match="Unsupported HTTP method: PATCH"): + await hook._make_api_call_with_retries_async( + "PATCH", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + ) + # No HTTP call should be made + assert mock_session.call_count == 0 From 488290c77da2932a266f0c33f181d1e60890297b Mon Sep 17 00:00:00 2001 From: anand Date: Thu, 5 Jun 2025 14:13:58 -0500 Subject: [PATCH 04/10] remove comment --- .../src/airflow/providers/snowflake/hooks/snowflake_sql_api.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 76459b33d3653..12617c89d0434 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -184,9 +184,6 @@ def execute_query( } _, json_response = self._make_api_call_with_retries("POST", url, headers, params, data) - # except requests.exceptions.HTTPError as e: # pragma: no cover - # msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" - # raise AirflowException(msg) self.log.info("Snowflake SQL POST API response: %s", json_response) if "statementHandles" in json_response: self.query_ids = json_response["statementHandles"] From d913a96bed9e085188e183efac98c2ba03dd86b2 Mon Sep 17 00:00:00 2001 From: anand Date: Tue, 10 Jun 2025 16:37:25 -0500 Subject: [PATCH 05/10] update sync request to use request.request --- .../snowflake/hooks/snowflake_sql_api.py | 23 ++- .../snowflake/hooks/test_snowflake_sql_api.py | 155 ++++++++++-------- 2 files changed, 101 insertions(+), 77 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 12617c89d0434..9873895bf746f 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -268,13 +268,10 @@ def check_query_output(self, query_ids: list[str]) -> None: """ for query_id in query_ids: header, params, url = self.get_request_url_header_params(query_id) - try: - response = requests.get(url, headers=header, params=params) - response.raise_for_status() - self.log.info(response.json()) - except requests.exceptions.HTTPError as e: - msg = f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}" - raise AirflowException(msg) + _, response_json = self._make_api_call_with_retries( + method="GET", url=url, headers=header, params=params + ) + self.log.info(response_json) def _process_response(self, status_code, resp): self.log.info("Snowflake SQL GET statements status API response: %s", resp) @@ -345,7 +342,9 @@ def _should_retry_on_error(exception) -> bool: return True return False - def _make_api_call_with_retries(self, method, url, headers, params=None, data=None): + def _make_api_call_with_retries( + self, method: str, url: str, headers: dict, params: dict = None, json: dict = None + ): """ Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors. @@ -362,10 +361,10 @@ def _make_api_call_with_retries(self, method, url, headers, params=None, data=No @tenacity.retry(**self.retry_config) # Use the retry args defined in constructor def _make_request(): - if method.upper() == "GET": - response = requests.get(url, headers=headers, params=params) - elif method.upper() == "POST": - response = requests.post(url, headers=headers, params=params, json=data) + if method.upper() in ("GET", "POST"): + response = requests.request( + method=method.lower(), url=url, headers=headers, params=params, json=json + ) else: raise ValueError(f"Unsupported HTTP method: {method}") response.raise_for_status() diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 576ca03981454..edbf2a1df6aa8 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -21,15 +21,15 @@ import uuid from typing import TYPE_CHECKING, Any from unittest import mock -from unittest.mock import AsyncMock, PropertyMock +from unittest.mock import AsyncMock, PropertyMock, call import aiohttp import pytest import requests +import tenacity from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from responses import RequestsMock from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection @@ -150,6 +150,8 @@ "role": "airflow_role", } +API_URL = "https://test.snowflakecomputing.com/api/v2/statements/test" + def create_successful_response_mock(content): """Create mock response for success state""" @@ -223,11 +225,11 @@ def test_execute_query( ): """Test execute_query method, run query by mocking post request method and return the query ids""" mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock + type(mock_requests.request.return_value).status_code = status_code_mock hook = SnowflakeSqlApiHook("mock_conn_id") query_ids = hook.execute_query(sql, statement_count) @@ -298,7 +300,7 @@ def test_execute_query_exception_without_statement_handle( without statementHandle in the response """ side_effect = create_post_side_effect() - mock_requests.post.side_effect = side_effect + mock_requests.request.side_effect = side_effect hook = SnowflakeSqlApiHook("mock_conn_id") with pytest.raises(AirflowException) as exception_info: @@ -329,7 +331,7 @@ def test_execute_query_bindings_warning( """Test execute_query method logs warning when bindings are provided for multi-statement queries""" mock_conn_params.return_value = CONN_PARAMS mock_get_headers.return_value = HEADERS - mock_requests.post.return_value = create_successful_response_mock( + mock_requests.request.return_value = create_successful_response_mock( {"statementHandles": ["uuid", "uuid1"]} ) @@ -356,18 +358,19 @@ def test_check_query_output(self, mock_geturl_header_params, mock_requests, quer req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/" - mock_requests.get.return_value.json.return_value = GET_RESPONSE + mock_requests.request.return_value.json.return_value = GET_RESPONSE hook = SnowflakeSqlApiHook("mock_conn_id") with mock.patch.object(hook.log, "info") as mock_log_info: hook.check_query_output(query_ids) mock_log_info.assert_called_with(GET_RESPONSE) @pytest.mark.parametrize("query_ids", [["uuid", "uuid1"]]) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - def test_check_query_output_exception(self, mock_geturl_header_params, query_ids): + def test_check_query_output_exception(self, mock_geturl_header_params, mock_request, query_ids): """ Test check_query_output by passing query ids as params and mock get_request_url_header_params to raise airflow exception and mock with http error @@ -375,11 +378,13 @@ def test_check_query_output_exception(self, mock_geturl_header_params, query_ids req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "https://test/airflow/" - hook = SnowflakeSqlApiHook("mock_conn_id") - with mock.patch.object(hook.log, "error"), RequestsMock() as requests_mock: - requests_mock.get(url="https://test/airflow/", json={"foo": "bar"}, status=500) - with pytest.raises(AirflowException, match='Response: {"foo": "bar"}, Status Code: 500'): - hook.check_query_output(query_ids) + custom_retry_args = { + "stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5 + } + hook = SnowflakeSqlApiHook("mock_conn_id", api_retry_args=custom_retry_args) + mock_request.request.side_effect = [create_post_side_effect(status_code=500)] * 3 + with pytest.raises(requests.exceptions.HTTPError): + hook.check_query_output(query_ids) @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", @@ -656,7 +661,7 @@ def json(self): def raise_for_status(self): return - mock_requests.get.return_value = MockResponse(status_code, response) + mock_requests.request.return_value = MockResponse(status_code, response) hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") assert hook.get_sql_api_query_status("uuid") == expected_response @@ -834,7 +839,7 @@ def test_proper_parametrization_of_execute_query_api_request( mock_conn_param.return_value = CONN_PARAMS mock_get_headers.return_value = HEADERS mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] status_code_mock = mock.PropertyMock(return_value=200) @@ -845,7 +850,9 @@ def test_proper_parametrization_of_execute_query_api_request( hook.execute_query(sql, statement_count) - mock_requests.post.assert_called_once_with(url, headers=HEADERS, json=expected_payload, params=params) + mock_requests.request.assert_called_once_with( + method="post", url=url, headers=HEADERS, json=expected_payload, params=params + ) @pytest.mark.parametrize( "status_code,should_retry", @@ -881,21 +888,50 @@ def test_make_api_call_with_retries_http_errors(self, mock_requests, status_code if should_retry: # For retryable errors, first call fails, second succeeds - mock_requests.get.side_effect = [failed_response, success_response] + mock_requests.request.side_effect = [failed_response, success_response] status_code, resp_json = hook._make_api_call_with_retries( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + method="GET", + url=API_URL, + headers=HEADERS, ) assert status_code == 200 assert resp_json == {"statementHandle": "uuid"} - assert mock_requests.get.call_count == 2 + assert mock_requests.request.call_count == 2 + mock_requests.request.assert_has_calls( + [ + call( + method="get", + json=None, + url=API_URL, + params=None, + headers=HEADERS, + ), + call( + method="get", + json=None, + url=API_URL, + params=None, + headers=HEADERS, + ), + ] + ) else: # For non-retryable errors, should fail immediately - mock_requests.get.side_effect = [failed_response] + mock_requests.request.side_effect = [failed_response] with pytest.raises(requests.exceptions.HTTPError): hook._make_api_call_with_retries( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS + method="GET", + url=API_URL, + headers=HEADERS, ) - assert mock_requests.get.call_count == 1 + assert mock_requests.request.call_count == 1 + mock_requests.request.assert_called_with( + method="get", + json=None, + url=API_URL, + params=None, + headers=HEADERS, + ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_connection_errors(self, mock_requests): @@ -910,18 +946,25 @@ def test_make_api_call_with_retries_connection_errors(self, mock_requests): success_response.json.return_value = {"statementHandle": "uuid"} success_response.raise_for_status.return_value = None - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ requests.exceptions.ConnectionError("Connection failed"), success_response, ] status_code, resp_json = hook._make_api_call_with_retries( - "POST", "https://test.snowflakecomputing.com/api/v2/statements", HEADERS, data={"test": "data"} + "POST", API_URL, HEADERS, json={"test": "data"} ) assert status_code == 200 + mock_requests.request.assert_called_with( + method="post", + url=API_URL, + params=None, + headers=HEADERS, + json={"test": "data"}, + ) assert resp_json == {"statementHandle": "uuid"} - assert mock_requests.post.call_count == 2 + assert mock_requests.request.call_count == 2 @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_timeout_errors(self, mock_requests): @@ -936,15 +979,16 @@ def test_make_api_call_with_retries_timeout_errors(self, mock_requests): success_response.json.return_value = {"statementHandle": "uuid"} success_response.raise_for_status.return_value = None - mock_requests.get.side_effect = [requests.exceptions.Timeout("Request timed out"), success_response] + mock_requests.request.side_effect = [ + requests.exceptions.Timeout("Request timed out"), + success_response, + ] - status_code, resp_json = hook._make_api_call_with_retries( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + status_code, resp_json = hook._make_api_call_with_retries("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == {"statementHandle": "uuid"} - assert mock_requests.get.call_count == 2 + assert mock_requests.request.call_count == 2 @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_max_attempts(self, mock_requests): @@ -959,15 +1003,13 @@ def test_make_api_call_with_retries_max_attempts(self, mock_requests): failed_response.json.return_value = {"error": "rate limited"} failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) - mock_requests.get.side_effect = [failed_response] * 10 # More failures than max retries + mock_requests.request.side_effect = [failed_response] * 10 # More failures than max retries with pytest.raises(requests.exceptions.HTTPError): - hook._make_api_call_with_retries( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + hook._make_api_call_with_retries("GET", API_URL, HEADERS) # Should attempt 5 times (initial + 4 retries) based on default retry config - assert mock_requests.get.call_count == 5 + assert mock_requests.request.call_count == 5 @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_success_no_retry(self, mock_requests): @@ -982,15 +1024,15 @@ def test_make_api_call_with_retries_success_no_retry(self, mock_requests): success_response.json.return_value = {"statementHandle": "uuid"} success_response.raise_for_status.return_value = None - mock_requests.post.return_value = success_response + mock_requests.request.return_value = success_response status_code, resp_json = hook._make_api_call_with_retries( - "POST", "https://test.snowflakecomputing.com/api/v2/statements", HEADERS, data={"test": "data"} + "POST", API_URL, HEADERS, json={"test": "data"} ) assert status_code == 200 assert resp_json == {"statementHandle": "uuid"} - assert mock_requests.post.call_count == 1 + assert mock_requests.request.call_count == 1 def test_make_api_call_with_retries_unsupported_method(self): """ @@ -999,16 +1041,13 @@ def test_make_api_call_with_retries_unsupported_method(self): hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") with pytest.raises(ValueError, match="Unsupported HTTP method: PUT"): - hook._make_api_call_with_retries( - "PUT", "https://test.snowflakecomputing.com/api/v2/statements", HEADERS - ) + hook._make_api_call_with_retries("PUT", API_URL, HEADERS) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): """ Test that _make_api_call_with_retries method respects custom retry configuration. """ - import tenacity # Create hook with custom retry config custom_retry_args = { @@ -1022,15 +1061,13 @@ def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): failed_response.json.return_value = {"error": "service unavailable"} failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) - mock_requests.get.side_effect = [failed_response] * 3 + mock_requests.request.side_effect = [failed_response] * 3 with pytest.raises(requests.exceptions.HTTPError): - hook._make_api_call_with_retries( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + hook._make_api_call_with_retries("GET", API_URL, HEADERS) # Should attempt only 2 times due to custom config - assert mock_requests.get.call_count == 2 + assert mock_requests.request.call_count == 2 @pytest.mark.asyncio @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") @@ -1042,9 +1079,7 @@ async def test_make_api_call_with_retries_async_success(self, mock_get): mock_response = create_async_request_client_response_success() mock_get.return_value.__aenter__.return_value = mock_response - status_code, resp_json = await hook._make_api_call_with_retries_async( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == GET_RESPONSE assert mock_get.call_count == 1 @@ -1066,9 +1101,7 @@ async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_ mock_response_200, ] - status_code, resp_json = await hook._make_api_call_with_retries_async( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == GET_RESPONSE assert mock_get.call_count == 2 @@ -1086,9 +1119,7 @@ async def test_make_api_call_with_retries_async_non_retryable_http_error(self, m mock_get.return_value.__aenter__.return_value = mock_response_400 with pytest.raises(aiohttp.ClientResponseError): - await hook._make_api_call_with_retries_async( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert mock_get.call_count == 1 @pytest.mark.asyncio @@ -1110,9 +1141,7 @@ async def test_make_api_call_with_retries_async_connection_error(self, mock_get) mock_request_200, ] - status_code, resp_json = await hook._make_api_call_with_retries_async( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == GET_RESPONSE assert mock_get.call_count == 2 @@ -1130,9 +1159,7 @@ async def test_make_api_call_with_retries_async_max_attempts(self, mock_get): mock_get.return_value.__aenter__.side_effect = [mock_request_429] * 5 with pytest.raises(aiohttp.ClientResponseError): - await hook._make_api_call_with_retries_async( - "GET", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) # Should attempt 5 times (default max retries) assert mock_get.call_count == 5 @@ -1145,8 +1172,6 @@ async def test_make_api_call_with_retries_async_unsupported_method(self, mock_se hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") with pytest.raises(ValueError, match="Unsupported HTTP method: PATCH"): - await hook._make_api_call_with_retries_async( - "PATCH", "https://test.snowflakecomputing.com/api/v2/statements/test", HEADERS - ) + await hook._make_api_call_with_retries_async("PATCH", API_URL, HEADERS) # No HTTP call should be made assert mock_session.call_count == 0 From 6df437b7a6ac9dfb86f5fc9c5f7be28d235bf1e6 Mon Sep 17 00:00:00 2001 From: anand Date: Tue, 10 Jun 2025 17:02:48 -0500 Subject: [PATCH 06/10] mypy fixes --- .../airflow/providers/snowflake/hooks/snowflake_sql_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 9873895bf746f..f9d3888e4e657 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -343,7 +343,7 @@ def _should_retry_on_error(exception) -> bool: return False def _make_api_call_with_retries( - self, method: str, url: str, headers: dict, params: dict = None, json: dict = None + self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None ): """ Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors. @@ -359,7 +359,7 @@ def _make_api_call_with_retries( :return: The response object from the API call. """ - @tenacity.retry(**self.retry_config) # Use the retry args defined in constructor + @tenacity.retry(**self.retry_config) # type: ignore def _make_request(): if method.upper() in ("GET", "POST"): response = requests.request( @@ -387,7 +387,7 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N :return: The response object from the API call. """ - @tenacity.retry(**self.retry_config) + @tenacity.retry(**self.retry_config) # type: ignore async def _make_request(): async with aiohttp.ClientSession(headers=headers) as session: if method.upper() == "GET": From f710a1c5701babbd3e4e00330a2c524e5c9f65ce Mon Sep 17 00:00:00 2001 From: anand Date: Wed, 18 Jun 2025 10:05:45 -0500 Subject: [PATCH 07/10] updated sync and async api call methods to use tenacity context manager --- .../snowflake/hooks/snowflake_sql_api.py | 67 ++++++------ .../snowflake/hooks/test_snowflake_sql_api.py | 103 +++++++++--------- 2 files changed, 88 insertions(+), 82 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index f9d3888e4e657..ded3bee463d5e 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -25,11 +25,18 @@ import aiohttp import requests -import tenacity from aiohttp import ClientConnectionError, ClientResponseError from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from requests.exceptions import ConnectionError, HTTPError, Timeout +from tenacity import ( + AsyncRetrying, + Retrying, + before_sleep_log, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -86,10 +93,10 @@ def __init__( self.token_life_time = token_life_time self.token_renewal_delta = token_renewal_delta self.retry_config = { - "retry": tenacity.retry_if_exception(self._should_retry_on_error), - "wait": tenacity.wait_exponential(multiplier=1, min=1, max=60), - "stop": tenacity.stop_after_attempt(5), - "before_sleep": tenacity.before_sleep_log(self.log, log_level=20), # INFO level + "retry": retry_if_exception(self._should_retry_on_error), + "wait": wait_exponential(multiplier=1, min=1, max=60), + "stop": stop_after_attempt(5), + "before_sleep": before_sleep_log(self.log, log_level=20), # INFO level "reraise": True, } if api_retry_args: @@ -358,19 +365,17 @@ def _make_api_call_with_retries( :param data: (Optional) The data to include in the API call. :return: The response object from the API call. """ - - @tenacity.retry(**self.retry_config) # type: ignore - def _make_request(): - if method.upper() in ("GET", "POST"): - response = requests.request( - method=method.lower(), url=url, headers=headers, params=params, json=json - ) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - response.raise_for_status() - return response.status_code, response.json() - - return _make_request() + with requests.Session() as session: + for attempt in Retrying(**self.retry_config): # type: ignore + with attempt: + if method.upper() in ("GET", "POST"): + response = session.request( + method=method.lower(), url=url, headers=headers, params=params, json=json + ) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + response.raise_for_status() + return response.status_code, response.json() async def _make_api_call_with_retries_async(self, method, url, headers, params=None): """ @@ -383,20 +388,16 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N :param url: The URL for the API endpoint. :param headers: The headers to include in the API call. :param params: (Optional) The query parameters to include in the API call. - :param data: (Optional) The data to include in the API call. :return: The response object from the API call. """ - - @tenacity.retry(**self.retry_config) # type: ignore - async def _make_request(): - async with aiohttp.ClientSession(headers=headers) as session: - if method.upper() == "GET": - async with session.get(url, params=params) as response: - response.raise_for_status() - # Return status and json content for async processing - content = await response.json() - return response.status, content - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - return await _make_request() + async with aiohttp.ClientSession(headers=headers) as session: + async for attempt in AsyncRetrying(**self.retry_config): # type: ignore + with attempt: + if method.upper() == "GET": + async with session.request(method=method.lower(), url=url, params=params) as response: + response.raise_for_status() + # Return status and json content for async processing + content = await response.json() + return response.status, content + else: + raise ValueError(f"Unsupported HTTP method: {method}") diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index edbf2a1df6aa8..8f34d751837d1 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -153,6 +153,26 @@ API_URL = "https://test.snowflakecomputing.com/api/v2/statements/test" +@pytest.fixture +def mock_requests(): + with mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.requests.Session" + ) as mock_session_cls: + mock_session = mock.MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + yield mock_session + + +@pytest.fixture +def mock_async_request(): + with mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.request" + ) as mock_session_cls: + mock_request = mock.MagicMock() + mock_session_cls.return_value = mock_request + yield mock_request + + def create_successful_response_mock(content): """Create mock response for success state""" response = mock.MagicMock() @@ -207,7 +227,6 @@ class TestSnowflakeSqlApiHook: (SQL_MULTIPLE_STMTS, 4, {"statementHandles": ["uuid", "uuid1"]}, ["uuid", "uuid1"]), ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -217,11 +236,11 @@ def test_execute_query( self, mock_get_header, mock_conn_param, - mock_requests, sql, statement_count, expected_response, expected_query_ids, + mock_requests, ): """Test execute_query method, run query by mocking post request method and return the query ids""" mock_requests.codes.ok = 200 @@ -279,7 +298,6 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( "sql,statement_count,expected_response, expected_query_ids", [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -289,11 +307,11 @@ def test_execute_query_exception_without_statement_handle( self, mock_get_header, mock_conn_param, - mock_requests, sql, statement_count, expected_response, expected_query_ids, + mock_requests, ): """ Test execute_query method by mocking the exception response and raise airflow exception @@ -313,7 +331,6 @@ def test_execute_query_exception_without_statement_handle( (SQL_MULTIPLE_STMTS, 4, {"1": {"type": "FIXED", "value": "123"}}), ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -323,10 +340,10 @@ def test_execute_query_bindings_warning( self, mock_get_headers, mock_conn_params, - mock_requests, sql, statement_count, bindings, + mock_requests, ): """Test execute_query method logs warning when bindings are provided for multi-statement queries""" mock_conn_params.return_value = CONN_PARAMS @@ -348,12 +365,11 @@ def test_execute_query_bindings_warning( (["uuid", "uuid1"]), ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - def test_check_query_output(self, mock_geturl_header_params, mock_requests, query_ids): + def test_check_query_output(self, mock_geturl_header_params, query_ids, mock_requests): """Test check_query_output by passing query ids as params and mock get_request_url_header_params""" req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} @@ -365,12 +381,16 @@ def test_check_query_output(self, mock_geturl_header_params, mock_requests, quer mock_log_info.assert_called_with(GET_RESPONSE) @pytest.mark.parametrize("query_ids", [["uuid", "uuid1"]]) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - def test_check_query_output_exception(self, mock_geturl_header_params, mock_request, query_ids): + def test_check_query_output_exception( + self, + mock_geturl_header_params, + query_ids, + mock_requests, + ): """ Test check_query_output by passing query ids as params and mock get_request_url_header_params to raise airflow exception and mock with http error @@ -382,7 +402,7 @@ def test_check_query_output_exception(self, mock_geturl_header_params, mock_requ "stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5 } hook = SnowflakeSqlApiHook("mock_conn_id", api_retry_args=custom_retry_args) - mock_request.request.side_effect = [create_post_side_effect(status_code=500)] * 3 + mock_requests.request.side_effect = [create_post_side_effect(status_code=500)] * 3 with pytest.raises(requests.exceptions.HTTPError): hook.check_query_output(query_ids) @@ -640,9 +660,8 @@ def test_get_private_key_should_support_private_auth_with_unencrypted_key( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_get_sql_api_query_status( - self, mock_requests, mock_geturl_header_params, status_code, response, expected_response + self, mock_geturl_header_params, status_code, response, expected_response, mock_requests ): """Test get_sql_api_query_status function by mocking the status, response and expected response""" @@ -704,17 +723,16 @@ def raise_for_status(self): "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") async def test_get_sql_api_query_status_async( - self, mock_get, mock_geturl_header_params, status_code, response, expected_response + self, mock_geturl_header_params, status_code, response, expected_response, mock_async_request ): """Test Async get_sql_api_query_status_async function by mocking the status, response and expected response""" req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/" - mock_get.return_value.__aenter__.return_value.status = status_code - mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=response) + mock_async_request.__aenter__.return_value.status = status_code + mock_async_request.__aenter__.return_value.json = AsyncMock(return_value=response) hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") response = await hook.get_sql_api_query_status_async("uuid") assert response == expected_response @@ -812,7 +830,6 @@ def test_hook_parameter_propagation(self, hook_params): ], ) @mock.patch("uuid.uuid4") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -822,13 +839,13 @@ def test_proper_parametrization_of_execute_query_api_request( self, mock_get_headers, mock_conn_param, - mock_requests, mock_uuid, test_hook_params, sql, statement_count, expected_payload, expected_response, + mock_requests, ): """ This tests if the query execution ordered by POST request to Snowflake API @@ -866,8 +883,7 @@ def test_proper_parametrization_of_execute_query_api_request( (404, False), # Not Found - should not retry ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") - def test_make_api_call_with_retries_http_errors(self, mock_requests, status_code, should_retry): + def test_make_api_call_with_retries_http_errors(self, status_code, should_retry, mock_requests): """ Test that _make_api_call_with_retries method only retries on specific HTTP status codes. Should retry on 429, 503, 504 but not on other error codes. @@ -933,7 +949,6 @@ def test_make_api_call_with_retries_http_errors(self, mock_requests, status_code headers=HEADERS, ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_connection_errors(self, mock_requests): """ Test that _make_api_call_with_retries method retries on connection errors. @@ -966,7 +981,6 @@ def test_make_api_call_with_retries_connection_errors(self, mock_requests): assert resp_json == {"statementHandle": "uuid"} assert mock_requests.request.call_count == 2 - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_timeout_errors(self, mock_requests): """ Test that _make_api_call_with_retries method retries on timeout errors. @@ -990,7 +1004,6 @@ def test_make_api_call_with_retries_timeout_errors(self, mock_requests): assert resp_json == {"statementHandle": "uuid"} assert mock_requests.request.call_count == 2 - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_max_attempts(self, mock_requests): """ Test that _make_api_call_with_retries method respects max retry attempts. @@ -1011,7 +1024,6 @@ def test_make_api_call_with_retries_max_attempts(self, mock_requests): # Should attempt 5 times (initial + 4 retries) based on default retry config assert mock_requests.request.call_count == 5 - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_success_no_retry(self, mock_requests): """ Test that _make_api_call_with_retries method doesn't retry on successful requests. @@ -1043,7 +1055,6 @@ def test_make_api_call_with_retries_unsupported_method(self): with pytest.raises(ValueError, match="Unsupported HTTP method: PUT"): hook._make_api_call_with_retries("PUT", API_URL, HEADERS) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): """ Test that _make_api_call_with_retries method respects custom retry configuration. @@ -1070,23 +1081,21 @@ def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): assert mock_requests.request.call_count == 2 @pytest.mark.asyncio - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") - async def test_make_api_call_with_retries_async_success(self, mock_get): + async def test_make_api_call_with_retries_async_success(self, mock_async_request): """ Test that _make_api_call_with_retries_async returns response on success. """ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") mock_response = create_async_request_client_response_success() - mock_get.return_value.__aenter__.return_value = mock_response + mock_async_request.__aenter__.return_value = mock_response status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == GET_RESPONSE - assert mock_get.call_count == 1 + assert mock_async_request.__aenter__.call_count == 1 @pytest.mark.asyncio - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") - async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_get): + async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_async_request): """ Test that _make_api_call_with_retries_async retries on retryable HTTP errors (429, 503, 504). """ @@ -1096,7 +1105,7 @@ async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_ mock_response_429 = create_async_request_client_response_error() mock_response_200 = create_async_request_client_response_success() # Side effect for request context manager - mock_get.return_value.__aenter__.side_effect = [ + mock_async_request.__aenter__.side_effect = [ mock_response_429, mock_response_200, ] @@ -1104,11 +1113,10 @@ async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_ status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == GET_RESPONSE - assert mock_get.call_count == 2 + assert mock_async_request.__aenter__.call_count == 2 @pytest.mark.asyncio - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") - async def test_make_api_call_with_retries_async_non_retryable_http_error(self, mock_get): + async def test_make_api_call_with_retries_async_non_retryable_http_error(self, mock_async_request): """ Test that _make_api_call_with_retries_async does not retry on non-retryable HTTP errors. """ @@ -1116,15 +1124,14 @@ async def test_make_api_call_with_retries_async_non_retryable_http_error(self, m mock_response_400 = create_async_request_client_response_error(status_code=400) - mock_get.return_value.__aenter__.return_value = mock_response_400 + mock_async_request.__aenter__.return_value = mock_response_400 with pytest.raises(aiohttp.ClientResponseError): await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) - assert mock_get.call_count == 1 + assert mock_async_request.__aenter__.call_count == 1 @pytest.mark.asyncio - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") - async def test_make_api_call_with_retries_async_connection_error(self, mock_get): + async def test_make_api_call_with_retries_async_connection_error(self, mock_async_request): """ Test that _make_api_call_with_retries_async retries on connection errors. """ @@ -1136,7 +1143,7 @@ async def test_make_api_call_with_retries_async_connection_error(self, mock_get) mock_request_200 = create_async_request_client_response_success() # Side effect for request context manager - mock_get.return_value.__aenter__.side_effect = [ + mock_async_request.__aenter__.side_effect = [ failed_conn, mock_request_200, ] @@ -1144,11 +1151,10 @@ async def test_make_api_call_with_retries_async_connection_error(self, mock_get) status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) assert status_code == 200 assert resp_json == GET_RESPONSE - assert mock_get.call_count == 2 + assert mock_async_request.__aenter__.call_count == 2 @pytest.mark.asyncio - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") - async def test_make_api_call_with_retries_async_max_attempts(self, mock_get): + async def test_make_api_call_with_retries_async_max_attempts(self, mock_async_request): """ Test that _make_api_call_with_retries_async respects max retry attempts. """ @@ -1156,16 +1162,15 @@ async def test_make_api_call_with_retries_async_max_attempts(self, mock_get): mock_request_429 = create_async_request_client_response_error(status_code=429) # Always returns 429 - mock_get.return_value.__aenter__.side_effect = [mock_request_429] * 5 + mock_async_request.__aenter__.side_effect = [mock_request_429] * 5 with pytest.raises(aiohttp.ClientResponseError): await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) # Should attempt 5 times (default max retries) - assert mock_get.call_count == 5 + assert mock_async_request.__aenter__.call_count == 5 @pytest.mark.asyncio - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") - async def test_make_api_call_with_retries_async_unsupported_method(self, mock_session): + async def test_make_api_call_with_retries_async_unsupported_method(self, mock_async_request): """ Test that _make_api_call_with_retries_async raises ValueError for unsupported HTTP methods. """ @@ -1174,4 +1179,4 @@ async def test_make_api_call_with_retries_async_unsupported_method(self, mock_se with pytest.raises(ValueError, match="Unsupported HTTP method: PATCH"): await hook._make_api_call_with_retries_async("PATCH", API_URL, HEADERS) # No HTTP call should be made - assert mock_session.call_count == 0 + assert mock_async_request.__aenter__.call_count == 0 From ebe29c649f2f59436f4f1c179f2ddcb63e118e77 Mon Sep 17 00:00:00 2001 From: anand Date: Wed, 18 Jun 2025 11:04:41 -0500 Subject: [PATCH 08/10] update unit test with correct method --- .../tests/unit/snowflake/hooks/test_snowflake_sql_api.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 8f34d751837d1..410588a7a71bb 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -254,7 +254,6 @@ def test_execute_query( query_ids = hook.execute_query(sql, statement_count) assert query_ids == expected_query_ids - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -272,11 +271,11 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( ) mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock + type(mock_requests.request.return_value).status_code = status_code_mock hook = SnowflakeSqlApiHook("mock_conn_id") query_ids = hook.execute_query(sql, statement_count) @@ -288,7 +287,7 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( {"statementHandle": "uuid"}, ["uuid"], ) - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] query_ids = hook.execute_query(sql, statement_count) From bb87d361a293d132a3c661242230e359c8317f97 Mon Sep 17 00:00:00 2001 From: anand Date: Wed, 18 Jun 2025 12:14:57 -0500 Subject: [PATCH 09/10] retry args docs --- .../src/airflow/providers/snowflake/hooks/snowflake_sql_api.py | 1 + .../src/airflow/providers/snowflake/operators/snowflake.py | 1 + 2 files changed, 2 insertions(+) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index ded3bee463d5e..55928638da5c3 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -75,6 +75,7 @@ class SnowflakeSqlApiHook(SnowflakeHook): :param token_life_time: lifetime of the JWT Token in timedelta :param token_renewal_delta: Renewal time of the JWT Token in timedelta :param deferrable: Run operator in the deferrable mode. + :param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes. """ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index 21eda67132df9..48086e53cd8c7 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -355,6 +355,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator): When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. :param deferrable: Run operator in the deferrable mode. + :param snowflake_api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes. """ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime From 088fc172977fe78419f3c598a6b48b28992b80b3 Mon Sep 17 00:00:00 2001 From: anand Date: Wed, 18 Jun 2025 14:03:45 -0500 Subject: [PATCH 10/10] reorder so self.log is initialized --- .../airflow/providers/snowflake/hooks/snowflake_sql_api.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 55928638da5c3..185a49ffaba06 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -93,6 +93,10 @@ def __init__( self.snowflake_conn_id = snowflake_conn_id self.token_life_time = token_life_time self.token_renewal_delta = token_renewal_delta + + super().__init__(snowflake_conn_id, *args, **kwargs) + self.private_key: Any = None + self.retry_config = { "retry": retry_if_exception(self._should_retry_on_error), "wait": wait_exponential(multiplier=1, min=1, max=60), @@ -103,9 +107,6 @@ def __init__( if api_retry_args: self.retry_config.update(api_retry_args) - super().__init__(snowflake_conn_id, *args, **kwargs) - self.private_key: Any = None - def get_private_key(self) -> None: """Get the private key from snowflake connection.""" conn = self.get_connection(self.snowflake_conn_id)