diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 791000ed781a6..d82dfa5f297de 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -25,6 +25,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: from airflow.utils.context import Context @@ -33,6 +34,10 @@ CHECK_INTERVAL_SECOND: int = 30 +def serialize(result: Dict) -> str: + return json.loads(json.dumps(result, cls=AirflowJsonEncoder)) + + class SageMakerBaseOperator(BaseOperator): """This is the base operator for all SageMaker operators. @@ -188,7 +193,7 @@ def execute(self, context: 'Context') -> Dict: ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Processing Job creation failed: {response}') - return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])} + return {'Processing': serialize(self.hook.describe_processing_job(self.config['ProcessingJobName']))} class SageMakerEndpointConfigOperator(SageMakerBaseOperator): @@ -360,8 +365,10 @@ def execute(self, context: 'Context') -> Dict: raise AirflowException(f'Sagemaker endpoint creation failed: {response}') else: return { - 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']), - 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']), + 'EndpointConfig': serialize( + self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']) + ), + 'Endpoint': serialize(self.hook.describe_endpoint(endpoint_info['EndpointName'])), } @@ -475,8 +482,10 @@ def execute(self, context: 'Context') -> Dict: raise AirflowException(f'Sagemaker transform Job creation failed: {response}') else: return { - 'Model': self.hook.describe_model(transform_config['ModelName']), - 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']), + 'Model': serialize(self.hook.describe_model(transform_config['ModelName'])), + 'Transform': serialize( + self.hook.describe_transform_job(transform_config['TransformJobName']) + ), } def _check_if_transform_job_exists(self) -> None: @@ -570,7 +579,9 @@ def execute(self, context: 'Context') -> Dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}') else: - return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])} + return { + 'Tuning': serialize(self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])) + } class SageMakerModelOperator(SageMakerBaseOperator): @@ -609,7 +620,7 @@ def execute(self, context: 'Context') -> Dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker model creation failed: {response}') else: - return {'Model': self.hook.describe_model(self.config['ModelName'])} + return {'Model': serialize(self.hook.describe_model(self.config['ModelName']))} class SageMakerTrainingOperator(SageMakerBaseOperator): @@ -698,7 +709,7 @@ def execute(self, context: 'Context') -> Dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Training Job creation failed: {response}') else: - return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} + return {'Training': serialize(self.hook.describe_training_job(self.config['TrainingJobName']))} def _check_if_job_exists(self) -> None: training_job_name = self.config['TrainingJobName'] diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py index 3f252ab4d8a17..76682ff98487a 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py @@ -25,14 +25,12 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator CREATE_MODEL_PARAMS: Dict = { 'ModelName': 'model_name', - 'PrimaryContainer': { - 'Image': 'image_name', - 'ModelDataUrl': 'output_path', - }, + 'PrimaryContainer': {'Image': 'image_name', 'ModelDataUrl': 'output_path'}, 'ExecutionRoleArn': 'arn:aws:iam:role/test-role', } CREATE_ENDPOINT_CONFIG_PARAMS: Dict = { @@ -71,7 +69,8 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_endpoint_config') @mock.patch.object(SageMakerHook, 'create_endpoint') - def test_integer_fields(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields(self, serialize, mock_endpoint, mock_endpoint_config, mock_model, mock_client): mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS @@ -82,7 +81,8 @@ def test_integer_fields(self, mock_endpoint, mock_endpoint_config, mock_model, m @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_endpoint_config') @mock.patch.object(SageMakerHook, 'create_endpoint') - def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute(self, serialize, mock_endpoint, mock_endpoint_config, mock_model, mock_client): mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) mock_model.assert_called_once_with(CREATE_MODEL_PARAMS) @@ -108,8 +108,9 @@ def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, mock_mo @mock.patch.object(SageMakerHook, 'create_endpoint_config') @mock.patch.object(SageMakerHook, 'create_endpoint') @mock.patch.object(SageMakerHook, 'update_endpoint') + @mock.patch.object(sagemaker, 'serialize', return_value="") def test_execute_with_duplicate_endpoint_creation( - self, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client + self, serialize, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client ): response = { 'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing endpoint.'} diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py index 98990f7c7451e..4ccc98525ad06 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerDeleteModelOperator, SageMakerModelOperator, @@ -47,14 +48,16 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') - def test_integer_fields(self, mock_model, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields(self, serialize, mock_model, mock_client): mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') - def test_execute(self, mock_model, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute(self, serialize, mock_model, mock_client): mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) mock_model.assert_called_once_with(CREATE_MODEL_PARAMS) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 10e4f3feda779..b2f90bf1b5861 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -23,6 +23,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator CREATE_PROCESSING_PARAMS: Dict = { @@ -99,7 +100,10 @@ def setUp(self): 'create_processing_job', return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}, ) - def test_integer_fields_without_stopping_condition(self, mock_processing, mock_hook, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields_without_stopping_condition( + self, serialize, mock_processing, mock_hook, mock_client + ): sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) @@ -115,7 +119,8 @@ def test_integer_fields_without_stopping_condition(self, mock_processing, mock_h 'create_processing_job', return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}, ) - def test_integer_fields_with_stopping_condition(self, mock_processing, mock_hook, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields_with_stopping_condition(self, serialize, mock_processing, mock_hook, mock_client): sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION ) @@ -137,7 +142,8 @@ def test_integer_fields_with_stopping_condition(self, mock_processing, mock_hook 'create_processing_job', return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}, ) - def test_execute(self, mock_processing, mock_hook, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute(self, serialize, mock_processing, mock_hook, mock_client): sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) @@ -153,7 +159,8 @@ def test_execute(self, mock_processing, mock_hook, mock_client): 'create_processing_job', return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}, ) - def test_execute_with_stopping_condition(self, mock_processing, mock_hook, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute_with_stopping_condition(self, serialize, mock_processing, mock_hook, mock_client): sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION ) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index f8f16e3cfe112..89c19bc629595 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -23,6 +23,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator EXPECTED_INTEGER_FIELDS: List[List[str]] = [ @@ -67,7 +68,8 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_training_job') - def test_integer_fields(self, mock_training, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields(self, serialize, mock_training, mock_client): mock_training.return_value = { 'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, @@ -80,7 +82,8 @@ def test_integer_fields(self, mock_training, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_training_job') - def test_execute_with_check_if_job_exists(self, mock_training, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute_with_check_if_job_exists(self, serialize, mock_training, mock_client): mock_training.return_value = { 'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, @@ -98,7 +101,8 @@ def test_execute_with_check_if_job_exists(self, mock_training, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_training_job') - def test_execute_without_check_if_job_exists(self, mock_training, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute_without_check_if_job_exists(self, serialize, mock_training, mock_client): mock_training.return_value = { 'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 871079b3d445f..f2acf907c9628 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator EXPECTED_INTEGER_FIELDS: List[List[str]] = [ @@ -65,7 +66,8 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_transform_job') - def test_integer_fields(self, mock_transform, mock_model, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields(self, serialize, mock_transform, mock_model, mock_client): mock_transform.return_value = { 'TransformJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, @@ -82,7 +84,8 @@ def test_integer_fields(self, mock_transform, mock_model, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_transform_job') - def test_execute(self, mock_transform, mock_model, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute(self, serialize, mock_transform, mock_model, mock_client): mock_transform.return_value = { 'TransformJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, @@ -106,7 +109,8 @@ def test_execute_with_failure(self, mock_transform, mock_model, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_transform_job') - def test_execute_with_check_if_job_exists(self, mock_transform, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute_with_check_if_job_exists(self, serialize, mock_transform, mock_client): mock_transform.return_value = { 'TransformJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, @@ -123,7 +127,8 @@ def test_execute_with_check_if_job_exists(self, mock_transform, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_transform_job') - def test_execute_without_check_if_job_exists(self, mock_transform, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute_without_check_if_job_exists(self, serialize, mock_transform, mock_client): mock_transform.return_value = { 'TransformJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}, diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py index 9d3efb0c49168..1eff6a2b81658 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator EXPECTED_INTEGER_FIELDS: List[List[str]] = [ @@ -83,7 +84,8 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') - def test_integer_fields(self, mock_tuning, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_integer_fields(self, serialize, mock_tuning, mock_client): mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS @@ -92,7 +94,8 @@ def test_integer_fields(self, mock_tuning, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') - def test_execute(self, mock_tuning, mock_client): + @mock.patch.object(sagemaker, 'serialize', return_value="") + def test_execute(self, serialize, mock_tuning, mock_client): mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) mock_tuning.assert_called_once_with( diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py b/tests/system/providers/amazon/aws/example_sagemaker.py index 22d786043fc9f..645b943ba5543 100644 --- a/tests/system/providers/amazon/aws/example_sagemaker.py +++ b/tests/system/providers/amazon/aws/example_sagemaker.py @@ -419,7 +419,6 @@ def delete_logs(env_id): preprocess_raw_data = SageMakerProcessingOperator( task_id='preprocess_raw_data', config=test_setup['processing_config'], - do_xcom_push=False, ) # [END howto_operator_sagemaker_processing] @@ -429,7 +428,6 @@ def delete_logs(env_id): config=test_setup['training_config'], # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_training] @@ -437,7 +435,6 @@ def delete_logs(env_id): await_training = SageMakerTrainingSensor( task_id="await_training", job_name=test_setup['training_job_name'], - do_xcom_push=False, ) # [END howto_sensor_sagemaker_training] @@ -445,7 +442,6 @@ def delete_logs(env_id): create_model = SageMakerModelOperator( task_id='create_model', config=test_setup['model_config'], - do_xcom_push=False, ) # [END howto_operator_sagemaker_model] @@ -455,7 +451,6 @@ def delete_logs(env_id): config=test_setup['tuning_config'], # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_tuning] @@ -463,7 +458,6 @@ def delete_logs(env_id): await_tune = SageMakerTuningSensor( task_id="await_tuning", job_name=test_setup['tuning_job_name'], - do_xcom_push=False, ) # [END howto_sensor_sagemaker_tuning] @@ -473,7 +467,6 @@ def delete_logs(env_id): config=test_setup['transform_config'], # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_transform] @@ -481,7 +474,6 @@ def delete_logs(env_id): await_transform = SageMakerTransformSensor( task_id="await_transform", job_name=test_setup['transform_job_name'], - do_xcom_push=False, ) # [END howto_sensor_sagemaker_transform] @@ -490,7 +482,6 @@ def delete_logs(env_id): task_id="delete_model", config={'ModelName': test_setup['model_name']}, trigger_rule=TriggerRule.ALL_DONE, - do_xcom_push=False, ) # [END howto_operator_sagemaker_delete_model] diff --git a/tests/system/providers/amazon/aws/example_sagemaker_endpoint.py b/tests/system/providers/amazon/aws/example_sagemaker_endpoint.py index 5733a87a92e76..6fd6d7c5ff146 100644 --- a/tests/system/providers/amazon/aws/example_sagemaker_endpoint.py +++ b/tests/system/providers/amazon/aws/example_sagemaker_endpoint.py @@ -210,20 +210,17 @@ def set_up(env_id, knn_image_uri, role_arn, ti=None): train_model = SageMakerTrainingOperator( task_id='train_model', config=test_setup['training_config'], - do_xcom_push=False, ) create_model = SageMakerModelOperator( task_id='create_model', config=test_setup['model_config'], - do_xcom_push=False, ) # [START howto_operator_sagemaker_endpoint_config] configure_endpoint = SageMakerEndpointConfigOperator( task_id='configure_endpoint', config=test_setup['endpoint_config_config'], - do_xcom_push=False, ) # [END howto_operator_sagemaker_endpoint_config] @@ -233,7 +230,6 @@ def set_up(env_id, knn_image_uri, role_arn, ti=None): config=test_setup['deploy_endpoint_config'], # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_endpoint]