From f0218a976ed3a4aabb927f2ff01b8dd60e06f2de Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 29 Mar 2022 10:36:53 -0700 Subject: [PATCH 1/6] Adds Sample DAGs and Docs for Amazon Sagemaker --- .../aws/example_dags/example_sagemaker.py | 264 +++++++++++------- .../providers/amazon/aws/hooks/sagemaker.py | 45 ++- .../amazon/aws/operators/sagemaker.py | 85 +++--- .../providers/amazon/aws/sensors/sagemaker.py | 59 ++-- .../operators/sagemaker.rst | 118 ++++++-- docs/spelling_wordlist.txt | 1 + 6 files changed, 369 insertions(+), 203 deletions(-) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py index 6162b061db465..df0950526c662 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py @@ -15,163 +15,219 @@ # specific language governing permissions and limitations # under the License. +import io +import os from datetime import datetime -from os import environ + +import numpy as np +import pandas as pd +import requests from airflow import DAG +from airflow.decorators import task +from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerDeleteModelOperator, SageMakerModelOperator, - SageMakerProcessingOperator, SageMakerTrainingOperator, SageMakerTransformOperator, ) +from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTrainingSensor, SageMakerTransformSensor -MODEL_NAME = "sample_model" -TRAINING_JOB_NAME = "sample_training" -IMAGE_URI = environ.get("ECR_IMAGE_URI", "123456789012.dkr.ecr.us-east-1.amazonaws.com/repo_name") -S3_BUCKET = environ.get("BUCKET_NAME", "test-airflow-12345") -ROLE = environ.get("SAGEMAKER_ROLE_ARN", "arn:aws:iam::123456789012:role/role_name") +# This Sample DAG demonstrates using SageMaker to identify various species of Iris flower. +# The Project Name variable below will be used to name various tasks and the required S3 keys. +PROJECT_NAME = 'iris' +TIMESTAMP = '{{ ts_nodash }}' -SAGEMAKER_PROCESSING_JOB_CONFIG = { - "ProcessingJobName": "sample_processing_job", - "ProcessingInputs": [ - { - "InputName": "input", - "AppManaged": False, - "S3Input": { - "S3Uri": f"s3://{S3_BUCKET}/preprocessing/input/", - "LocalPath": "/opt/ml/processing/input/", - "S3DataType": "S3Prefix", - "S3InputMode": "File", - "S3DataDistributionType": "FullyReplicated", - "S3CompressionType": "None", - }, - }, - ], - "ProcessingOutputConfig": { - "Outputs": [ - { - "OutputName": "output", - "S3Output": { - "S3Uri": f"s3://{S3_BUCKET}/preprocessing/output/", - "LocalPath": "/opt/ml/processing/output/", - "S3UploadMode": "EndOfJob", - }, - "AppManaged": False, - } - ] - }, - "ProcessingResources": { - "ClusterConfig": { - "InstanceCount": 1, - "InstanceType": "ml.m5.large", - "VolumeSizeInGB": 5, - } - }, - "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, - "AppSpecification": { - "ImageUri": f"{IMAGE_URI}", - "ContainerEntrypoint": ["python3", "./preprocessing.py"], - }, - "RoleArn": f"{ROLE}", -} +S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') +INPUT_S3_KEY = f'{PROJECT_NAME}/processed-input-data' +OUTPUT_S3_KEY = f'{PROJECT_NAME}/results' +MODEL_NAME = f'{PROJECT_NAME}-KNN-model' +TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' + +ROLE_ARN = os.getenv( + 'SAGEMAKER_ROLE_ARN', + 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', +) -SAGEMAKER_TRAINING_JOB_CONFIG = { +# A Sample dataset hosted by UC Irvine's machine learning repository +DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data' + +# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR +# repo cited in public SageMaker documentation, so the account number does not need to be redacted. +# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title +KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' + +# Define configs for training, model creation, and batch transform jobs +TRAINING_CONFIG = { "AlgorithmSpecification": { - "TrainingImage": f"{IMAGE_URI}", + "TrainingImage": KNN_IMAGE_URI, "TrainingInputMode": "File", }, + "HyperParameters": { + "predictor_type": "classifier", + "feature_dim": "4", + "k": "3", + "sample_size": "150", + }, "InputDataConfig": [ { - "ChannelName": "config", + "ChannelName": "train", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": f"s3://{S3_BUCKET}/config/", - "S3DataDistributionType": "FullyReplicated", + "S3Uri": f"s3://{S3_BUCKET}/{INPUT_S3_KEY}/train.csv", } }, - "CompressionType": "None", - "RecordWrapperType": "None", - }, + "ContentType": "text/csv", + "InputMode": "File", + } ], - "OutputDataConfig": { - "KmsKeyId": "", - "S3OutputPath": f"s3://{S3_BUCKET}/training/", - }, + "OutputDataConfig": {"S3OutputPath": f"s3://{S3_BUCKET}/{OUTPUT_S3_KEY}/"}, "ResourceConfig": { - "InstanceType": "ml.m5.large", "InstanceCount": 1, - "VolumeSizeInGB": 5, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 1, }, + "RoleArn": ROLE_ARN, "StoppingCondition": {"MaxRuntimeInSeconds": 6000}, - "RoleArn": f"{ROLE}", - "EnableNetworkIsolation": False, - "EnableInterContainerTrafficEncryption": False, - "EnableManagedSpotTraining": False, "TrainingJobName": TRAINING_JOB_NAME, } -SAGEMAKER_CREATE_MODEL_CONFIG = { +MODEL_CONFIG = { + "ExecutionRoleArn": ROLE_ARN, "ModelName": MODEL_NAME, - "Containers": [ - { - "Image": f"{IMAGE_URI}", - "Mode": "SingleModel", - "ModelDataUrl": f"s3://{S3_BUCKET}/training/{TRAINING_JOB_NAME}/output/model.tar.gz", - } - ], - "ExecutionRoleArn": f"{ROLE}", - "EnableNetworkIsolation": False, + "PrimaryContainer": { + "Mode": "SingleModel", + "Image": KNN_IMAGE_URI, + "ModelDataUrl": f"s3://{S3_BUCKET}/{OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz", + }, } -SAGEMAKER_INFERENCE_CONFIG = { - "TransformJobName": "sample_transform_job", - "ModelName": MODEL_NAME, +TRANSFORM_CONFIG = { + # Transform job names can't be reused, so appending a full timestamp tp ensure it is unique. + "TransformJobName": f"test-{PROJECT_NAME}-knn-{TIMESTAMP}", "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": f"s3://{S3_BUCKET}/config/config_date.yml", + "S3Uri": f"s3://{S3_BUCKET}/{INPUT_S3_KEY}/test.csv", } }, - "ContentType": "application/x-yaml", - "CompressionType": "None", - "SplitType": "None", + "SplitType": "Line", + "ContentType": "text/csv", }, - "TransformOutput": {"S3OutputPath": f"s3://{S3_BUCKET}/inferencing/output/"}, - "TransformResources": {"InstanceType": "ml.m5.large", "InstanceCount": 1}, + "TransformOutput": {"S3OutputPath": f"s3://{S3_BUCKET}/{OUTPUT_S3_KEY}"}, + "TransformResources": { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + }, + "ModelName": MODEL_NAME, } -# [START howto_operator_sagemaker] + +@task +def data_prep(data_url, s3_bucket, input_s3_key): + """ + Grabs the Iris dataset from API, splits into train/test splits, and saves CSV's to S3 using S3 Hook + """ + # Get data from API + iris_response = requests.get(data_url).content + columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'] + iris = pd.read_csv(io.StringIO(iris_response.decode('utf-8')), names=columns) + + # Process data + iris['species'] = iris['species'].replace({'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}) + iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']] + + # Split into test and train data + iris_train, iris_test = np.split( + iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))] + ) + iris_test.drop(['species'], axis=1, inplace=True) + + # Save files to S3 + iris_train.to_csv('iris_train.csv', index=False, header=False) + iris_test.to_csv('iris_test.csv', index=False, header=False) + s3_hook = S3Hook(aws_conn_id='aws-sagemaker') + s3_hook.load_file( + 'iris_train.csv', + f'{input_s3_key}/train.csv', + bucket_name=s3_bucket, + replace=True, + ) + s3_hook.load_file( + 'iris_test.csv', + f'{input_s3_key}/test.csv', + bucket_name=s3_bucket, + replace=True, + ) + + with DAG( - "sample_sagemaker_dag", + dag_id='example_sagemaker', schedule_interval=None, - start_date=datetime(2022, 2, 21), + start_date=datetime(2021, 1, 1), + tags=['example'], catchup=False, ) as dag: - sagemaker_processing_task = SageMakerProcessingOperator( - config=SAGEMAKER_PROCESSING_JOB_CONFIG, - aws_conn_id="aws_default", - task_id="sagemaker_preprocessing_task", + + # [START howto_operator_sagemaker_training] + train_model = SageMakerTrainingOperator( + task_id='train_model', + config=TRAINING_CONFIG, + # Waits by default, setting as False to demonstrate the Sensor below. + wait_for_completion=False, + do_xcom_push=False, ) + # [END howto_operator_sagemaker_training] - training_task = SageMakerTrainingOperator( - config=SAGEMAKER_TRAINING_JOB_CONFIG, aws_conn_id="aws_default", task_id="sagemaker_training_task" + # [START howto_operator_sagemaker_training_sensor] + await_training = SageMakerTrainingSensor( + task_id="await_training", + job_name=TRAINING_JOB_NAME, ) + # [END howto_operator_sagemaker_training_sensor] - model_create_task = SageMakerModelOperator( - config=SAGEMAKER_CREATE_MODEL_CONFIG, aws_conn_id="aws_default", task_id="sagemaker_create_model_task" + # [START howto_operator_sagemaker_model] + create_model = SageMakerModelOperator( + task_id='create_model', + config=MODEL_CONFIG, + do_xcom_push=False, ) + # [END howto_operator_sagemaker_model] - inference_task = SageMakerTransformOperator( - config=SAGEMAKER_INFERENCE_CONFIG, aws_conn_id="aws_default", task_id="sagemaker_inference_task" + # [START howto_operator_sagemaker_transform] + test_model = SageMakerTransformOperator( + task_id='test_model', + config=TRANSFORM_CONFIG, + # Waits by default, setting as False to demonstrate the Sensor below. + wait_for_completion=False, + do_xcom_push=False, ) + # [END howto_operator_sagemaker_transform] - model_delete_task = SageMakerDeleteModelOperator( - task_id="sagemaker_delete_model_task", config={'ModelName': MODEL_NAME}, aws_conn_id="aws_default" + # [START howto_operator_sagemaker_transform_sensor] + await_transform = SageMakerTransformSensor( + task_id="await_transform", + job_name=f"test-{PROJECT_NAME}-knn-{TIMESTAMP}", ) + # [END howto_operator_sagemaker_transform_sensor] - sagemaker_processing_task >> training_task >> model_create_task >> inference_task >> model_delete_task - # [END howto_operator_sagemaker] + # [START howto_operator_sagemaker_delete_model] + delete_model = SageMakerDeleteModelOperator( + task_id="delete_model", + config={'ModelName': MODEL_NAME}, + trigger_rule='all_done', + ) + # [END howto_operator_sagemaker_delete_model] + + ( + data_prep(DATA_URL, S3_BUCKET, INPUT_S3_KEY) + >> train_model + >> await_training + >> create_model + >> test_model + >> await_transform + >> delete_model + ) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 73348c8555aa6..2c8c28a738ec3 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -310,7 +310,8 @@ def create_training_job( max_ingestion_time: Optional[int] = None, ): """ - Create a training job + Starts a model training job. After training completes, Amazon SageMaker saves + the resulting model artifacts to an Amazon S3 location that you specify. :param config: the config for training :param wait_for_completion: if the program should keep running until job finishes @@ -357,7 +358,11 @@ def create_tuning_job( max_ingestion_time: Optional[int] = None, ): """ - Create a tuning job + Starts a hyperparameter tuning job. A hyperparameter tuning job finds the + best version of a model by running many training jobs on your dataset using + the algorithm you choose and values for hyperparameters within ranges that + you specify. It then chooses the hyperparameter values that result in a model + that performs the best, as measured by an objective metric that you choose. :param config: the config for tuning :param wait_for_completion: if the program should keep running until job finishes @@ -389,7 +394,8 @@ def create_transform_job( max_ingestion_time: Optional[int] = None, ): """ - Create a transform job + Starts a transform job. A transform job uses a trained model to get inferences + on a dataset and saves these results to an Amazon S3 location that you specify. :param config: the config for transform job :param wait_for_completion: if the program should keep running until job finishes @@ -422,7 +428,10 @@ def create_processing_job( max_ingestion_time: Optional[int] = None, ): """ - Create a processing job + Use Amazon SageMaker Processing to analyze data and evaluate machine learning + models on Amazon SageMaker. With Processing, you can use a simplified, managed + experience on SageMaker to run your data processing workloads, such as feature + engineering, data validation, model evaluation, and model interpretation. :param config: the config for processing job :param wait_for_completion: if the program should keep running until job finishes @@ -446,7 +455,10 @@ def create_processing_job( def create_model(self, config: dict): """ - Create a model job + Creates a model in Amazon SageMaker. In the request, you name the model and + describe a primary container. For the primary container, you specify the Docker + image that contains inference code, artifacts (from prior training), and a custom + environment map that the inference code uses when you deploy the model for predictions. :param config: the config for model :return: A response to model creation @@ -455,7 +467,14 @@ def create_model(self, config: dict): def create_endpoint_config(self, config: dict): """ - Create an endpoint config + Creates an endpoint configuration that Amazon SageMaker hosting + services uses to deploy models. In the configuration, you identify + one or more models, created using the CreateModel API, to deploy and + the resources that you want Amazon SageMaker to provision. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_model` + :class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint` :param config: the config for endpoint-config :return: A response to endpoint config creation @@ -470,7 +489,15 @@ def create_endpoint( max_ingestion_time: Optional[int] = None, ): """ - Create an endpoint + When you create a serverless endpoint, SageMaker provisions and manages + the compute resources for you. Then, you can make inference requests to + the endpoint and receive model predictions in response. SageMaker scales + the compute resources up and down as needed to handle your request traffic. + + Requires an Endpoint Config. + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint_config` + :param config: the config for endpoint :param wait_for_completion: if the program should keep running until job finishes @@ -501,7 +528,9 @@ def update_endpoint( max_ingestion_time: Optional[int] = None, ): """ - Update an endpoint + Deploys the new EndpointConfig specified in the request, switches to using + newly created endpoint, and then deletes resources provisioned for the + endpoint using the previous EndpointConfig (there is no availability loss). :param config: the config for endpoint :param wait_for_completion: if the program should keep running until job finishes diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 071c167e9a0d1..3638f05a1a3d1 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -105,14 +105,14 @@ def hook(self): class SageMakerProcessingOperator(SageMakerBaseOperator): - """Initiate a SageMaker processing job. - - This operator returns The ARN of the processing job created in Amazon SageMaker. + """ + Use Amazon SageMaker Processing to analyze data and evaluate machine learning + models on Amazon SageMake. With Processing, you can use a simplified, managed + experience on SageMaker to run your data processing workloads, such as feature + engineering, data validation, model evaluation, and model interpretation. :param config: The configuration necessary to start a processing job (templated). - For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the processing job. :param print_log: if the operator should print the cloudwatch log during processing @@ -123,13 +123,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): the operation does not timeout. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". + :return Dict: Returns The ARN of the processing job created in Amazon SageMaker. """ def __init__( self, *, config: dict, - aws_conn_id: str, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, @@ -137,7 +137,7 @@ def __init__( action_if_job_exists: str = 'increment', **kwargs, ): - super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) + super().__init__(config=config, **kwargs) if action_if_job_exists not in ('increment', 'fail'): raise AirflowException( f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ @@ -185,14 +185,16 @@ def execute(self, context: 'Context') -> dict: class SageMakerEndpointConfigOperator(SageMakerBaseOperator): """ - Create a SageMaker endpoint config. - - This operator returns The ARN of the endpoint config created in Amazon SageMaker + Creates an endpoint configuration that Amazon SageMaker hosting + services uses to deploy models. In the configuration, you identify + one or more models, created using the CreateModel API, to deploy and + the resources that you want Amazon SageMaker to provision. :param config: The configuration necessary to create an endpoint config. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config` :param aws_conn_id: The AWS connection ID to use. + :return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker. """ integer_fields = [['ProductionVariants', 'InitialInstanceCount']] @@ -213,9 +215,12 @@ def execute(self, context: 'Context') -> dict: class SageMakerEndpointOperator(SageMakerBaseOperator): """ - Create a SageMaker endpoint. + When you create a serverless endpoint, SageMaker provisions and manages + the compute resources for you. Then, you can make inference requests to + the endpoint and receive model predictions in response. SageMaker scales + the compute resources up and down as needed to handle your request traffic. - This operator returns The ARN of the endpoint created in Amazon SageMaker + Requires an Endpoint Config. :param config: The configuration necessary to create an endpoint. @@ -242,13 +247,13 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): For details of the configuration parameter of endpoint_configuration see :py:meth:`SageMaker.Client.create_endpoint` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes. :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation waits before polling the status of the endpoint creation. :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't finish within max_ingestion_time seconds. If you set this parameter to None it never times out. :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'. + :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker. """ def __init__( @@ -331,9 +336,13 @@ def execute(self, context: 'Context') -> dict: class SageMakerTransformOperator(SageMakerBaseOperator): - """Initiate a SageMaker transform job. + """ + Starts a transform job. A transform job uses a trained model to get inferences + on a dataset and saves these results to an Amazon S3 location that you specify. - This operator returns The ARN of the model created in Amazon SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerTransformOperator` :param config: The configuration necessary to start a transform job (templated). @@ -354,13 +363,13 @@ class SageMakerTransformOperator(SageMakerBaseOperator): For details of the configuration parameter of model_config, See: :py:meth:`SageMaker.Client.create_model` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the transform job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the transform job. :param max_ingestion_time: If wait is set to True, the operation fails if the transform job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. + :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ def __init__( @@ -422,21 +431,24 @@ def execute(self, context: 'Context') -> dict: class SageMakerTuningOperator(SageMakerBaseOperator): - """Initiate a SageMaker hyperparameter tuning job. - - This operator returns The ARN of the tuning job created in Amazon SageMaker. + """ + Starts a hyperparameter tuning job. A hyperparameter tuning job finds the + best version of a model by running many training jobs on your dataset using + the algorithm you choose and values for hyperparameters within ranges that + you specify. It then chooses the hyperparameter values that result in a model + that performs the best, as measured by an objective metric that you choose. :param config: The configuration necessary to start a tuning job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the tuning job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the tuning job. :param max_ingestion_time: If wait is set to True, the operation fails if the tuning job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. + :return Dict: Returns The ARN of the tuning job created in Amazon SageMaker. """ integer_fields = [ @@ -487,14 +499,20 @@ def execute(self, context: 'Context') -> dict: class SageMakerModelOperator(SageMakerBaseOperator): - """Create a SageMaker model. + """ + Creates a model in Amazon SageMaker. In the request, you name the model and + describe a primary container. For the primary container, you specify the Docker + image that contains inference code, artifacts (from prior training), and a custom + environment map that the inference code uses when you deploy the model for predictions. - This operator returns The ARN of the model created in Amazon SageMaker + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerModelOperator` :param config: The configuration necessary to create a model. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model` - :param aws_conn_id: The AWS connection ID to use. + :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ def __init__(self, *, config, **kwargs): @@ -518,14 +536,16 @@ def execute(self, context: 'Context') -> dict: class SageMakerTrainingOperator(SageMakerBaseOperator): """ - Initiate a SageMaker training job. + Starts a model training job. After training completes, Amazon SageMaker saves + the resulting model artifacts to an Amazon S3 location that you specify. - This operator returns The ARN of the training job created in Amazon SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerTrainingOperator` :param config: The configuration necessary to start a training job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the training job. :param print_log: if the operator should print the cloudwatch log during training @@ -539,6 +559,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". This is only relevant if check_if + :return Dict: Returns The ARN of the training job created in Amazon SageMaker. """ integer_fields = [ @@ -611,19 +632,19 @@ def _check_if_job_exists(self) -> None: class SageMakerDeleteModelOperator(SageMakerBaseOperator): - """Deletes a SageMaker model. + """ + Deletes a SageMaker model. - This operator deletes the Model entry created in SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerDeleteModelOperator` :param config: The configuration necessary to delete the model. - For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model` - :param aws_conn_id: The AWS connection ID to use. """ - def __init__(self, *, config, aws_conn_id: str, **kwargs): + def __init__(self, *, config, **kwargs): super().__init__(config=config, **kwargs) - self.aws_conn_id = aws_conn_id self.config = config def execute(self, context: 'Context') -> Any: diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index 054b139cc2eb9..bf833c458879b 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -27,10 +27,10 @@ class SageMakerBaseSensor(BaseSensorOperator): - """Contains general sensor behavior for SageMaker. + """ + Contains general sensor behavior for SageMaker. - Subclasses should implement get_sagemaker_response() - and state_from_response() methods. + Subclasses should implement get_sagemaker_response() and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """ @@ -84,13 +84,11 @@ def state_from_response(self, response: dict) -> str: class SageMakerEndpointSensor(SageMakerBaseSensor): - """Asks for the state of the endpoint state until it reaches a - terminal state. - If it fails the sensor errors, the task fails. - - - :param job_name: job_name of the endpoint instance to check the state of + """ + Polls the endpoint state until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. + :param endpoint_name: Name of the endpoint instance to watch. """ template_fields: Sequence[str] = ('endpoint_name',) @@ -118,15 +116,15 @@ def state_from_response(self, response): class SageMakerTransformSensor(SageMakerBaseSensor): - """Asks for the state of the transform state until it reaches a - terminal state. - The sensor will error if the job errors, throwing a - AirflowException - containing the failure reason. + """ + Polls the transform job until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. - :param - job_name: job_name of the transform job instance to check the state of + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:SageMakerTransformSensor` + :param job_name: Name of the transform job to watch. """ template_fields: Sequence[str] = ('job_name',) @@ -154,16 +152,11 @@ def state_from_response(self, response): class SageMakerTuningSensor(SageMakerBaseSensor): - """Asks for the state of the tuning state until it reaches a terminal - state. - The sensor will error if the job errors, throwing a - AirflowException - containing the failure reason. - - :param - job_name: job_name of the tuning instance to check the state of - :type - job_name: str + """ + Asks for the state of the tuning state until it reaches a terminal state. + Raises an AirflowException with the failure reason if a failed state is reached. + + :param job_name: Name of the tuning instance to watch. """ template_fields: Sequence[str] = ('job_name',) @@ -191,14 +184,16 @@ def state_from_response(self, response): class SageMakerTrainingSensor(SageMakerBaseSensor): - """Asks for the state of the training state until it reaches a - terminal state. - If it fails the sensor errors, failing the task. - + """ + Polls the training job until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. - :param job_name: name of the SageMaker training job to check the state of + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:SageMakerTrainingSensor` - :param print_log: if the operator should print the cloudwatch log + :param job_name: Name of the training job to watch. + :param print_log: Prints the cloudwatch log if True; Defaults to True. """ template_fields: Sequence[str] = ('job_name',) diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst index f44d258a2718f..8562d7916d734 100644 --- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst +++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst @@ -15,52 +15,116 @@ specific language governing permissions and limitations under the License. - Amazon SageMaker Operators -======================================== +========================== + +`Amazon SageMaker `__ is a fully managed +machine learning service. With Amazon SageMaker, data scientists and developers +can quickly build and train machine learning models, and then deploy them into a +production-ready hosted environment. + +Airflow provides operators to create and interact with SageMaker Jobs. Prerequisite Tasks ------------------ .. include:: _partials/prerequisite_tasks.rst -Overview --------- +Manage Amazon SageMaker Jobs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. _howto/operator:SageMakerTrainingOperator: + +Create an Amazon SageMaker Training Job +""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker training job you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_training] + :end-before: [END howto_operator_sagemaker_training] + +.. _howto/operator:SageMakerModelOperator: + +Create an Amazon SageMaker Model +"""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker model you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_model] + :end-before: [END howto_operator_sagemaker_model] + +.. _howto/operator:SageMakerDeleteModelOperator: -Airflow to Amazon SageMaker integration provides several operators to create and interact with -SageMaker Jobs. +Delete an Amazon SageMaker Model +"""""""""""""""""""""""""""""""" - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator` +To delete an Amazon Sagemaker model you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator`. -Purpose -""""""" +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_delete_model] + :end-before: [END howto_operator_sagemaker_delete_model] + +.. _howto/operator:SageMakerTransformOperator: + +Create an Amazon SageMaker Transform Job +"""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker transform job you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_transform] + :end-before: [END howto_operator_sagemaker_transform] + + +Amazon SageMaker Sensors +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. _howto/sensor:SageMakerTrainingSensor: + +Amazon SageMaker Training Sensor +"""""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker training job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTrainingSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_training_sensor] + :end-before: [END howto_operator_sagemaker_training_sensor] -This example DAG ``example_sagemaker.py`` uses ``SageMakerProcessingOperator``, ``SageMakerTrainingOperator``, -``SageMakerModelOperator``, ``SageMakerDeleteModelOperator`` and ``SageMakerTransformOperator`` to -create SageMaker processing job, run the training job, -generate the models artifact in s3, create the model, -, run SageMaker Batch inference and delete the model from SageMaker. +.. _howto/sensor:SageMakerTransformSensor: -Defining tasks -"""""""""""""" +Amazon SageMaker Transform Sensor +""""""""""""""""""""""""""""""""""" -In the following code we create a SageMaker processing, -training, Sagemaker Model, batch transform job and -then delete the model. +To check the state of an Amazon Sagemaker transform job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`. .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py :language: python - :start-after: [START howto_operator_sagemaker] - :end-before: [END howto_operator_sagemaker] + :dedent: 4 + :start-after: [START howto_operator_sagemaker_transform_sensor] + :end-before: [END howto_operator_sagemaker_transform_sensor] Reference ---------- +^^^^^^^^^ For further information, look at: * `Boto3 Library Documentation for Sagemaker `__ +* `Amazon SageMaker Developer Guide `__ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b9f2aad173a90..b55f42eaad761 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -909,6 +909,7 @@ https httpx hvac hyperparameter +hyperparameters iPython iTerm iam From 1386ca214a42a4709000bb45419ec2a35fa15c88 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 5 Apr 2022 16:37:32 -0700 Subject: [PATCH 2/6] Adds SageMaker Tuning to the Sample DAG and Docs --- .../aws/example_dags/example_sagemaker.py | 144 ++++++++++++++---- .../amazon/aws/operators/sagemaker.py | 4 + .../providers/amazon/aws/sensors/sagemaker.py | 8 +- .../operators/sagemaker.rst | 28 ++++ 4 files changed, 153 insertions(+), 31 deletions(-) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py index df0950526c662..9c28b53e100f4 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py @@ -18,6 +18,7 @@ import io import os from datetime import datetime +from typing import Mapping import numpy as np import pandas as pd @@ -31,25 +32,50 @@ SageMakerModelOperator, SageMakerTrainingOperator, SageMakerTransformOperator, + SageMakerTuningOperator, +) +from airflow.providers.amazon.aws.sensors.sagemaker import ( + SageMakerTrainingSensor, + SageMakerTransformSensor, + SageMakerTuningSensor, ) -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTrainingSensor, SageMakerTransformSensor -# This Sample DAG demonstrates using SageMaker to identify various species of Iris flower. -# The Project Name variable below will be used to name various tasks and the required S3 keys. +# Project name will be used in naming the S3 buckets and various tasks. +# The dataset used in this example is identifying varieties of the Iris flower. PROJECT_NAME = 'iris' TIMESTAMP = '{{ ts_nodash }}' S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') -INPUT_S3_KEY = f'{PROJECT_NAME}/processed-input-data' -OUTPUT_S3_KEY = f'{PROJECT_NAME}/results' -MODEL_NAME = f'{PROJECT_NAME}-KNN-model' -TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' - ROLE_ARN = os.getenv( 'SAGEMAKER_ROLE_ARN', 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', ) +INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' +TRAINING_DATA_SOURCE: Mapping[str, str] = { + "CompressionType": "None", + "ContentType": "text/csv", + "DataSource": { # type: ignore + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', + }, +} +TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/results' +PREDICTION_OUTPUT_S3_KEY = f'{PROJECT_NAME}/transform' + +MODEL_NAME = f'{PROJECT_NAME}-KNN-model' +# Job names can't be reused, so appending a timestamp to ensure it is unique. +TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' +TRANSFORM_JOB_NAME = f'{PROJECT_NAME}-transform-{TIMESTAMP}' +TUNING_JOB_NAME = f'{PROJECT_NAME}-tune-{TIMESTAMP}' + +RESOURCE_CONFIG = { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 1, +} + # A Sample dataset hosted by UC Irvine's machine learning repository DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data' @@ -73,22 +99,11 @@ "InputDataConfig": [ { "ChannelName": "train", - "DataSource": { - "S3DataSource": { - "S3DataType": "S3Prefix", - "S3Uri": f"s3://{S3_BUCKET}/{INPUT_S3_KEY}/train.csv", - } - }, - "ContentType": "text/csv", - "InputMode": "File", + **TRAINING_DATA_SOURCE, } ], - "OutputDataConfig": {"S3OutputPath": f"s3://{S3_BUCKET}/{OUTPUT_S3_KEY}/"}, - "ResourceConfig": { - "InstanceCount": 1, - "InstanceType": "ml.m5.large", - "VolumeSizeInGB": 1, - }, + "OutputDataConfig": {"S3OutputPath": f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, + "ResourceConfig": RESOURCE_CONFIG, "RoleArn": ROLE_ARN, "StoppingCondition": {"MaxRuntimeInSeconds": 6000}, "TrainingJobName": TRAINING_JOB_NAME, @@ -100,24 +115,23 @@ "PrimaryContainer": { "Mode": "SingleModel", "Image": KNN_IMAGE_URI, - "ModelDataUrl": f"s3://{S3_BUCKET}/{OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz", + "ModelDataUrl": f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', }, } TRANSFORM_CONFIG = { - # Transform job names can't be reused, so appending a full timestamp tp ensure it is unique. - "TransformJobName": f"test-{PROJECT_NAME}-knn-{TIMESTAMP}", + "TransformJobName": TRANSFORM_JOB_NAME, "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": f"s3://{S3_BUCKET}/{INPUT_S3_KEY}/test.csv", + "S3Uri": f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv', } }, "SplitType": "Line", "ContentType": "text/csv", }, - "TransformOutput": {"S3OutputPath": f"s3://{S3_BUCKET}/{OUTPUT_S3_KEY}"}, + "TransformOutput": {"S3OutputPath": f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'}, "TransformResources": { "InstanceCount": 1, "InstanceType": "ml.m5.large", @@ -125,6 +139,59 @@ "ModelName": MODEL_NAME, } +TUNING_CONFIG = { + "HyperParameterTuningJobName": TUNING_JOB_NAME, + "HyperParameterTuningJobConfig": { + "Strategy": "Bayesian", + "HyperParameterTuningJobObjective": { + "MetricName": "test:accuracy", + "Type": "Maximize", + }, + "ResourceLimits": { + # You would bump these up in production as appropriate. + "MaxNumberOfTrainingJobs": 1, + "MaxParallelTrainingJobs": 1, + }, + "ParameterRanges": { + "CategoricalParameterRanges": [], + "IntegerParameterRanges": [ + # Set the min and max values of the hyperparameters you want to tune. + { + "Name": "k", + "MinValue": "1", + "MaxValue": "1024", + }, + { + "Name": "sample_size", + "MinValue": "100", + "MaxValue": "2000", + }, + ], + }, + }, + "TrainingJobDefinition": { + "StaticHyperParameters": { + "predictor_type": "classifier", + "feature_dim": "4", + }, + "AlgorithmSpecification": {"TrainingImage": KNN_IMAGE_URI, "TrainingInputMode": "File"}, + "InputDataConfig": [ + { + "ChannelName": "train", + **TRAINING_DATA_SOURCE, + }, + { + "ChannelName": "test", + **TRAINING_DATA_SOURCE, + }, + ], + "OutputDataConfig": {"S3OutputPath": f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'}, + "ResourceConfig": RESOURCE_CONFIG, + "RoleArn": ROLE_ARN, + "StoppingCondition": {"MaxRuntimeInSeconds": 600}, + }, +} + @task def data_prep(data_url, s3_bucket, input_s3_key): @@ -197,6 +264,23 @@ def data_prep(data_url, s3_bucket, input_s3_key): ) # [END howto_operator_sagemaker_model] + # [START howto_operator_sagemaker_tuning] + tune_model = SageMakerTuningOperator( + task_id="tune_model", + config=TUNING_CONFIG, + # Waits by default, setting as False to demonstrate the Sensor below. + wait_for_completion=False, + do_xcom_push=False, + ) + # [END howto_operator_sagemaker_tuning] + + # [START howto_operator_sagemaker_tuning_sensor] + await_tune = SageMakerTuningSensor( + task_id="await_tuning", + job_name=TUNING_JOB_NAME, + ) + # [END howto_operator_sagemaker_tuning_sensor] + # [START howto_operator_sagemaker_transform] test_model = SageMakerTransformOperator( task_id='test_model', @@ -210,7 +294,7 @@ def data_prep(data_url, s3_bucket, input_s3_key): # [START howto_operator_sagemaker_transform_sensor] await_transform = SageMakerTransformSensor( task_id="await_transform", - job_name=f"test-{PROJECT_NAME}-knn-{TIMESTAMP}", + job_name=TRANSFORM_JOB_NAME, ) # [END howto_operator_sagemaker_transform_sensor] @@ -223,10 +307,12 @@ def data_prep(data_url, s3_bucket, input_s3_key): # [END howto_operator_sagemaker_delete_model] ( - data_prep(DATA_URL, S3_BUCKET, INPUT_S3_KEY) + data_prep(DATA_URL, S3_BUCKET, INPUT_DATA_S3_KEY) >> train_model >> await_training >> create_model + >> tune_model + >> await_tune >> test_model >> await_transform >> delete_model diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 3638f05a1a3d1..b7c858af63591 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -438,6 +438,10 @@ class SageMakerTuningOperator(SageMakerBaseOperator): you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by an objective metric that you choose. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerTuningOperator` + :param config: The configuration necessary to start a tuning job (templated). For details of the configuration parameter see diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index bf833c458879b..c4e79c39be402 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -121,7 +121,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor): AirflowException with the failure reason if a failed state is reached. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTransformSensor` :param job_name: Name of the transform job to watch. @@ -156,6 +156,10 @@ class SageMakerTuningSensor(SageMakerBaseSensor): Asks for the state of the tuning state until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerTuningSensor` + :param job_name: Name of the tuning instance to watch. """ @@ -189,7 +193,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): AirflowException with the failure reason if a failed state is reached. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTrainingSensor` :param job_name: Name of the training job to watch. diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst index 8562d7916d734..68e737d859aa9 100644 --- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst +++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst @@ -61,6 +61,20 @@ To create an Amazon Sagemaker model you can use :start-after: [START howto_operator_sagemaker_model] :end-before: [END howto_operator_sagemaker_model] +.. _howto/operator:SageMakerTuningOperator: + +Start a Hyperparameter Tuning Job +""""""""""""""""""""""""""""""""" + +To start a hyperparameter tuning job for an Amazon Sagemaker model you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_tuning] + :end-before: [END howto_operator_sagemaker_tuning] + .. _howto/operator:SageMakerDeleteModelOperator: Delete an Amazon SageMaker Model @@ -121,6 +135,20 @@ you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerT :start-after: [START howto_operator_sagemaker_transform_sensor] :end-before: [END howto_operator_sagemaker_transform_sensor] +.. _howto/sensor:SageMakerTuningSensor: + +Amazon SageMaker Tuning Sensor +"""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker hyperparameter tuning job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTuningSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_tuning_sensor] + :end-before: [END howto_operator_sagemaker_tuning_sensor] + Reference ^^^^^^^^^ From 5917292233cced61a61b32da0c609583fce67bb6 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 21 Apr 2022 14:15:27 -0700 Subject: [PATCH 3/6] Adds SageMaker Processing to the Sample DAG and Docs --- .../aws/example_dags/example_sagemaker.py | 380 ++++++++++++------ .../amazon/aws/operators/sagemaker.py | 4 + .../operators/sagemaker.rst | 15 + 3 files changed, 277 insertions(+), 122 deletions(-) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py index 9c28b53e100f4..73b4a85f87cdd 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py @@ -14,15 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import io +import base64 import os +import subprocess from datetime import datetime -from typing import Mapping +from tempfile import NamedTemporaryFile -import numpy as np -import pandas as pd -import requests +import boto3 from airflow import DAG from airflow.decorators import task @@ -30,6 +28,7 @@ from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerDeleteModelOperator, SageMakerModelOperator, + SageMakerProcessingOperator, SageMakerTrainingOperator, SageMakerTransformOperator, SageMakerTuningOperator, @@ -46,191 +45,317 @@ TIMESTAMP = '{{ ts_nodash }}' S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') -ROLE_ARN = os.getenv( - 'SAGEMAKER_ROLE_ARN', - 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', -) - +RAW_DATA_S3_KEY = f'{PROJECT_NAME}/preprocessing/input.csv' INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' -TRAINING_DATA_SOURCE: Mapping[str, str] = { - "CompressionType": "None", - "ContentType": "text/csv", - "DataSource": { # type: ignore - "S3DataDistributionType": "FullyReplicated", - "S3DataType": "S3Prefix", - "S3Uri": f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', - }, -} TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/results' PREDICTION_OUTPUT_S3_KEY = f'{PROJECT_NAME}/transform' +PROCESSING_LOCAL_INPUT_PATH = '/opt/ml/processing/input' +PROCESSING_LOCAL_OUTPUT_PATH = '/opt/ml/processing/output' + MODEL_NAME = f'{PROJECT_NAME}-KNN-model' -# Job names can't be reused, so appending a timestamp to ensure it is unique. +# Job names can't be reused, so appending a timestamp ensures it is unique. +PROCESSING_JOB_NAME = f'{PROJECT_NAME}-processing-{TIMESTAMP}' TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' TRANSFORM_JOB_NAME = f'{PROJECT_NAME}-transform-{TIMESTAMP}' TUNING_JOB_NAME = f'{PROJECT_NAME}-tune-{TIMESTAMP}' -RESOURCE_CONFIG = { - "InstanceCount": 1, - "InstanceType": "ml.m5.large", - "VolumeSizeInGB": 1, -} +ROLE_ARN = os.getenv( + 'SAGEMAKER_ROLE_ARN', + 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', +) +ECR_REPOSITORY = os.getenv('ECR_REPOSITORY', '1234567890.dkr.ecr.us-west-2.amazonaws.com/process_data') +REGION = ECR_REPOSITORY.split('.')[3] -# A Sample dataset hosted by UC Irvine's machine learning repository -DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data' +# For this example we are using a subset of Fischer's Iris Data Set. +# The full dataset can be found at UC Irvine's machine learning repository: +# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data +DATASET = """ + 5.1,3.5,1.4,0.2,Iris-setosa + 4.9,3.0,1.4,0.2,Iris-setosa + 7.0,3.2,4.7,1.4,Iris-versicolor + 6.4,3.2,4.5,1.5,Iris-versicolor + 4.9,2.5,4.5,1.7,Iris-virginica + 7.3,2.9,6.3,1.8,Iris-virginica + """ +SAMPLE_SIZE = DATASET.count('\n') - 1 # The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR # repo cited in public SageMaker documentation, so the account number does not need to be redacted. # For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' -# Define configs for training, model creation, and batch transform jobs +TASK_TIMEOUT = {'MaxRuntimeInSeconds': 6 * 60} + +RESOURCE_CONFIG = { + 'InstanceCount': 1, + 'InstanceType': 'ml.m5.large', + 'VolumeSizeInGB': 1, +} + +TRAINING_DATA_SOURCE = { + 'CompressionType': 'None', + 'ContentType': 'text/csv', + 'DataSource': { # type: ignore + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', + } + }, +} + +# Define configs for processing, training, model creation, and batch transform jobs +SAGEMAKER_PROCESSING_JOB_CONFIG = { + 'ProcessingJobName': PROCESSING_JOB_NAME, + 'RoleArn': f'{ROLE_ARN}', + 'ProcessingInputs': [ + { + 'InputName': 'input', + 'AppManaged': False, + 'S3Input': { + 'S3Uri': f's3://{S3_BUCKET}/{RAW_DATA_S3_KEY}', + 'LocalPath': PROCESSING_LOCAL_INPUT_PATH, + 'S3DataType': 'S3Prefix', + 'S3InputMode': 'File', + 'S3DataDistributionType': 'FullyReplicated', + 'S3CompressionType': 'None', + }, + }, + ], + 'ProcessingOutputConfig': { + 'Outputs': [ + { + 'OutputName': 'output', + 'S3Output': { + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}', + 'LocalPath': PROCESSING_LOCAL_OUTPUT_PATH, + 'S3UploadMode': 'EndOfJob', + }, + 'AppManaged': False, + } + ] + }, + 'ProcessingResources': { + 'ClusterConfig': RESOURCE_CONFIG, + }, + 'StoppingCondition': TASK_TIMEOUT, + 'AppSpecification': { + 'ImageUri': ECR_REPOSITORY, + }, +} + TRAINING_CONFIG = { - "AlgorithmSpecification": { + 'TrainingJobName': TRAINING_JOB_NAME, + 'RoleArn': ROLE_ARN, + 'AlgorithmSpecification': { "TrainingImage": KNN_IMAGE_URI, "TrainingInputMode": "File", }, - "HyperParameters": { - "predictor_type": "classifier", - "feature_dim": "4", - "k": "3", - "sample_size": "150", + 'HyperParameters': { + 'predictor_type': 'classifier', + 'feature_dim': '4', + 'k': '3', + 'sample_size': str(SAMPLE_SIZE), }, - "InputDataConfig": [ + 'InputDataConfig': [ { - "ChannelName": "train", - **TRAINING_DATA_SOURCE, + 'ChannelName': 'train', + **TRAINING_DATA_SOURCE, # type: ignore [arg-type] } ], - "OutputDataConfig": {"S3OutputPath": f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, - "ResourceConfig": RESOURCE_CONFIG, - "RoleArn": ROLE_ARN, - "StoppingCondition": {"MaxRuntimeInSeconds": 6000}, - "TrainingJobName": TRAINING_JOB_NAME, + 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, + 'ResourceConfig': RESOURCE_CONFIG, + 'StoppingCondition': TASK_TIMEOUT, } MODEL_CONFIG = { - "ExecutionRoleArn": ROLE_ARN, - "ModelName": MODEL_NAME, - "PrimaryContainer": { - "Mode": "SingleModel", - "Image": KNN_IMAGE_URI, - "ModelDataUrl": f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', + 'ModelName': MODEL_NAME, + 'ExecutionRoleArn': ROLE_ARN, + 'PrimaryContainer': { + 'Mode': 'SingleModel', + 'Image': KNN_IMAGE_URI, + 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', }, } TRANSFORM_CONFIG = { - "TransformJobName": TRANSFORM_JOB_NAME, - "TransformInput": { - "DataSource": { - "S3DataSource": { - "S3DataType": "S3Prefix", - "S3Uri": f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv', + 'TransformJobName': TRANSFORM_JOB_NAME, + 'ModelName': MODEL_NAME, + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv', } }, - "SplitType": "Line", - "ContentType": "text/csv", + 'SplitType': 'Line', + 'ContentType': 'text/csv', }, - "TransformOutput": {"S3OutputPath": f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'}, - "TransformResources": { - "InstanceCount": 1, - "InstanceType": "ml.m5.large", + 'TransformOutput': {'S3OutputPath': f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'}, + 'TransformResources': { + 'InstanceCount': 1, + 'InstanceType': 'ml.m5.large', }, - "ModelName": MODEL_NAME, } TUNING_CONFIG = { - "HyperParameterTuningJobName": TUNING_JOB_NAME, - "HyperParameterTuningJobConfig": { - "Strategy": "Bayesian", - "HyperParameterTuningJobObjective": { - "MetricName": "test:accuracy", - "Type": "Maximize", + 'HyperParameterTuningJobName': TUNING_JOB_NAME, + 'HyperParameterTuningJobConfig': { + 'Strategy': 'Bayesian', + 'HyperParameterTuningJobObjective': { + 'MetricName': 'test:accuracy', + 'Type': 'Maximize', }, - "ResourceLimits": { + 'ResourceLimits': { # You would bump these up in production as appropriate. - "MaxNumberOfTrainingJobs": 1, - "MaxParallelTrainingJobs": 1, + 'MaxNumberOfTrainingJobs': 1, + 'MaxParallelTrainingJobs': 1, }, - "ParameterRanges": { - "CategoricalParameterRanges": [], - "IntegerParameterRanges": [ + 'ParameterRanges': { + 'CategoricalParameterRanges': [], + 'IntegerParameterRanges': [ # Set the min and max values of the hyperparameters you want to tune. { - "Name": "k", - "MinValue": "1", - "MaxValue": "1024", + 'Name': 'k', + 'MinValue': '1', + "MaxValue": str(SAMPLE_SIZE), }, { - "Name": "sample_size", - "MinValue": "100", - "MaxValue": "2000", + 'Name': 'sample_size', + 'MinValue': '1', + 'MaxValue': str(SAMPLE_SIZE), }, ], }, }, - "TrainingJobDefinition": { - "StaticHyperParameters": { - "predictor_type": "classifier", - "feature_dim": "4", + 'TrainingJobDefinition': { + 'StaticHyperParameters': { + 'predictor_type': 'classifier', + 'feature_dim': '4', }, - "AlgorithmSpecification": {"TrainingImage": KNN_IMAGE_URI, "TrainingInputMode": "File"}, - "InputDataConfig": [ + 'AlgorithmSpecification': {'TrainingImage': KNN_IMAGE_URI, 'TrainingInputMode': 'File'}, + 'InputDataConfig': [ { - "ChannelName": "train", - **TRAINING_DATA_SOURCE, + 'ChannelName': 'train', + **TRAINING_DATA_SOURCE, # type: ignore [arg-type] }, { - "ChannelName": "test", - **TRAINING_DATA_SOURCE, + 'ChannelName': 'test', + **TRAINING_DATA_SOURCE, # type: ignore [arg-type] }, ], - "OutputDataConfig": {"S3OutputPath": f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'}, - "ResourceConfig": RESOURCE_CONFIG, - "RoleArn": ROLE_ARN, - "StoppingCondition": {"MaxRuntimeInSeconds": 600}, + 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'}, + 'ResourceConfig': RESOURCE_CONFIG, + 'StoppingCondition': TASK_TIMEOUT, + 'RoleArn': ROLE_ARN, }, } -@task -def data_prep(data_url, s3_bucket, input_s3_key): - """ - Grabs the Iris dataset from API, splits into train/test splits, and saves CSV's to S3 using S3 Hook +# This script will be the entrypoint for the docker image which will handle preprocessing the raw data +# NOTE: The following string must remain dedented as it is being written to a file. +PREPROCESS_SCRIPT = ( """ - # Get data from API - iris_response = requests.get(data_url).content +import boto3 +import numpy as np +import pandas as pd + +def main(): + # Load the Iris dataset from {input_path}/input.csv, split it into train/test + # subsets, and write them to {output_path}/ for the Processing Operator. + columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'] - iris = pd.read_csv(io.StringIO(iris_response.decode('utf-8')), names=columns) + iris = pd.read_csv('{input_path}/input.csv', names=columns) # Process data - iris['species'] = iris['species'].replace({'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}) + iris['species'] = iris['species'].replace({{'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}}) iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']] # Split into test and train data iris_train, iris_test = np.split( iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))] ) + + # Remove the "answers" from the test set iris_test.drop(['species'], axis=1, inplace=True) - # Save files to S3 - iris_train.to_csv('iris_train.csv', index=False, header=False) - iris_test.to_csv('iris_test.csv', index=False, header=False) - s3_hook = S3Hook(aws_conn_id='aws-sagemaker') - s3_hook.load_file( - 'iris_train.csv', - f'{input_s3_key}/train.csv', - bucket_name=s3_bucket, - replace=True, - ) - s3_hook.load_file( - 'iris_test.csv', - f'{input_s3_key}/test.csv', - bucket_name=s3_bucket, + # Write the splits to disk + iris_train.to_csv('{output_path}/train.csv', index=False, header=False) + iris_test.to_csv('{output_path}/test.csv', index=False, header=False) + + print('Preprocessing Done.') + +if __name__ == "__main__": + main() + + """ +).format(input_path=PROCESSING_LOCAL_INPUT_PATH, output_path=PROCESSING_LOCAL_OUTPUT_PATH) + + +@task +def upload_dataset_to_s3(): + """Uploads the provided dataset to a designated Amazon S3 bucket.""" + S3Hook().load_string( + string_data=DATASET, + bucket_name=S3_BUCKET, + key=RAW_DATA_S3_KEY, replace=True, ) +@task +def build_and_upload_docker_image(): + """ + We need a Docker image with the following requirements: + - Has numpy, pandas, requests, and boto3 installed + - Has our data preprocessing script mounted and set as the entry point + """ + + # Fetch and parse ECR Token to be used for the docker push + ecr_client = boto3.client('ecr', region_name=REGION) + token = ecr_client.get_authorization_token() + credentials = (base64.b64decode(token['authorizationData'][0]['authorizationToken'])).decode('utf-8') + username, password = credentials.split(':') + + with NamedTemporaryFile(mode='w+t') as preprocessing_script, NamedTemporaryFile(mode='w+t') as dockerfile: + + preprocessing_script.write(PREPROCESS_SCRIPT) + preprocessing_script.flush() + + dockerfile.write( + f""" + FROM amazonlinux + COPY {preprocessing_script.name.split('/')[2]} /preprocessing.py + ADD credentials /credentials + ENV AWS_SHARED_CREDENTIALS_FILE=/credentials + RUN yum install python3 pip -y + RUN pip3 install boto3 pandas requests + CMD [ "python3", "/preprocessing.py"] + """ + ) + dockerfile.flush() + + docker_build_and_push_commands = f""" + cp /root/.aws/credentials /tmp/credentials && + docker build -f {dockerfile.name} -t {ECR_REPOSITORY} /tmp && + rm /tmp/credentials && + aws ecr get-login-password --region {REGION} | + docker login --username {username} --password {password} {ECR_REPOSITORY} && + docker push {ECR_REPOSITORY} + """ + docker_build = subprocess.Popen( + docker_build_and_push_commands, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _, err = docker_build.communicate() + + if docker_build.returncode != 0: + raise RuntimeError(err) + + with DAG( dag_id='example_sagemaker', schedule_interval=None, @@ -239,6 +364,14 @@ def data_prep(data_url, s3_bucket, input_s3_key): catchup=False, ) as dag: + # [START howto_operator_sagemaker_processing] + preprocess_raw_data = SageMakerProcessingOperator( + task_id='preprocess_raw_data', + config=SAGEMAKER_PROCESSING_JOB_CONFIG, + do_xcom_push=False, + ) + # [END howto_operator_sagemaker_processing] + # [START howto_operator_sagemaker_training] train_model = SageMakerTrainingOperator( task_id='train_model', @@ -251,7 +384,7 @@ def data_prep(data_url, s3_bucket, input_s3_key): # [START howto_operator_sagemaker_training_sensor] await_training = SageMakerTrainingSensor( - task_id="await_training", + task_id='await_training', job_name=TRAINING_JOB_NAME, ) # [END howto_operator_sagemaker_training_sensor] @@ -266,7 +399,7 @@ def data_prep(data_url, s3_bucket, input_s3_key): # [START howto_operator_sagemaker_tuning] tune_model = SageMakerTuningOperator( - task_id="tune_model", + task_id='tune_model', config=TUNING_CONFIG, # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, @@ -276,7 +409,7 @@ def data_prep(data_url, s3_bucket, input_s3_key): # [START howto_operator_sagemaker_tuning_sensor] await_tune = SageMakerTuningSensor( - task_id="await_tuning", + task_id='await_tuning', job_name=TUNING_JOB_NAME, ) # [END howto_operator_sagemaker_tuning_sensor] @@ -293,21 +426,24 @@ def data_prep(data_url, s3_bucket, input_s3_key): # [START howto_operator_sagemaker_transform_sensor] await_transform = SageMakerTransformSensor( - task_id="await_transform", + task_id='await_transform', job_name=TRANSFORM_JOB_NAME, ) # [END howto_operator_sagemaker_transform_sensor] + # Trigger rule set to "all_done" so clean up will run regardless of success on other tasks. # [START howto_operator_sagemaker_delete_model] delete_model = SageMakerDeleteModelOperator( - task_id="delete_model", + task_id='delete_model', config={'ModelName': MODEL_NAME}, trigger_rule='all_done', ) # [END howto_operator_sagemaker_delete_model] ( - data_prep(DATA_URL, S3_BUCKET, INPUT_DATA_S3_KEY) + upload_dataset_to_s3() + >> build_and_upload_docker_image() + >> preprocess_raw_data >> train_model >> await_training >> create_model diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index b7c858af63591..650cffd228651 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -111,6 +111,10 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): experience on SageMaker to run your data processing workloads, such as feature engineering, data validation, model evaluation, and model interpretation. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerProcessingOperator` + :param config: The configuration necessary to start a processing job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job` :param wait_for_completion: If wait is set to True, the time interval, in seconds, diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst index 68e737d859aa9..74db759450bfa 100644 --- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst +++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst @@ -33,6 +33,21 @@ Prerequisite Tasks Manage Amazon SageMaker Jobs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. _howto/operator:SageMakerProcessingOperator: + +Create an Amazon SageMaker Processing Job +""""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker processing job to sanitize your dataset you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_processing] + :end-before: [END howto_operator_sagemaker_processing] + + .. _howto/operator:SageMakerTrainingOperator: Create an Amazon SageMaker Training Job From 86b4fba8f3e530bd3848eeef7f5d45f530bedf0a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 26 Apr 2022 18:11:44 -0700 Subject: [PATCH 4/6] Adds cleanup script to the SageMaker sample DAG --- .../amazon/aws/example_dags/example_sagemaker.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py index 73b4a85f87cdd..df69013e1c473 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py @@ -356,6 +356,17 @@ def build_and_upload_docker_image(): raise RuntimeError(err) +@task(trigger_rule='all_done') +def cleanup(): + # Delete S3 Artifacts + client = boto3.client('s3') + object_keys = [ + key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] + ] + for key in object_keys: + client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]}) + + with DAG( dag_id='example_sagemaker', schedule_interval=None, @@ -451,5 +462,6 @@ def build_and_upload_docker_image(): >> await_tune >> test_model >> await_transform + >> cleanup() >> delete_model ) From e1c2581293047760070af9e7c0b7c6cbe4210838 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 25 Apr 2022 16:28:27 -0700 Subject: [PATCH 5/6] Adds a new sample DAG for SageMaker Endpoint and updates the existing sagemaker.rst --- .../example_sagemaker_endpoint.py | 230 ++++++++++++++++++ .../amazon/aws/operators/sagemaker.py | 8 + .../providers/amazon/aws/sensors/sagemaker.py | 4 + .../operators/sagemaker.rst | 42 ++++ 4 files changed, 284 insertions(+) create mode 100644 airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py new file mode 100644 index 0000000000000..b4207a9b163fc --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os +from datetime import datetime + +import boto3 + +from airflow import DAG +from airflow.decorators import task +from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator +from airflow.providers.amazon.aws.operators.sagemaker import ( + SageMakerDeleteModelOperator, + SageMakerEndpointConfigOperator, + SageMakerEndpointOperator, + SageMakerModelOperator, + SageMakerTrainingOperator, +) +from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor + +# Project name will be used in naming the S3 buckets and various tasks. +# The dataset used in this example is identifying varieties of the Iris flower. +PROJECT_NAME = 'iris' +TIMESTAMP = '{{ ts_nodash }}' + +S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') +ROLE_ARN = os.getenv( + 'SAGEMAKER_ROLE_ARN', + 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', +) +INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' +TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/training-results' + +MODEL_NAME = f'{PROJECT_NAME}-KNN-model' +ENDPOINT_NAME = f'{PROJECT_NAME}-endpoint' +# Job names can't be reused, so appending a timestamp ensures it is unique. +ENDPOINT_CONFIG_JOB_NAME = f'{PROJECT_NAME}-endpoint-config-{TIMESTAMP}' +TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' + +# For an example of how to obtain the following train and test data, please see +# https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +TRAIN_DATA = '0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n' +SAMPLE_TEST_DATA = '6.4,3.2,4.5,1.5' + +# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR +# repo cited in public SageMaker documentation, so the account number does not need to be redacted. +# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title +KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' + +# Define configs for processing, training, model creation, and batch transform jobs +TRAINING_CONFIG = { + 'TrainingJobName': TRAINING_JOB_NAME, + 'RoleArn': ROLE_ARN, + 'AlgorithmSpecification': { + "TrainingImage": KNN_IMAGE_URI, + "TrainingInputMode": "File", + }, + 'HyperParameters': { + 'predictor_type': 'classifier', + 'feature_dim': '4', + 'k': '3', + 'sample_size': '6', + }, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'CompressionType': 'None', + 'ContentType': 'text/csv', + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', + } + }, + } + ], + 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.m5.large', + 'VolumeSizeInGB': 1, + }, + 'StoppingCondition': {'MaxRuntimeInSeconds': 6 * 60}, +} + +MODEL_CONFIG = { + 'ModelName': MODEL_NAME, + 'ExecutionRoleArn': ROLE_ARN, + 'PrimaryContainer': { + 'Mode': 'SingleModel', + 'Image': KNN_IMAGE_URI, + 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', + }, +} + +ENDPOINT_CONFIG_CONFIG = { + 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, + 'ProductionVariants': [ + { + 'VariantName': f'{PROJECT_NAME}-demo', + 'ModelName': MODEL_NAME, + 'InstanceType': 'ml.t2.medium', + 'InitialInstanceCount': 1, + }, + ], +} + +DEPLOY_ENDPOINT_CONFIG = { + 'EndpointName': ENDPOINT_NAME, + 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, +} + + +@task +def call_endpoint(): + runtime = boto3.Session().client('sagemaker-runtime') + + response = runtime.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + ContentType='text/csv', + Body=SAMPLE_TEST_DATA, + ) + + return json.loads(response["Body"].read().decode())['predictions'] + + +@task(trigger_rule='all_done') +def cleanup(): + # Delete Endpoint and Endpoint Config + client = boto3.client('sagemaker') + endpoint_config_name = client.list_endpoint_configs()['EndpointConfigs'][0]['EndpointConfigName'] + client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + client.delete_endpoint(EndpointName=ENDPOINT_NAME) + + # Delete S3 Artifacts + client = boto3.client('s3') + object_keys = [ + key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] + ] + for key in object_keys: + client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]}) + + +with DAG( + dag_id='example_sagemaker_endpoint', + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as dag: + + upload_data = S3CreateObjectOperator( + task_id='upload_data', + s3_bucket=S3_BUCKET, + s3_key=f'{INPUT_DATA_S3_KEY}/train.csv', + data=TRAIN_DATA, + replace=True, + ) + + train_model = SageMakerTrainingOperator( + task_id='train_model', + config=TRAINING_CONFIG, + do_xcom_push=False, + ) + + create_model = SageMakerModelOperator( + task_id='create_model', + config=MODEL_CONFIG, + do_xcom_push=False, + ) + + # [START howto_operator_sagemaker_endpoint_config] + configure_endpoint = SageMakerEndpointConfigOperator( + task_id='configure_endpoint', + config=ENDPOINT_CONFIG_CONFIG, + do_xcom_push=False, + ) + # [END howto_operator_sagemaker_endpoint_config] + + # [START howto_operator_sagemaker_endpoint] + deploy_endpoint = SageMakerEndpointOperator( + task_id='deploy_endpoint', + config=DEPLOY_ENDPOINT_CONFIG, + # Waits by default, > train_model + >> create_model + >> configure_endpoint + >> deploy_endpoint + >> await_endpoint + >> call_endpoint() + >> cleanup() + >> delete_model + ) diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 650cffd228651..11be2e7a83c50 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -194,6 +194,10 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator): one or more models, created using the CreateModel API, to deploy and the resources that you want Amazon SageMaker to provision. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerEndpointConfigOperator` + :param config: The configuration necessary to create an endpoint config. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config` @@ -226,6 +230,10 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): Requires an Endpoint Config. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerEndpointOperator` + :param config: The configuration necessary to create an endpoint. diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index c4e79c39be402..3cf6dceef154c 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -88,6 +88,10 @@ class SageMakerEndpointSensor(SageMakerBaseSensor): Polls the endpoint state until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerEndpointSensor` + :param endpoint_name: Name of the endpoint instance to watch. """ diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst index 74db759450bfa..a31527d4ca305 100644 --- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst +++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst @@ -118,6 +118,34 @@ To create an Amazon Sagemaker transform job you can use :start-after: [START howto_operator_sagemaker_transform] :end-before: [END howto_operator_sagemaker_transform] +.. _howto/operator:SageMakerEndpointConfigOperator: + +Create an Amazon SageMaker Endpoint Config Job +"""""""""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker endpoint config job you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointConfigOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_endpoint_config] + :end-before: [END howto_operator_sagemaker_endpoint_config] + +.. _howto/operator:SageMakerEndpointOperator: + +Create an Amazon SageMaker Endpoint Job +""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker endpoint you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_endpoint] + :end-before: [END howto_operator_sagemaker_endpoint] + Amazon SageMaker Sensors ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -164,6 +192,20 @@ you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTun :start-after: [START howto_operator_sagemaker_tuning_sensor] :end-before: [END howto_operator_sagemaker_tuning_sensor] +.. _howto/sensor:SageMakerEndpointSensor: + +Amazon SageMaker Endpoint Sensor +"""""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker hyperparameter tuning job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerEndpointSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_endpoint_sensor] + :end-before: [END howto_operator_sagemaker_endpoint_sensor] + Reference ^^^^^^^^^ From 63d202492b25b2362e52d38d18ed473e4a1d85b6 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 26 Apr 2022 13:20:04 -0700 Subject: [PATCH 6/6] Fixes SageMaker operator return values --- .../aws/example_dags/example_sagemaker.py | 5 ---- .../example_sagemaker_endpoint.py | 3 -- .../amazon/aws/operators/sagemaker.py | 29 +++++++++++++------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py index df69013e1c473..0ad5cf496ddc7 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py @@ -379,7 +379,6 @@ def cleanup(): preprocess_raw_data = SageMakerProcessingOperator( task_id='preprocess_raw_data', config=SAGEMAKER_PROCESSING_JOB_CONFIG, - do_xcom_push=False, ) # [END howto_operator_sagemaker_processing] @@ -389,7 +388,6 @@ def cleanup(): config=TRAINING_CONFIG, # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_training] @@ -404,7 +402,6 @@ def cleanup(): create_model = SageMakerModelOperator( task_id='create_model', config=MODEL_CONFIG, - do_xcom_push=False, ) # [END howto_operator_sagemaker_model] @@ -414,7 +411,6 @@ def cleanup(): config=TUNING_CONFIG, # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_tuning] @@ -431,7 +427,6 @@ def cleanup(): config=TRANSFORM_CONFIG, # Waits by default, setting as False to demonstrate the Sensor below. wait_for_completion=False, - do_xcom_push=False, ) # [END howto_operator_sagemaker_transform] diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py index b4207a9b163fc..c0aaa2abef41d 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py @@ -175,20 +175,17 @@ def cleanup(): train_model = SageMakerTrainingOperator( task_id='train_model', config=TRAINING_CONFIG, - do_xcom_push=False, ) create_model = SageMakerModelOperator( task_id='create_model', config=MODEL_CONFIG, - do_xcom_push=False, ) # [START howto_operator_sagemaker_endpoint_config] configure_endpoint = SageMakerEndpointConfigOperator( task_id='configure_endpoint', config=ENDPOINT_CONFIG_CONFIG, - do_xcom_push=False, ) # [END howto_operator_sagemaker_endpoint_config] diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 11be2e7a83c50..b07664b9aa85f 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -17,7 +17,7 @@ import json import sys -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from botocore.exceptions import ClientError @@ -25,6 +25,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.utils.json import AirflowJsonEncoder if sys.version_info >= (3, 8): from functools import cached_property @@ -35,6 +36,10 @@ from airflow.utils.context import Context +def serialize(result: Dict) -> str: + return json.dumps(result, cls=AirflowJsonEncoder) + + class SageMakerBaseOperator(BaseOperator): """This is the base operator for all SageMaker operators. @@ -184,7 +189,7 @@ def execute(self, context: 'Context') -> dict: ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Processing Job creation failed: {response}') - return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])} + return {'Processing': serialize(self.hook.describe_processing_job(self.config['ProcessingJobName']))} class SageMakerEndpointConfigOperator(SageMakerBaseOperator): @@ -342,8 +347,10 @@ def execute(self, context: 'Context') -> dict: raise AirflowException(f'Sagemaker endpoint creation failed: {response}') else: return { - 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']), - 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']), + 'EndpointConfig': serialize( + self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']) + ), + 'Endpoint': serialize(self.hook.describe_endpoint(endpoint_info['EndpointName'])), } @@ -437,8 +444,10 @@ def execute(self, context: 'Context') -> dict: raise AirflowException(f'Sagemaker transform Job creation failed: {response}') else: return { - 'Model': self.hook.describe_model(transform_config['ModelName']), - 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']), + 'Model': serialize(self.hook.describe_model(transform_config['ModelName'])), + 'Transform': serialize( + self.hook.describe_transform_job(transform_config['TransformJobName']) + ), } @@ -511,7 +520,9 @@ def execute(self, context: 'Context') -> dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}') else: - return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])} + return { + 'Tuning': serialize(self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])) + } class SageMakerModelOperator(SageMakerBaseOperator): @@ -547,7 +558,7 @@ def execute(self, context: 'Context') -> dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker model creation failed: {response}') else: - return {'Model': self.hook.describe_model(self.config['ModelName'])} + return {'Model': serialize(self.hook.describe_model(self.config['ModelName']))} class SageMakerTrainingOperator(SageMakerBaseOperator): @@ -630,7 +641,7 @@ def execute(self, context: 'Context') -> dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Training Job creation failed: {response}') else: - return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} + return {'Training': serialize(self.hook.describe_training_job(self.config['TrainingJobName']))} def _check_if_job_exists(self) -> None: training_job_name = self.config['TrainingJobName']