From 654a223805093715f9245b1cb37805976cbe0e9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 17 Aug 2023 14:58:15 -0700 Subject: [PATCH] bugfix: strip tags when falling back to update in `SageMakerEndpointOperator` also fixed the condition to fallback so that we don't retry when it's useless + added a warning on fallback to make the behavior more obvious to users --- .../amazon/aws/operators/sagemaker.py | 34 ++++++++++++----- .../aws/operators/test_sagemaker_endpoint.py | 37 +++++++++++++++++++ 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index ce0fa6f7c553a..1547d2203c076 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -494,16 +494,32 @@ def execute(self, context: Context) -> dict: try: response = sagemaker_operation( endpoint_info, - wait_for_completion=False, - ) - # waiting for completion is handled here in the operator - except ClientError: - self.operation = "update" - sagemaker_operation = self.hook.update_endpoint - response = sagemaker_operation( - endpoint_info, - wait_for_completion=False, + wait_for_completion=False, # waiting for completion is handled here in the operator ) + except ClientError as ce: + if self.operation == "create" and ce.response["Error"]["Message"].startswith( + "Cannot create already existing endpoint" + ): + # if we get an error because the endpoint already exists, we try to update it instead + self.operation = "update" + sagemaker_operation = self.hook.update_endpoint + self.log.warning( + "cannot create already existing endpoint %s, " + "updating it with the given config instead", + endpoint_info["EndpointName"], + ) + if "Tags" in endpoint_info: + self.log.warning( + "Provided tags will be ignored in the update operation " + "(tags on the existing endpoint will be unchanged)" + ) + endpoint_info.pop("Tags") + response = sagemaker_operation( + endpoint_info, + wait_for_completion=False, + ) + else: + raise if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker endpoint creation failed: {response}") diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py index 8a566535b98e3..d31556f9bf9a8 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py @@ -122,6 +122,43 @@ def test_execute_with_duplicate_endpoint_creation( } self.sagemaker.execute(None) + @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "create_endpoint_config") + @mock.patch.object(SageMakerHook, "create_endpoint") + @mock.patch.object(SageMakerHook, "update_endpoint") + @mock.patch.object(sagemaker, "serialize", return_value="") + def test_execute_with_duplicate_endpoint_removes_tags( + self, + serialize, + mock_endpoint_update, + mock_endpoint_create, + mock_endpoint_config, + mock_model, + mock_client, + ): + mock_endpoint_create.side_effect = ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Cannot create already existing endpoint.", + } + }, + operation_name="CreateEndpoint", + ) + + def _check_no_tags(config, wait_for_completion): + assert "Tags" not in config + return { + "EndpointArn": "test_arn", + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + mock_endpoint_update.side_effect = _check_no_tags + + self.sagemaker.config["Endpoint"]["Tags"] = {"Key": "k", "Value": "v"} + self.sagemaker.execute(None) + @mock.patch.object(SageMakerHook, "create_model") @mock.patch.object(SageMakerHook, "create_endpoint_config") @mock.patch.object(SageMakerHook, "create_endpoint")