diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py index 6162b061db465..0ad5cf496ddc7 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py @@ -14,164 +14,449 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import base64 +import os +import subprocess from datetime import datetime -from os import environ +from tempfile import NamedTemporaryFile + +import boto3 from airflow import DAG +from airflow.decorators import task +from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerDeleteModelOperator, SageMakerModelOperator, SageMakerProcessingOperator, SageMakerTrainingOperator, SageMakerTransformOperator, + SageMakerTuningOperator, +) +from airflow.providers.amazon.aws.sensors.sagemaker import ( + SageMakerTrainingSensor, + SageMakerTransformSensor, + SageMakerTuningSensor, +) + +# Project name will be used in naming the S3 buckets and various tasks. +# The dataset used in this example is identifying varieties of the Iris flower. +PROJECT_NAME = 'iris' +TIMESTAMP = '{{ ts_nodash }}' + +S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') +RAW_DATA_S3_KEY = f'{PROJECT_NAME}/preprocessing/input.csv' +INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' +TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/results' +PREDICTION_OUTPUT_S3_KEY = f'{PROJECT_NAME}/transform' + +PROCESSING_LOCAL_INPUT_PATH = '/opt/ml/processing/input' +PROCESSING_LOCAL_OUTPUT_PATH = '/opt/ml/processing/output' + +MODEL_NAME = f'{PROJECT_NAME}-KNN-model' +# Job names can't be reused, so appending a timestamp ensures it is unique. +PROCESSING_JOB_NAME = f'{PROJECT_NAME}-processing-{TIMESTAMP}' +TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' +TRANSFORM_JOB_NAME = f'{PROJECT_NAME}-transform-{TIMESTAMP}' +TUNING_JOB_NAME = f'{PROJECT_NAME}-tune-{TIMESTAMP}' + +ROLE_ARN = os.getenv( + 'SAGEMAKER_ROLE_ARN', + 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', ) +ECR_REPOSITORY = os.getenv('ECR_REPOSITORY', '1234567890.dkr.ecr.us-west-2.amazonaws.com/process_data') +REGION = ECR_REPOSITORY.split('.')[3] + +# For this example we are using a subset of Fischer's Iris Data Set. +# The full dataset can be found at UC Irvine's machine learning repository: +# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data +DATASET = """ + 5.1,3.5,1.4,0.2,Iris-setosa + 4.9,3.0,1.4,0.2,Iris-setosa + 7.0,3.2,4.7,1.4,Iris-versicolor + 6.4,3.2,4.5,1.5,Iris-versicolor + 4.9,2.5,4.5,1.7,Iris-virginica + 7.3,2.9,6.3,1.8,Iris-virginica + """ +SAMPLE_SIZE = DATASET.count('\n') - 1 -MODEL_NAME = "sample_model" -TRAINING_JOB_NAME = "sample_training" -IMAGE_URI = environ.get("ECR_IMAGE_URI", "123456789012.dkr.ecr.us-east-1.amazonaws.com/repo_name") -S3_BUCKET = environ.get("BUCKET_NAME", "test-airflow-12345") -ROLE = environ.get("SAGEMAKER_ROLE_ARN", "arn:aws:iam::123456789012:role/role_name") +# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR +# repo cited in public SageMaker documentation, so the account number does not need to be redacted. +# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title +KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' +TASK_TIMEOUT = {'MaxRuntimeInSeconds': 6 * 60} + +RESOURCE_CONFIG = { + 'InstanceCount': 1, + 'InstanceType': 'ml.m5.large', + 'VolumeSizeInGB': 1, +} + +TRAINING_DATA_SOURCE = { + 'CompressionType': 'None', + 'ContentType': 'text/csv', + 'DataSource': { # type: ignore + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', + } + }, +} + +# Define configs for processing, training, model creation, and batch transform jobs SAGEMAKER_PROCESSING_JOB_CONFIG = { - "ProcessingJobName": "sample_processing_job", - "ProcessingInputs": [ + 'ProcessingJobName': PROCESSING_JOB_NAME, + 'RoleArn': f'{ROLE_ARN}', + 'ProcessingInputs': [ { - "InputName": "input", - "AppManaged": False, - "S3Input": { - "S3Uri": f"s3://{S3_BUCKET}/preprocessing/input/", - "LocalPath": "/opt/ml/processing/input/", - "S3DataType": "S3Prefix", - "S3InputMode": "File", - "S3DataDistributionType": "FullyReplicated", - "S3CompressionType": "None", + 'InputName': 'input', + 'AppManaged': False, + 'S3Input': { + 'S3Uri': f's3://{S3_BUCKET}/{RAW_DATA_S3_KEY}', + 'LocalPath': PROCESSING_LOCAL_INPUT_PATH, + 'S3DataType': 'S3Prefix', + 'S3InputMode': 'File', + 'S3DataDistributionType': 'FullyReplicated', + 'S3CompressionType': 'None', }, }, ], - "ProcessingOutputConfig": { - "Outputs": [ + 'ProcessingOutputConfig': { + 'Outputs': [ { - "OutputName": "output", - "S3Output": { - "S3Uri": f"s3://{S3_BUCKET}/preprocessing/output/", - "LocalPath": "/opt/ml/processing/output/", - "S3UploadMode": "EndOfJob", + 'OutputName': 'output', + 'S3Output': { + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}', + 'LocalPath': PROCESSING_LOCAL_OUTPUT_PATH, + 'S3UploadMode': 'EndOfJob', }, - "AppManaged": False, + 'AppManaged': False, } ] }, - "ProcessingResources": { - "ClusterConfig": { - "InstanceCount": 1, - "InstanceType": "ml.m5.large", - "VolumeSizeInGB": 5, - } + 'ProcessingResources': { + 'ClusterConfig': RESOURCE_CONFIG, }, - "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, - "AppSpecification": { - "ImageUri": f"{IMAGE_URI}", - "ContainerEntrypoint": ["python3", "./preprocessing.py"], + 'StoppingCondition': TASK_TIMEOUT, + 'AppSpecification': { + 'ImageUri': ECR_REPOSITORY, }, - "RoleArn": f"{ROLE}", } -SAGEMAKER_TRAINING_JOB_CONFIG = { - "AlgorithmSpecification": { - "TrainingImage": f"{IMAGE_URI}", +TRAINING_CONFIG = { + 'TrainingJobName': TRAINING_JOB_NAME, + 'RoleArn': ROLE_ARN, + 'AlgorithmSpecification': { + "TrainingImage": KNN_IMAGE_URI, "TrainingInputMode": "File", }, - "InputDataConfig": [ - { - "ChannelName": "config", - "DataSource": { - "S3DataSource": { - "S3DataType": "S3Prefix", - "S3Uri": f"s3://{S3_BUCKET}/config/", - "S3DataDistributionType": "FullyReplicated", - } - }, - "CompressionType": "None", - "RecordWrapperType": "None", - }, - ], - "OutputDataConfig": { - "KmsKeyId": "", - "S3OutputPath": f"s3://{S3_BUCKET}/training/", - }, - "ResourceConfig": { - "InstanceType": "ml.m5.large", - "InstanceCount": 1, - "VolumeSizeInGB": 5, + 'HyperParameters': { + 'predictor_type': 'classifier', + 'feature_dim': '4', + 'k': '3', + 'sample_size': str(SAMPLE_SIZE), }, - "StoppingCondition": {"MaxRuntimeInSeconds": 6000}, - "RoleArn": f"{ROLE}", - "EnableNetworkIsolation": False, - "EnableInterContainerTrafficEncryption": False, - "EnableManagedSpotTraining": False, - "TrainingJobName": TRAINING_JOB_NAME, -} - -SAGEMAKER_CREATE_MODEL_CONFIG = { - "ModelName": MODEL_NAME, - "Containers": [ + 'InputDataConfig': [ { - "Image": f"{IMAGE_URI}", - "Mode": "SingleModel", - "ModelDataUrl": f"s3://{S3_BUCKET}/training/{TRAINING_JOB_NAME}/output/model.tar.gz", + 'ChannelName': 'train', + **TRAINING_DATA_SOURCE, # type: ignore [arg-type] } ], - "ExecutionRoleArn": f"{ROLE}", - "EnableNetworkIsolation": False, + 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, + 'ResourceConfig': RESOURCE_CONFIG, + 'StoppingCondition': TASK_TIMEOUT, } -SAGEMAKER_INFERENCE_CONFIG = { - "TransformJobName": "sample_transform_job", - "ModelName": MODEL_NAME, - "TransformInput": { - "DataSource": { - "S3DataSource": { - "S3DataType": "S3Prefix", - "S3Uri": f"s3://{S3_BUCKET}/config/config_date.yml", +MODEL_CONFIG = { + 'ModelName': MODEL_NAME, + 'ExecutionRoleArn': ROLE_ARN, + 'PrimaryContainer': { + 'Mode': 'SingleModel', + 'Image': KNN_IMAGE_URI, + 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', + }, +} + +TRANSFORM_CONFIG = { + 'TransformJobName': TRANSFORM_JOB_NAME, + 'ModelName': MODEL_NAME, + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv', } }, - "ContentType": "application/x-yaml", - "CompressionType": "None", - "SplitType": "None", + 'SplitType': 'Line', + 'ContentType': 'text/csv', + }, + 'TransformOutput': {'S3OutputPath': f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'}, + 'TransformResources': { + 'InstanceCount': 1, + 'InstanceType': 'ml.m5.large', }, - "TransformOutput": {"S3OutputPath": f"s3://{S3_BUCKET}/inferencing/output/"}, - "TransformResources": {"InstanceType": "ml.m5.large", "InstanceCount": 1}, } -# [START howto_operator_sagemaker] +TUNING_CONFIG = { + 'HyperParameterTuningJobName': TUNING_JOB_NAME, + 'HyperParameterTuningJobConfig': { + 'Strategy': 'Bayesian', + 'HyperParameterTuningJobObjective': { + 'MetricName': 'test:accuracy', + 'Type': 'Maximize', + }, + 'ResourceLimits': { + # You would bump these up in production as appropriate. + 'MaxNumberOfTrainingJobs': 1, + 'MaxParallelTrainingJobs': 1, + }, + 'ParameterRanges': { + 'CategoricalParameterRanges': [], + 'IntegerParameterRanges': [ + # Set the min and max values of the hyperparameters you want to tune. + { + 'Name': 'k', + 'MinValue': '1', + "MaxValue": str(SAMPLE_SIZE), + }, + { + 'Name': 'sample_size', + 'MinValue': '1', + 'MaxValue': str(SAMPLE_SIZE), + }, + ], + }, + }, + 'TrainingJobDefinition': { + 'StaticHyperParameters': { + 'predictor_type': 'classifier', + 'feature_dim': '4', + }, + 'AlgorithmSpecification': {'TrainingImage': KNN_IMAGE_URI, 'TrainingInputMode': 'File'}, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + **TRAINING_DATA_SOURCE, # type: ignore [arg-type] + }, + { + 'ChannelName': 'test', + **TRAINING_DATA_SOURCE, # type: ignore [arg-type] + }, + ], + 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'}, + 'ResourceConfig': RESOURCE_CONFIG, + 'StoppingCondition': TASK_TIMEOUT, + 'RoleArn': ROLE_ARN, + }, +} + + +# This script will be the entrypoint for the docker image which will handle preprocessing the raw data +# NOTE: The following string must remain dedented as it is being written to a file. +PREPROCESS_SCRIPT = ( + """ +import boto3 +import numpy as np +import pandas as pd + +def main(): + # Load the Iris dataset from {input_path}/input.csv, split it into train/test + # subsets, and write them to {output_path}/ for the Processing Operator. + + columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'] + iris = pd.read_csv('{input_path}/input.csv', names=columns) + + # Process data + iris['species'] = iris['species'].replace({{'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}}) + iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']] + + # Split into test and train data + iris_train, iris_test = np.split( + iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))] + ) + + # Remove the "answers" from the test set + iris_test.drop(['species'], axis=1, inplace=True) + + # Write the splits to disk + iris_train.to_csv('{output_path}/train.csv', index=False, header=False) + iris_test.to_csv('{output_path}/test.csv', index=False, header=False) + + print('Preprocessing Done.') + +if __name__ == "__main__": + main() + + """ +).format(input_path=PROCESSING_LOCAL_INPUT_PATH, output_path=PROCESSING_LOCAL_OUTPUT_PATH) + + +@task +def upload_dataset_to_s3(): + """Uploads the provided dataset to a designated Amazon S3 bucket.""" + S3Hook().load_string( + string_data=DATASET, + bucket_name=S3_BUCKET, + key=RAW_DATA_S3_KEY, + replace=True, + ) + + +@task +def build_and_upload_docker_image(): + """ + We need a Docker image with the following requirements: + - Has numpy, pandas, requests, and boto3 installed + - Has our data preprocessing script mounted and set as the entry point + """ + + # Fetch and parse ECR Token to be used for the docker push + ecr_client = boto3.client('ecr', region_name=REGION) + token = ecr_client.get_authorization_token() + credentials = (base64.b64decode(token['authorizationData'][0]['authorizationToken'])).decode('utf-8') + username, password = credentials.split(':') + + with NamedTemporaryFile(mode='w+t') as preprocessing_script, NamedTemporaryFile(mode='w+t') as dockerfile: + + preprocessing_script.write(PREPROCESS_SCRIPT) + preprocessing_script.flush() + + dockerfile.write( + f""" + FROM amazonlinux + COPY {preprocessing_script.name.split('/')[2]} /preprocessing.py + ADD credentials /credentials + ENV AWS_SHARED_CREDENTIALS_FILE=/credentials + RUN yum install python3 pip -y + RUN pip3 install boto3 pandas requests + CMD [ "python3", "/preprocessing.py"] + """ + ) + dockerfile.flush() + + docker_build_and_push_commands = f""" + cp /root/.aws/credentials /tmp/credentials && + docker build -f {dockerfile.name} -t {ECR_REPOSITORY} /tmp && + rm /tmp/credentials && + aws ecr get-login-password --region {REGION} | + docker login --username {username} --password {password} {ECR_REPOSITORY} && + docker push {ECR_REPOSITORY} + """ + docker_build = subprocess.Popen( + docker_build_and_push_commands, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _, err = docker_build.communicate() + + if docker_build.returncode != 0: + raise RuntimeError(err) + + +@task(trigger_rule='all_done') +def cleanup(): + # Delete S3 Artifacts + client = boto3.client('s3') + object_keys = [ + key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] + ] + for key in object_keys: + client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]}) + + with DAG( - "sample_sagemaker_dag", + dag_id='example_sagemaker', schedule_interval=None, - start_date=datetime(2022, 2, 21), + start_date=datetime(2021, 1, 1), + tags=['example'], catchup=False, ) as dag: - sagemaker_processing_task = SageMakerProcessingOperator( + + # [START howto_operator_sagemaker_processing] + preprocess_raw_data = SageMakerProcessingOperator( + task_id='preprocess_raw_data', config=SAGEMAKER_PROCESSING_JOB_CONFIG, - aws_conn_id="aws_default", - task_id="sagemaker_preprocessing_task", ) + # [END howto_operator_sagemaker_processing] - training_task = SageMakerTrainingOperator( - config=SAGEMAKER_TRAINING_JOB_CONFIG, aws_conn_id="aws_default", task_id="sagemaker_training_task" + # [START howto_operator_sagemaker_training] + train_model = SageMakerTrainingOperator( + task_id='train_model', + config=TRAINING_CONFIG, + # Waits by default, setting as False to demonstrate the Sensor below. + wait_for_completion=False, ) + # [END howto_operator_sagemaker_training] - model_create_task = SageMakerModelOperator( - config=SAGEMAKER_CREATE_MODEL_CONFIG, aws_conn_id="aws_default", task_id="sagemaker_create_model_task" + # [START howto_operator_sagemaker_training_sensor] + await_training = SageMakerTrainingSensor( + task_id='await_training', + job_name=TRAINING_JOB_NAME, ) + # [END howto_operator_sagemaker_training_sensor] - inference_task = SageMakerTransformOperator( - config=SAGEMAKER_INFERENCE_CONFIG, aws_conn_id="aws_default", task_id="sagemaker_inference_task" + # [START howto_operator_sagemaker_model] + create_model = SageMakerModelOperator( + task_id='create_model', + config=MODEL_CONFIG, ) + # [END howto_operator_sagemaker_model] - model_delete_task = SageMakerDeleteModelOperator( - task_id="sagemaker_delete_model_task", config={'ModelName': MODEL_NAME}, aws_conn_id="aws_default" + # [START howto_operator_sagemaker_tuning] + tune_model = SageMakerTuningOperator( + task_id='tune_model', + config=TUNING_CONFIG, + # Waits by default, setting as False to demonstrate the Sensor below. + wait_for_completion=False, ) + # [END howto_operator_sagemaker_tuning] - sagemaker_processing_task >> training_task >> model_create_task >> inference_task >> model_delete_task - # [END howto_operator_sagemaker] + # [START howto_operator_sagemaker_tuning_sensor] + await_tune = SageMakerTuningSensor( + task_id='await_tuning', + job_name=TUNING_JOB_NAME, + ) + # [END howto_operator_sagemaker_tuning_sensor] + + # [START howto_operator_sagemaker_transform] + test_model = SageMakerTransformOperator( + task_id='test_model', + config=TRANSFORM_CONFIG, + # Waits by default, setting as False to demonstrate the Sensor below. + wait_for_completion=False, + ) + # [END howto_operator_sagemaker_transform] + + # [START howto_operator_sagemaker_transform_sensor] + await_transform = SageMakerTransformSensor( + task_id='await_transform', + job_name=TRANSFORM_JOB_NAME, + ) + # [END howto_operator_sagemaker_transform_sensor] + + # Trigger rule set to "all_done" so clean up will run regardless of success on other tasks. + # [START howto_operator_sagemaker_delete_model] + delete_model = SageMakerDeleteModelOperator( + task_id='delete_model', + config={'ModelName': MODEL_NAME}, + trigger_rule='all_done', + ) + # [END howto_operator_sagemaker_delete_model] + + ( + upload_dataset_to_s3() + >> build_and_upload_docker_image() + >> preprocess_raw_data + >> train_model + >> await_training + >> create_model + >> tune_model + >> await_tune + >> test_model + >> await_transform + >> cleanup() + >> delete_model + ) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py new file mode 100644 index 0000000000000..c0aaa2abef41d --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py @@ -0,0 +1,227 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os +from datetime import datetime + +import boto3 + +from airflow import DAG +from airflow.decorators import task +from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator +from airflow.providers.amazon.aws.operators.sagemaker import ( + SageMakerDeleteModelOperator, + SageMakerEndpointConfigOperator, + SageMakerEndpointOperator, + SageMakerModelOperator, + SageMakerTrainingOperator, +) +from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor + +# Project name will be used in naming the S3 buckets and various tasks. +# The dataset used in this example is identifying varieties of the Iris flower. +PROJECT_NAME = 'iris' +TIMESTAMP = '{{ ts_nodash }}' + +S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') +ROLE_ARN = os.getenv( + 'SAGEMAKER_ROLE_ARN', + 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', +) +INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' +TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/training-results' + +MODEL_NAME = f'{PROJECT_NAME}-KNN-model' +ENDPOINT_NAME = f'{PROJECT_NAME}-endpoint' +# Job names can't be reused, so appending a timestamp ensures it is unique. +ENDPOINT_CONFIG_JOB_NAME = f'{PROJECT_NAME}-endpoint-config-{TIMESTAMP}' +TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' + +# For an example of how to obtain the following train and test data, please see +# https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +TRAIN_DATA = '0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n' +SAMPLE_TEST_DATA = '6.4,3.2,4.5,1.5' + +# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR +# repo cited in public SageMaker documentation, so the account number does not need to be redacted. +# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title +KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' + +# Define configs for processing, training, model creation, and batch transform jobs +TRAINING_CONFIG = { + 'TrainingJobName': TRAINING_JOB_NAME, + 'RoleArn': ROLE_ARN, + 'AlgorithmSpecification': { + "TrainingImage": KNN_IMAGE_URI, + "TrainingInputMode": "File", + }, + 'HyperParameters': { + 'predictor_type': 'classifier', + 'feature_dim': '4', + 'k': '3', + 'sample_size': '6', + }, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'CompressionType': 'None', + 'ContentType': 'text/csv', + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', + } + }, + } + ], + 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.m5.large', + 'VolumeSizeInGB': 1, + }, + 'StoppingCondition': {'MaxRuntimeInSeconds': 6 * 60}, +} + +MODEL_CONFIG = { + 'ModelName': MODEL_NAME, + 'ExecutionRoleArn': ROLE_ARN, + 'PrimaryContainer': { + 'Mode': 'SingleModel', + 'Image': KNN_IMAGE_URI, + 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', + }, +} + +ENDPOINT_CONFIG_CONFIG = { + 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, + 'ProductionVariants': [ + { + 'VariantName': f'{PROJECT_NAME}-demo', + 'ModelName': MODEL_NAME, + 'InstanceType': 'ml.t2.medium', + 'InitialInstanceCount': 1, + }, + ], +} + +DEPLOY_ENDPOINT_CONFIG = { + 'EndpointName': ENDPOINT_NAME, + 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, +} + + +@task +def call_endpoint(): + runtime = boto3.Session().client('sagemaker-runtime') + + response = runtime.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + ContentType='text/csv', + Body=SAMPLE_TEST_DATA, + ) + + return json.loads(response["Body"].read().decode())['predictions'] + + +@task(trigger_rule='all_done') +def cleanup(): + # Delete Endpoint and Endpoint Config + client = boto3.client('sagemaker') + endpoint_config_name = client.list_endpoint_configs()['EndpointConfigs'][0]['EndpointConfigName'] + client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + client.delete_endpoint(EndpointName=ENDPOINT_NAME) + + # Delete S3 Artifacts + client = boto3.client('s3') + object_keys = [ + key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] + ] + for key in object_keys: + client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]}) + + +with DAG( + dag_id='example_sagemaker_endpoint', + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as dag: + + upload_data = S3CreateObjectOperator( + task_id='upload_data', + s3_bucket=S3_BUCKET, + s3_key=f'{INPUT_DATA_S3_KEY}/train.csv', + data=TRAIN_DATA, + replace=True, + ) + + train_model = SageMakerTrainingOperator( + task_id='train_model', + config=TRAINING_CONFIG, + ) + + create_model = SageMakerModelOperator( + task_id='create_model', + config=MODEL_CONFIG, + ) + + # [START howto_operator_sagemaker_endpoint_config] + configure_endpoint = SageMakerEndpointConfigOperator( + task_id='configure_endpoint', + config=ENDPOINT_CONFIG_CONFIG, + ) + # [END howto_operator_sagemaker_endpoint_config] + + # [START howto_operator_sagemaker_endpoint] + deploy_endpoint = SageMakerEndpointOperator( + task_id='deploy_endpoint', + config=DEPLOY_ENDPOINT_CONFIG, + # Waits by default, > train_model + >> create_model + >> configure_endpoint + >> deploy_endpoint + >> await_endpoint + >> call_endpoint() + >> cleanup() + >> delete_model + ) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 73348c8555aa6..2c8c28a738ec3 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -310,7 +310,8 @@ def create_training_job( max_ingestion_time: Optional[int] = None, ): """ - Create a training job + Starts a model training job. After training completes, Amazon SageMaker saves + the resulting model artifacts to an Amazon S3 location that you specify. :param config: the config for training :param wait_for_completion: if the program should keep running until job finishes @@ -357,7 +358,11 @@ def create_tuning_job( max_ingestion_time: Optional[int] = None, ): """ - Create a tuning job + Starts a hyperparameter tuning job. A hyperparameter tuning job finds the + best version of a model by running many training jobs on your dataset using + the algorithm you choose and values for hyperparameters within ranges that + you specify. It then chooses the hyperparameter values that result in a model + that performs the best, as measured by an objective metric that you choose. :param config: the config for tuning :param wait_for_completion: if the program should keep running until job finishes @@ -389,7 +394,8 @@ def create_transform_job( max_ingestion_time: Optional[int] = None, ): """ - Create a transform job + Starts a transform job. A transform job uses a trained model to get inferences + on a dataset and saves these results to an Amazon S3 location that you specify. :param config: the config for transform job :param wait_for_completion: if the program should keep running until job finishes @@ -422,7 +428,10 @@ def create_processing_job( max_ingestion_time: Optional[int] = None, ): """ - Create a processing job + Use Amazon SageMaker Processing to analyze data and evaluate machine learning + models on Amazon SageMaker. With Processing, you can use a simplified, managed + experience on SageMaker to run your data processing workloads, such as feature + engineering, data validation, model evaluation, and model interpretation. :param config: the config for processing job :param wait_for_completion: if the program should keep running until job finishes @@ -446,7 +455,10 @@ def create_processing_job( def create_model(self, config: dict): """ - Create a model job + Creates a model in Amazon SageMaker. In the request, you name the model and + describe a primary container. For the primary container, you specify the Docker + image that contains inference code, artifacts (from prior training), and a custom + environment map that the inference code uses when you deploy the model for predictions. :param config: the config for model :return: A response to model creation @@ -455,7 +467,14 @@ def create_model(self, config: dict): def create_endpoint_config(self, config: dict): """ - Create an endpoint config + Creates an endpoint configuration that Amazon SageMaker hosting + services uses to deploy models. In the configuration, you identify + one or more models, created using the CreateModel API, to deploy and + the resources that you want Amazon SageMaker to provision. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_model` + :class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint` :param config: the config for endpoint-config :return: A response to endpoint config creation @@ -470,7 +489,15 @@ def create_endpoint( max_ingestion_time: Optional[int] = None, ): """ - Create an endpoint + When you create a serverless endpoint, SageMaker provisions and manages + the compute resources for you. Then, you can make inference requests to + the endpoint and receive model predictions in response. SageMaker scales + the compute resources up and down as needed to handle your request traffic. + + Requires an Endpoint Config. + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint_config` + :param config: the config for endpoint :param wait_for_completion: if the program should keep running until job finishes @@ -501,7 +528,9 @@ def update_endpoint( max_ingestion_time: Optional[int] = None, ): """ - Update an endpoint + Deploys the new EndpointConfig specified in the request, switches to using + newly created endpoint, and then deletes resources provisioned for the + endpoint using the previous EndpointConfig (there is no availability loss). :param config: the config for endpoint :param wait_for_completion: if the program should keep running until job finishes diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 071c167e9a0d1..b07664b9aa85f 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -17,7 +17,7 @@ import json import sys -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from botocore.exceptions import ClientError @@ -25,6 +25,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.utils.json import AirflowJsonEncoder if sys.version_info >= (3, 8): from functools import cached_property @@ -35,6 +36,10 @@ from airflow.utils.context import Context +def serialize(result: Dict) -> str: + return json.dumps(result, cls=AirflowJsonEncoder) + + class SageMakerBaseOperator(BaseOperator): """This is the base operator for all SageMaker operators. @@ -105,14 +110,18 @@ def hook(self): class SageMakerProcessingOperator(SageMakerBaseOperator): - """Initiate a SageMaker processing job. + """ + Use Amazon SageMaker Processing to analyze data and evaluate machine learning + models on Amazon SageMake. With Processing, you can use a simplified, managed + experience on SageMaker to run your data processing workloads, such as feature + engineering, data validation, model evaluation, and model interpretation. - This operator returns The ARN of the processing job created in Amazon SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerProcessingOperator` :param config: The configuration necessary to start a processing job (templated). - For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the processing job. :param print_log: if the operator should print the cloudwatch log during processing @@ -123,13 +132,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): the operation does not timeout. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". + :return Dict: Returns The ARN of the processing job created in Amazon SageMaker. """ def __init__( self, *, config: dict, - aws_conn_id: str, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, @@ -137,7 +146,7 @@ def __init__( action_if_job_exists: str = 'increment', **kwargs, ): - super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) + super().__init__(config=config, **kwargs) if action_if_job_exists not in ('increment', 'fail'): raise AirflowException( f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ @@ -180,19 +189,25 @@ def execute(self, context: 'Context') -> dict: ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Processing Job creation failed: {response}') - return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])} + return {'Processing': serialize(self.hook.describe_processing_job(self.config['ProcessingJobName']))} class SageMakerEndpointConfigOperator(SageMakerBaseOperator): """ - Create a SageMaker endpoint config. + Creates an endpoint configuration that Amazon SageMaker hosting + services uses to deploy models. In the configuration, you identify + one or more models, created using the CreateModel API, to deploy and + the resources that you want Amazon SageMaker to provision. - This operator returns The ARN of the endpoint config created in Amazon SageMaker + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerEndpointConfigOperator` :param config: The configuration necessary to create an endpoint config. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config` :param aws_conn_id: The AWS connection ID to use. + :return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker. """ integer_fields = [['ProductionVariants', 'InitialInstanceCount']] @@ -213,9 +228,16 @@ def execute(self, context: 'Context') -> dict: class SageMakerEndpointOperator(SageMakerBaseOperator): """ - Create a SageMaker endpoint. + When you create a serverless endpoint, SageMaker provisions and manages + the compute resources for you. Then, you can make inference requests to + the endpoint and receive model predictions in response. SageMaker scales + the compute resources up and down as needed to handle your request traffic. - This operator returns The ARN of the endpoint created in Amazon SageMaker + Requires an Endpoint Config. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerEndpointOperator` :param config: The configuration necessary to create an endpoint. @@ -242,13 +264,13 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): For details of the configuration parameter of endpoint_configuration see :py:meth:`SageMaker.Client.create_endpoint` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes. :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation waits before polling the status of the endpoint creation. :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't finish within max_ingestion_time seconds. If you set this parameter to None it never times out. :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'. + :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker. """ def __init__( @@ -325,15 +347,21 @@ def execute(self, context: 'Context') -> dict: raise AirflowException(f'Sagemaker endpoint creation failed: {response}') else: return { - 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']), - 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']), + 'EndpointConfig': serialize( + self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']) + ), + 'Endpoint': serialize(self.hook.describe_endpoint(endpoint_info['EndpointName'])), } class SageMakerTransformOperator(SageMakerBaseOperator): - """Initiate a SageMaker transform job. + """ + Starts a transform job. A transform job uses a trained model to get inferences + on a dataset and saves these results to an Amazon S3 location that you specify. - This operator returns The ARN of the model created in Amazon SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerTransformOperator` :param config: The configuration necessary to start a transform job (templated). @@ -354,13 +382,13 @@ class SageMakerTransformOperator(SageMakerBaseOperator): For details of the configuration parameter of model_config, See: :py:meth:`SageMaker.Client.create_model` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the transform job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the transform job. :param max_ingestion_time: If wait is set to True, the operation fails if the transform job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. + :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ def __init__( @@ -416,27 +444,36 @@ def execute(self, context: 'Context') -> dict: raise AirflowException(f'Sagemaker transform Job creation failed: {response}') else: return { - 'Model': self.hook.describe_model(transform_config['ModelName']), - 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']), + 'Model': serialize(self.hook.describe_model(transform_config['ModelName'])), + 'Transform': serialize( + self.hook.describe_transform_job(transform_config['TransformJobName']) + ), } class SageMakerTuningOperator(SageMakerBaseOperator): - """Initiate a SageMaker hyperparameter tuning job. + """ + Starts a hyperparameter tuning job. A hyperparameter tuning job finds the + best version of a model by running many training jobs on your dataset using + the algorithm you choose and values for hyperparameters within ranges that + you specify. It then chooses the hyperparameter values that result in a model + that performs the best, as measured by an objective metric that you choose. - This operator returns The ARN of the tuning job created in Amazon SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerTuningOperator` :param config: The configuration necessary to start a tuning job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the tuning job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the tuning job. :param max_ingestion_time: If wait is set to True, the operation fails if the tuning job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. + :return Dict: Returns The ARN of the tuning job created in Amazon SageMaker. """ integer_fields = [ @@ -483,18 +520,26 @@ def execute(self, context: 'Context') -> dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}') else: - return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])} + return { + 'Tuning': serialize(self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])) + } class SageMakerModelOperator(SageMakerBaseOperator): - """Create a SageMaker model. + """ + Creates a model in Amazon SageMaker. In the request, you name the model and + describe a primary container. For the primary container, you specify the Docker + image that contains inference code, artifacts (from prior training), and a custom + environment map that the inference code uses when you deploy the model for predictions. - This operator returns The ARN of the model created in Amazon SageMaker + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerModelOperator` :param config: The configuration necessary to create a model. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model` - :param aws_conn_id: The AWS connection ID to use. + :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ def __init__(self, *, config, **kwargs): @@ -513,19 +558,21 @@ def execute(self, context: 'Context') -> dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker model creation failed: {response}') else: - return {'Model': self.hook.describe_model(self.config['ModelName'])} + return {'Model': serialize(self.hook.describe_model(self.config['ModelName']))} class SageMakerTrainingOperator(SageMakerBaseOperator): """ - Initiate a SageMaker training job. + Starts a model training job. After training completes, Amazon SageMaker saves + the resulting model artifacts to an Amazon S3 location that you specify. - This operator returns The ARN of the training job created in Amazon SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerTrainingOperator` :param config: The configuration necessary to start a training job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job` - :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the training job. :param print_log: if the operator should print the cloudwatch log during training @@ -539,6 +586,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". This is only relevant if check_if + :return Dict: Returns The ARN of the training job created in Amazon SageMaker. """ integer_fields = [ @@ -593,7 +641,7 @@ def execute(self, context: 'Context') -> dict: if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException(f'Sagemaker Training Job creation failed: {response}') else: - return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} + return {'Training': serialize(self.hook.describe_training_job(self.config['TrainingJobName']))} def _check_if_job_exists(self) -> None: training_job_name = self.config['TrainingJobName'] @@ -611,19 +659,19 @@ def _check_if_job_exists(self) -> None: class SageMakerDeleteModelOperator(SageMakerBaseOperator): - """Deletes a SageMaker model. + """ + Deletes a SageMaker model. - This operator deletes the Model entry created in SageMaker. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerDeleteModelOperator` :param config: The configuration necessary to delete the model. - For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model` - :param aws_conn_id: The AWS connection ID to use. """ - def __init__(self, *, config, aws_conn_id: str, **kwargs): + def __init__(self, *, config, **kwargs): super().__init__(config=config, **kwargs) - self.aws_conn_id = aws_conn_id self.config = config def execute(self, context: 'Context') -> Any: diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index 054b139cc2eb9..3cf6dceef154c 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -27,10 +27,10 @@ class SageMakerBaseSensor(BaseSensorOperator): - """Contains general sensor behavior for SageMaker. + """ + Contains general sensor behavior for SageMaker. - Subclasses should implement get_sagemaker_response() - and state_from_response() methods. + Subclasses should implement get_sagemaker_response() and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """ @@ -84,13 +84,15 @@ def state_from_response(self, response: dict) -> str: class SageMakerEndpointSensor(SageMakerBaseSensor): - """Asks for the state of the endpoint state until it reaches a - terminal state. - If it fails the sensor errors, the task fails. - + """ + Polls the endpoint state until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. - :param job_name: job_name of the endpoint instance to check the state of + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerEndpointSensor` + :param endpoint_name: Name of the endpoint instance to watch. """ template_fields: Sequence[str] = ('endpoint_name',) @@ -118,15 +120,15 @@ def state_from_response(self, response): class SageMakerTransformSensor(SageMakerBaseSensor): - """Asks for the state of the transform state until it reaches a - terminal state. - The sensor will error if the job errors, throwing a - AirflowException - containing the failure reason. + """ + Polls the transform job until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. - :param - job_name: job_name of the transform job instance to check the state of + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerTransformSensor` + :param job_name: Name of the transform job to watch. """ template_fields: Sequence[str] = ('job_name',) @@ -154,16 +156,15 @@ def state_from_response(self, response): class SageMakerTuningSensor(SageMakerBaseSensor): - """Asks for the state of the tuning state until it reaches a terminal - state. - The sensor will error if the job errors, throwing a - AirflowException - containing the failure reason. - - :param - job_name: job_name of the tuning instance to check the state of - :type - job_name: str + """ + Asks for the state of the tuning state until it reaches a terminal state. + Raises an AirflowException with the failure reason if a failed state is reached. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerTuningSensor` + + :param job_name: Name of the tuning instance to watch. """ template_fields: Sequence[str] = ('job_name',) @@ -191,14 +192,16 @@ def state_from_response(self, response): class SageMakerTrainingSensor(SageMakerBaseSensor): - """Asks for the state of the training state until it reaches a - terminal state. - If it fails the sensor errors, failing the task. - + """ + Polls the training job until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. - :param job_name: name of the SageMaker training job to check the state of + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerTrainingSensor` - :param print_log: if the operator should print the cloudwatch log + :param job_name: Name of the training job to watch. + :param print_log: Prints the cloudwatch log if True; Defaults to True. """ template_fields: Sequence[str] = ('job_name',) diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst index f44d258a2718f..a31527d4ca305 100644 --- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst +++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst @@ -15,52 +15,201 @@ specific language governing permissions and limitations under the License. - Amazon SageMaker Operators -======================================== +========================== + +`Amazon SageMaker `__ is a fully managed +machine learning service. With Amazon SageMaker, data scientists and developers +can quickly build and train machine learning models, and then deploy them into a +production-ready hosted environment. + +Airflow provides operators to create and interact with SageMaker Jobs. Prerequisite Tasks ------------------ .. include:: _partials/prerequisite_tasks.rst -Overview --------- +Manage Amazon SageMaker Jobs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. _howto/operator:SageMakerProcessingOperator: + +Create an Amazon SageMaker Processing Job +""""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker processing job to sanitize your dataset you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_processing] + :end-before: [END howto_operator_sagemaker_processing] + + +.. _howto/operator:SageMakerTrainingOperator: + +Create an Amazon SageMaker Training Job +""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker training job you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_training] + :end-before: [END howto_operator_sagemaker_training] + +.. _howto/operator:SageMakerModelOperator: + +Create an Amazon SageMaker Model +"""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker model you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_model] + :end-before: [END howto_operator_sagemaker_model] + +.. _howto/operator:SageMakerTuningOperator: + +Start a Hyperparameter Tuning Job +""""""""""""""""""""""""""""""""" + +To start a hyperparameter tuning job for an Amazon Sagemaker model you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_tuning] + :end-before: [END howto_operator_sagemaker_tuning] + +.. _howto/operator:SageMakerDeleteModelOperator: + +Delete an Amazon SageMaker Model +"""""""""""""""""""""""""""""""" + +To delete an Amazon Sagemaker model you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_delete_model] + :end-before: [END howto_operator_sagemaker_delete_model] + +.. _howto/operator:SageMakerTransformOperator: + +Create an Amazon SageMaker Transform Job +"""""""""""""""""""""""""""""""""""""""" -Airflow to Amazon SageMaker integration provides several operators to create and interact with -SageMaker Jobs. +To create an Amazon Sagemaker transform job you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`. - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator` - - :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator` +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_transform] + :end-before: [END howto_operator_sagemaker_transform] -Purpose -""""""" +.. _howto/operator:SageMakerEndpointConfigOperator: + +Create an Amazon SageMaker Endpoint Config Job +"""""""""""""""""""""""""""""""""""""""""""""" + +To create an Amazon Sagemaker endpoint config job you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointConfigOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_endpoint_config] + :end-before: [END howto_operator_sagemaker_endpoint_config] -This example DAG ``example_sagemaker.py`` uses ``SageMakerProcessingOperator``, ``SageMakerTrainingOperator``, -``SageMakerModelOperator``, ``SageMakerDeleteModelOperator`` and ``SageMakerTransformOperator`` to -create SageMaker processing job, run the training job, -generate the models artifact in s3, create the model, -, run SageMaker Batch inference and delete the model from SageMaker. +.. _howto/operator:SageMakerEndpointOperator: -Defining tasks -"""""""""""""" +Create an Amazon SageMaker Endpoint Job +""""""""""""""""""""""""""""""""""""""" -In the following code we create a SageMaker processing, -training, Sagemaker Model, batch transform job and -then delete the model. +To create an Amazon Sagemaker endpoint you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_endpoint] + :end-before: [END howto_operator_sagemaker_endpoint] + + +Amazon SageMaker Sensors +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. _howto/sensor:SageMakerTrainingSensor: + +Amazon SageMaker Training Sensor +"""""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker training job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTrainingSensor`. .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py :language: python - :start-after: [START howto_operator_sagemaker] - :end-before: [END howto_operator_sagemaker] + :dedent: 4 + :start-after: [START howto_operator_sagemaker_training_sensor] + :end-before: [END howto_operator_sagemaker_training_sensor] + +.. _howto/sensor:SageMakerTransformSensor: + +Amazon SageMaker Transform Sensor +""""""""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker transform job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_transform_sensor] + :end-before: [END howto_operator_sagemaker_transform_sensor] + +.. _howto/sensor:SageMakerTuningSensor: + +Amazon SageMaker Tuning Sensor +"""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker hyperparameter tuning job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTuningSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_tuning_sensor] + :end-before: [END howto_operator_sagemaker_tuning_sensor] + +.. _howto/sensor:SageMakerEndpointSensor: + +Amazon SageMaker Endpoint Sensor +"""""""""""""""""""""""""""""""" + +To check the state of an Amazon Sagemaker hyperparameter tuning job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerEndpointSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_endpoint_sensor] + :end-before: [END howto_operator_sagemaker_endpoint_sensor] Reference ---------- +^^^^^^^^^ For further information, look at: * `Boto3 Library Documentation for Sagemaker `__ +* `Amazon SageMaker Developer Guide `__ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b9f2aad173a90..b55f42eaad761 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -909,6 +909,7 @@ https httpx hvac hyperparameter +hyperparameters iPython iTerm iam