Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.utils.json import AirflowJsonEncoder

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -33,6 +34,10 @@
CHECK_INTERVAL_SECOND: int = 30


def serialize(result: Dict) -> str:
return json.loads(json.dumps(result, cls=AirflowJsonEncoder))


class SageMakerBaseOperator(BaseOperator):
"""This is the base operator for all SageMaker operators.

Expand Down Expand Up @@ -188,7 +193,7 @@ def execute(self, context: 'Context') -> Dict:
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker Processing Job creation failed: {response}')
return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])}
return {'Processing': serialize(self.hook.describe_processing_job(self.config['ProcessingJobName']))}


class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
Expand Down Expand Up @@ -360,8 +365,10 @@ def execute(self, context: 'Context') -> Dict:
raise AirflowException(f'Sagemaker endpoint creation failed: {response}')
else:
return {
'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),
'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']),
'EndpointConfig': serialize(
self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName'])
),
'Endpoint': serialize(self.hook.describe_endpoint(endpoint_info['EndpointName'])),
}


Expand Down Expand Up @@ -475,8 +482,10 @@ def execute(self, context: 'Context') -> Dict:
raise AirflowException(f'Sagemaker transform Job creation failed: {response}')
else:
return {
'Model': self.hook.describe_model(transform_config['ModelName']),
'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
'Model': serialize(self.hook.describe_model(transform_config['ModelName'])),
'Transform': serialize(
self.hook.describe_transform_job(transform_config['TransformJobName'])
),
}

def _check_if_transform_job_exists(self) -> None:
Expand Down Expand Up @@ -570,7 +579,9 @@ def execute(self, context: 'Context') -> Dict:
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}')
else:
return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])}
return {
'Tuning': serialize(self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName']))
}


class SageMakerModelOperator(SageMakerBaseOperator):
Expand Down Expand Up @@ -609,7 +620,7 @@ def execute(self, context: 'Context') -> Dict:
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker model creation failed: {response}')
else:
return {'Model': self.hook.describe_model(self.config['ModelName'])}
return {'Model': serialize(self.hook.describe_model(self.config['ModelName']))}


class SageMakerTrainingOperator(SageMakerBaseOperator):
Expand Down Expand Up @@ -698,7 +709,7 @@ def execute(self, context: 'Context') -> Dict:
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker Training Job creation failed: {response}')
else:
return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])}
return {'Training': serialize(self.hook.describe_training_job(self.config['TrainingJobName']))}

def _check_if_job_exists(self) -> None:
training_job_name = self.config['TrainingJobName']
Expand Down
15 changes: 8 additions & 7 deletions tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator

CREATE_MODEL_PARAMS: Dict = {
'ModelName': 'model_name',
'PrimaryContainer': {
'Image': 'image_name',
'ModelDataUrl': 'output_path',
},
'PrimaryContainer': {'Image': 'image_name', 'ModelDataUrl': 'output_path'},
'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
}
CREATE_ENDPOINT_CONFIG_PARAMS: Dict = {
Expand Down Expand Up @@ -71,7 +69,8 @@ def setUp(self):
@mock.patch.object(SageMakerHook, 'create_model')
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
@mock.patch.object(SageMakerHook, 'create_endpoint')
def test_integer_fields(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields(self, serialize, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
Expand All @@ -82,7 +81,8 @@ def test_integer_fields(self, mock_endpoint, mock_endpoint_config, mock_model, m
@mock.patch.object(SageMakerHook, 'create_model')
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
@mock.patch.object(SageMakerHook, 'create_endpoint')
def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute(self, serialize, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
Expand All @@ -108,8 +108,9 @@ def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, mock_mo
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
@mock.patch.object(SageMakerHook, 'create_endpoint')
@mock.patch.object(SageMakerHook, 'update_endpoint')
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute_with_duplicate_endpoint_creation(
self, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client
self, serialize, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client
):
response = {
'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing endpoint.'}
Expand Down
7 changes: 5 additions & 2 deletions tests/providers/amazon/aws/operators/test_sagemaker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerDeleteModelOperator,
SageMakerModelOperator,
Expand All @@ -47,14 +48,16 @@ def setUp(self):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
def test_integer_fields(self, mock_model, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields(self, serialize, mock_model, mock_client):
mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
def test_execute(self, mock_model, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute(self, serialize, mock_model, mock_client):
mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
Expand Down
15 changes: 11 additions & 4 deletions tests/providers/amazon/aws/operators/test_sagemaker_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator

CREATE_PROCESSING_PARAMS: Dict = {
Expand Down Expand Up @@ -99,7 +100,10 @@ def setUp(self):
'create_processing_job',
return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
)
def test_integer_fields_without_stopping_condition(self, mock_processing, mock_hook, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields_without_stopping_condition(
self, serialize, mock_processing, mock_hook, mock_client
):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
)
Expand All @@ -115,7 +119,8 @@ def test_integer_fields_without_stopping_condition(self, mock_processing, mock_h
'create_processing_job',
return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
)
def test_integer_fields_with_stopping_condition(self, mock_processing, mock_hook, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields_with_stopping_condition(self, serialize, mock_processing, mock_hook, mock_client):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION
)
Expand All @@ -137,7 +142,8 @@ def test_integer_fields_with_stopping_condition(self, mock_processing, mock_hook
'create_processing_job',
return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
)
def test_execute(self, mock_processing, mock_hook, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute(self, serialize, mock_processing, mock_hook, mock_client):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
)
Expand All @@ -153,7 +159,8 @@ def test_execute(self, mock_processing, mock_hook, mock_client):
'create_processing_job',
return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
)
def test_execute_with_stopping_condition(self, mock_processing, mock_hook, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute_with_stopping_condition(self, serialize, mock_processing, mock_hook, mock_client):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION
)
Expand Down
10 changes: 7 additions & 3 deletions tests/providers/amazon/aws/operators/test_sagemaker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator

EXPECTED_INTEGER_FIELDS: List[List[str]] = [
Expand Down Expand Up @@ -67,7 +68,8 @@ def setUp(self):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_training_job')
def test_integer_fields(self, mock_training, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields(self, serialize, mock_training, mock_client):
mock_training.return_value = {
'TrainingJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand All @@ -80,7 +82,8 @@ def test_integer_fields(self, mock_training, mock_client):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_training_job')
def test_execute_with_check_if_job_exists(self, mock_training, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute_with_check_if_job_exists(self, serialize, mock_training, mock_client):
mock_training.return_value = {
'TrainingJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand All @@ -98,7 +101,8 @@ def test_execute_with_check_if_job_exists(self, mock_training, mock_client):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_training_job')
def test_execute_without_check_if_job_exists(self, mock_training, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute_without_check_if_job_exists(self, serialize, mock_training, mock_client):
mock_training.return_value = {
'TrainingJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand Down
13 changes: 9 additions & 4 deletions tests/providers/amazon/aws/operators/test_sagemaker_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator

EXPECTED_INTEGER_FIELDS: List[List[str]] = [
Expand Down Expand Up @@ -65,7 +66,8 @@ def setUp(self):
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_integer_fields(self, mock_transform, mock_model, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields(self, serialize, mock_transform, mock_model, mock_client):
mock_transform.return_value = {
'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand All @@ -82,7 +84,8 @@ def test_integer_fields(self, mock_transform, mock_model, mock_client):
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute(self, mock_transform, mock_model, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute(self, serialize, mock_transform, mock_model, mock_client):
mock_transform.return_value = {
'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand All @@ -106,7 +109,8 @@ def test_execute_with_failure(self, mock_transform, mock_model, mock_client):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute_with_check_if_job_exists(self, mock_transform, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute_with_check_if_job_exists(self, serialize, mock_transform, mock_client):
mock_transform.return_value = {
'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand All @@ -123,7 +127,8 @@ def test_execute_with_check_if_job_exists(self, mock_transform, mock_client):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute_without_check_if_job_exists(self, mock_transform, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute_without_check_if_job_exists(self, serialize, mock_transform, mock_client):
mock_transform.return_value = {
'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
Expand Down
7 changes: 5 additions & 2 deletions tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator

EXPECTED_INTEGER_FIELDS: List[List[str]] = [
Expand Down Expand Up @@ -83,7 +84,8 @@ def setUp(self):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_tuning_job')
def test_integer_fields(self, mock_tuning, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_integer_fields(self, serialize, mock_tuning, mock_client):
mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
Expand All @@ -92,7 +94,8 @@ def test_integer_fields(self, mock_tuning, mock_client):

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_tuning_job')
def test_execute(self, mock_tuning, mock_client):
@mock.patch.object(sagemaker, 'serialize', return_value="")
def test_execute(self, serialize, mock_tuning, mock_client):
mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
mock_tuning.assert_called_once_with(
Expand Down
9 changes: 0 additions & 9 deletions tests/system/providers/amazon/aws/example_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def delete_logs(env_id):
preprocess_raw_data = SageMakerProcessingOperator(
task_id='preprocess_raw_data',
config=test_setup['processing_config'],
do_xcom_push=False,
)
# [END howto_operator_sagemaker_processing]

Expand All @@ -429,23 +428,20 @@ def delete_logs(env_id):
config=test_setup['training_config'],
# Waits by default, setting as False to demonstrate the Sensor below.
wait_for_completion=False,
do_xcom_push=False,
)
# [END howto_operator_sagemaker_training]

# [START howto_sensor_sagemaker_training]
await_training = SageMakerTrainingSensor(
task_id="await_training",
job_name=test_setup['training_job_name'],
do_xcom_push=False,
)
# [END howto_sensor_sagemaker_training]

# [START howto_operator_sagemaker_model]
create_model = SageMakerModelOperator(
task_id='create_model',
config=test_setup['model_config'],
do_xcom_push=False,
)
# [END howto_operator_sagemaker_model]

Expand All @@ -455,15 +451,13 @@ def delete_logs(env_id):
config=test_setup['tuning_config'],
# Waits by default, setting as False to demonstrate the Sensor below.
wait_for_completion=False,
do_xcom_push=False,
)
# [END howto_operator_sagemaker_tuning]

# [START howto_sensor_sagemaker_tuning]
await_tune = SageMakerTuningSensor(
task_id="await_tuning",
job_name=test_setup['tuning_job_name'],
do_xcom_push=False,
)
# [END howto_sensor_sagemaker_tuning]

Expand All @@ -473,15 +467,13 @@ def delete_logs(env_id):
config=test_setup['transform_config'],
# Waits by default, setting as False to demonstrate the Sensor below.
wait_for_completion=False,
do_xcom_push=False,
)
# [END howto_operator_sagemaker_transform]

# [START howto_sensor_sagemaker_transform]
await_transform = SageMakerTransformSensor(
task_id="await_transform",
job_name=test_setup['transform_job_name'],
do_xcom_push=False,
)
# [END howto_sensor_sagemaker_transform]

Expand All @@ -490,7 +482,6 @@ def delete_logs(env_id):
task_id="delete_model",
config={'ModelName': test_setup['model_name']},
trigger_rule=TriggerRule.ALL_DONE,
do_xcom_push=False,
)
# [END howto_operator_sagemaker_delete_model]

Expand Down
Loading