Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,16 +494,32 @@ def execute(self, context: Context) -> dict:
try:
response = sagemaker_operation(
endpoint_info,
wait_for_completion=False,
)
# waiting for completion is handled here in the operator
except ClientError:
self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
response = sagemaker_operation(
endpoint_info,
wait_for_completion=False,
wait_for_completion=False, # waiting for completion is handled here in the operator
)
except ClientError as ce:
if self.operation == "create" and ce.response["Error"]["Message"].startswith(
"Cannot create already existing endpoint"
):
# if we get an error because the endpoint already exists, we try to update it instead
self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
self.log.warning(
"cannot create already existing endpoint %s, "
"updating it with the given config instead",
endpoint_info["EndpointName"],
)
if "Tags" in endpoint_info:
self.log.warning(
"Provided tags will be ignored in the update operation "
"(tags on the existing endpoint will be unchanged)"
)
endpoint_info.pop("Tags")
response = sagemaker_operation(
endpoint_info,
wait_for_completion=False,
)
else:
raise

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker endpoint creation failed: {response}")
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,43 @@ def test_execute_with_duplicate_endpoint_creation(
}
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_endpoint_config")
@mock.patch.object(SageMakerHook, "create_endpoint")
@mock.patch.object(SageMakerHook, "update_endpoint")
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_execute_with_duplicate_endpoint_removes_tags(
self,
serialize,
mock_endpoint_update,
mock_endpoint_create,
mock_endpoint_config,
mock_model,
mock_client,
):
mock_endpoint_create.side_effect = ClientError(
error_response={
"Error": {
"Code": "ValidationException",
"Message": "Cannot create already existing endpoint.",
}
},
operation_name="CreateEndpoint",
)

def _check_no_tags(config, wait_for_completion):
assert "Tags" not in config
return {
"EndpointArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
}

mock_endpoint_update.side_effect = _check_no_tags

self.sagemaker.config["Endpoint"]["Tags"] = {"Key": "k", "Value": "v"}
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_endpoint_config")
@mock.patch.object(SageMakerHook, "create_endpoint")
Expand Down