From 87f1920b3e957d3df8c48a435fa576f45957c9e4 Mon Sep 17 00:00:00 2001 From: Aakash Pydi Date: Wed, 10 Jun 2020 16:37:24 -0500 Subject: [PATCH] feature: Support for multi variant endpoint invocation with target variant param --- src/sagemaker/local/local_session.py | 4 + src/sagemaker/predictor.py | 13 +- tests/integ/test_multi_variant_endpoint.py | 309 +++++++++++++++++++++ tests/unit/test_predictor.py | 26 ++ 4 files changed, 349 insertions(+), 3 deletions(-) create mode 100644 tests/integ/test_multi_variant_endpoint.py diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 1e80f6e1b4..6c23cea45a 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -343,6 +343,7 @@ def invoke_endpoint( Accept=None, CustomAttributes=None, TargetModel=None, + TargetVariant=None, ): """ @@ -370,6 +371,9 @@ def invoke_endpoint( if TargetModel is not None: headers["X-Amzn-SageMaker-Target-Model"] = TargetModel + if TargetVariant is not None: + headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant + r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers) return {"Body": r, "ContentType": Accept} diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 80da8a551c..0f1efb4e18 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -83,7 +83,7 @@ def __init__( self._endpoint_config_name = self._get_endpoint_config_name() self._model_names = self._get_model_names() - def predict(self, data, initial_args=None, target_model=None): + def predict(self, data, initial_args=None, target_model=None, target_variant=None): """Return the inference from the specified endpoint. Args: @@ -98,6 +98,9 @@ def predict(self, data, initial_args=None, target_model=None): target_model (str): S3 model artifact path to run an inference request on, in case of a multi model endpoint. Does not apply to endpoints hosting single model (Default: None) + target_variant (str): The name of the production variant to run an inference + request on (Default: None). Note that the ProductionVariant identifies the model + you want to host and the resources you want to deploy for hosting it. Returns: object: Inference for the given input. If a deserializer was specified when creating @@ -106,7 +109,7 @@ def predict(self, data, initial_args=None, target_model=None): as is. """ - request_args = self._create_request_args(data, initial_args, target_model) + request_args = self._create_request_args(data, initial_args, target_model, target_variant) response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args) return self._handle_response(response) @@ -123,12 +126,13 @@ def _handle_response(self, response): response_body.close() return data - def _create_request_args(self, data, initial_args=None, target_model=None): + def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None): """ Args: data: initial_args: target_model: + target_variant: """ args = dict(initial_args) if initial_args else {} @@ -144,6 +148,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None): if target_model: args["TargetModel"] = target_model + if target_variant: + args["TargetVariant"] = target_variant + if self.serializer is not None: data = self.serializer(data) diff --git a/tests/integ/test_multi_variant_endpoint.py b/tests/integ/test_multi_variant_endpoint.py new file mode 100644 index 0000000000..7d1fab7645 --- /dev/null +++ b/tests/integ/test_multi_variant_endpoint.py @@ -0,0 +1,309 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import math +import pytest +import scipy.stats as st + +from sagemaker.s3 import S3Uploader +from sagemaker.session import production_variant +from sagemaker.sparkml import SparkMLModel +from sagemaker.utils import sagemaker_timestamp +from sagemaker.content_types import CONTENT_TYPE_CSV +from sagemaker.utils import unique_name_from_base +from sagemaker.amazon.amazon_estimator import get_image_uri +from sagemaker.predictor import csv_serializer, RealTimePredictor + + +import tests.integ + + +ROLE = "SageMakerRole" +MODEL_NAME = "test-xgboost-model-{}".format(sagemaker_timestamp()) +ENDPOINT_NAME = unique_name_from_base("integ-test-multi-variant-endpoint") +DEFAULT_REGION = "us-west-2" +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" +DEFAULT_INSTANCE_COUNT = 1 +XG_BOOST_MODEL_LOCAL_PATH = os.path.join(tests.integ.DATA_DIR, "xgboost_model", "xgb_model.tar.gz") + +TEST_VARIANT_1 = "Variant1" +TEST_VARIANT_1_WEIGHT = 0.3 + +TEST_VARIANT_2 = "Variant2" +TEST_VARIANT_2_WEIGHT = 0.7 + +VARIANT_TRAFFIC_SAMPLING_COUNT = 100 +DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION = 0.999 + +TEST_CSV_DATA = "42,42,42,42,42,42,42" + +SPARK_ML_MODEL_LOCAL_PATH = os.path.join( + tests.integ.DATA_DIR, "sparkml_model", "mleap_model.tar.gz" +) +SPARK_ML_MODEL_ENDPOINT_NAME = unique_name_from_base("integ-test-target-variant-sparkml") +SPARK_ML_DEFAULT_VARIANT_NAME = ( + "AllTraffic" +) # default defined in src/sagemaker/session.py def production_variant +SPARK_ML_WRONG_VARIANT_NAME = "WRONG_VARIANT" +SPARK_ML_TEST_DATA = "1.0,C,38.0,71.5,1.0,female" +SPARK_ML_MODEL_SCHEMA = json.dumps( + { + "input": [ + {"name": "Pclass", "type": "float"}, + {"name": "Embarked", "type": "string"}, + {"name": "Age", "type": "float"}, + {"name": "Fare", "type": "float"}, + {"name": "SibSp", "type": "float"}, + {"name": "Sex", "type": "string"}, + ], + "output": {"name": "features", "struct": "vector", "type": "double"}, + } +) + + +@pytest.fixture(scope="module") +def multi_variant_endpoint(sagemaker_session): + """ + Sets up the multi variant endpoint before the integration tests run. + Cleans up the multi variant endpoint after the integration tests run. + """ + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name( + endpoint_name=ENDPOINT_NAME, sagemaker_session=sagemaker_session, hours=2 + ): + + # Creating a model + bucket = sagemaker_session.default_bucket() + prefix = "sagemaker/DEMO-VariantTargeting" + model_url = S3Uploader.upload( + local_path=XG_BOOST_MODEL_LOCAL_PATH, + desired_s3_uri="s3://" + bucket + "/" + prefix, + session=sagemaker_session, + ) + + image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1") + + multi_variant_endpoint_model = sagemaker_session.create_model( + name=MODEL_NAME, + role=ROLE, + container_defs={"Image": image_uri, "ModelDataUrl": model_url}, + ) + + # Creating a multi variant endpoint + variant1 = production_variant( + model_name=MODEL_NAME, + instance_type=DEFAULT_INSTANCE_TYPE, + initial_instance_count=DEFAULT_INSTANCE_COUNT, + variant_name=TEST_VARIANT_1, + initial_weight=TEST_VARIANT_1_WEIGHT, + ) + variant2 = production_variant( + model_name=MODEL_NAME, + instance_type=DEFAULT_INSTANCE_TYPE, + initial_instance_count=DEFAULT_INSTANCE_COUNT, + variant_name=TEST_VARIANT_2, + initial_weight=TEST_VARIANT_2_WEIGHT, + ) + sagemaker_session.endpoint_from_production_variants( + name=ENDPOINT_NAME, production_variants=[variant1, variant2] + ) + + # Yield to run the integration tests + yield multi_variant_endpoint + + # Cleanup resources + sagemaker_session.delete_model(multi_variant_endpoint_model) + sagemaker_session.sagemaker_client.delete_endpoint_config(EndpointConfigName=ENDPOINT_NAME) + + # Validate resource cleanup + with pytest.raises(Exception) as exception: + sagemaker_session.sagemaker_client.describe_model( + ModelName=multi_variant_endpoint_model.name + ) + assert "Could not find model" in str(exception.value) + sagemaker_session.sagemaker_client.describe_endpoint_config(name=ENDPOINT_NAME) + assert "Could not find endpoint" in str(exception.value) + + +def test_target_variant_invocation(sagemaker_session, multi_variant_endpoint): + + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + Body=TEST_CSV_DATA, + ContentType=CONTENT_TYPE_CSV, + Accept=CONTENT_TYPE_CSV, + TargetVariant=TEST_VARIANT_1, + ) + assert response["InvokedProductionVariant"] == TEST_VARIANT_1 + + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + Body=TEST_CSV_DATA, + ContentType=CONTENT_TYPE_CSV, + Accept=CONTENT_TYPE_CSV, + TargetVariant=TEST_VARIANT_2, + ) + assert response["InvokedProductionVariant"] == TEST_VARIANT_2 + + +def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant_endpoint): + predictor = RealTimePredictor( + endpoint=ENDPOINT_NAME, + sagemaker_session=sagemaker_session, + serializer=csv_serializer, + content_type=CONTENT_TYPE_CSV, + accept=CONTENT_TYPE_CSV, + ) + + # Validate that no exception is raised when the target_variant is specified. + predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1) + predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2) + + +def test_variant_traffic_distribution(sagemaker_session, multi_variant_endpoint): + variant_1_invocation_count = 0 + variant_2_invocation_count = 0 + + for i in range(0, VARIANT_TRAFFIC_SAMPLING_COUNT): + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + Body=TEST_CSV_DATA, + ContentType=CONTENT_TYPE_CSV, + Accept=CONTENT_TYPE_CSV, + ) + if response["InvokedProductionVariant"] == TEST_VARIANT_1: + variant_1_invocation_count += 1 + elif response["InvokedProductionVariant"] == TEST_VARIANT_2: + variant_2_invocation_count += 1 + + assert variant_1_invocation_count + variant_2_invocation_count == VARIANT_TRAFFIC_SAMPLING_COUNT + + variant_1_invocation_percentage = float(variant_1_invocation_count) / float( + VARIANT_TRAFFIC_SAMPLING_COUNT + ) + variant_1_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_1_WEIGHT) + assert variant_1_invocation_percentage < TEST_VARIANT_1_WEIGHT + variant_1_margin_of_error + assert variant_1_invocation_percentage > TEST_VARIANT_1_WEIGHT - variant_1_margin_of_error + + variant_2_invocation_percentage = float(variant_2_invocation_count) / float( + VARIANT_TRAFFIC_SAMPLING_COUNT + ) + variant_2_margin_of_error = _compute_and_retrieve_margin_of_error(TEST_VARIANT_2_WEIGHT) + assert variant_2_invocation_percentage < TEST_VARIANT_2_WEIGHT + variant_2_margin_of_error + assert variant_2_invocation_percentage > TEST_VARIANT_2_WEIGHT - variant_2_margin_of_error + + +def test_spark_ml_predict_invocation_with_target_variant(sagemaker_session): + model_data = sagemaker_session.upload_data( + path=SPARK_ML_MODEL_LOCAL_PATH, key_prefix="integ-test-data/sparkml/model" + ) + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name( + SPARK_ML_MODEL_ENDPOINT_NAME, sagemaker_session + ): + spark_ml_model = SparkMLModel( + model_data=model_data, + role=ROLE, + sagemaker_session=sagemaker_session, + env={"SAGEMAKER_SPARKML_SCHEMA": SPARK_ML_MODEL_SCHEMA}, + ) + + predictor = spark_ml_model.deploy( + DEFAULT_INSTANCE_COUNT, + DEFAULT_INSTANCE_TYPE, + endpoint_name=SPARK_ML_MODEL_ENDPOINT_NAME, + ) + + # Validate that no exception is raised when the target_variant is specified. + predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_DEFAULT_VARIANT_NAME) + + with pytest.raises(Exception) as exception_info: + predictor.predict(SPARK_ML_TEST_DATA, target_variant=SPARK_ML_WRONG_VARIANT_NAME) + + assert "ValidationError" in str(exception_info.value) + assert SPARK_ML_WRONG_VARIANT_NAME in str(exception_info.value) + + # cleanup resources + spark_ml_model.delete_model() + sagemaker_session.sagemaker_client.delete_endpoint_config( + EndpointConfigName=SPARK_ML_MODEL_ENDPOINT_NAME + ) + + # Validate resource cleanup + with pytest.raises(Exception) as exception: + sagemaker_session.sagemaker_client.describe_model(ModelName=spark_ml_model.name) + assert "Could not find model" in str(exception.value) + sagemaker_session.sagemaker_client.describe_endpoint_config( + name=SPARK_ML_MODEL_ENDPOINT_NAME + ) + assert "Could not find endpoint" in str(exception.value) + + +@pytest.mark.local_mode +def test_target_variant_invocation_local_mode(sagemaker_session, multi_variant_endpoint): + + if sagemaker_session._region_name is None: + sagemaker_session._region_name = DEFAULT_REGION + + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + Body=TEST_CSV_DATA, + ContentType=CONTENT_TYPE_CSV, + Accept=CONTENT_TYPE_CSV, + TargetVariant=TEST_VARIANT_1, + ) + assert response["InvokedProductionVariant"] == TEST_VARIANT_1 + + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=ENDPOINT_NAME, + Body=TEST_CSV_DATA, + ContentType=CONTENT_TYPE_CSV, + Accept=CONTENT_TYPE_CSV, + TargetVariant=TEST_VARIANT_2, + ) + assert response["InvokedProductionVariant"] == TEST_VARIANT_2 + + +@pytest.mark.local_mode +def test_predict_invocation_with_target_variant_local_mode( + sagemaker_session, multi_variant_endpoint +): + + if sagemaker_session._region_name is None: + sagemaker_session._region_name = DEFAULT_REGION + + predictor = RealTimePredictor( + endpoint=ENDPOINT_NAME, + sagemaker_session=sagemaker_session, + serializer=csv_serializer, + content_type=CONTENT_TYPE_CSV, + accept=CONTENT_TYPE_CSV, + ) + + # Validate that no exception is raised when the target_variant is specified. + predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_1) + predictor.predict(TEST_CSV_DATA, target_variant=TEST_VARIANT_2) + + +def _compute_and_retrieve_margin_of_error(variant_weight): + """ + Computes the margin of error using the Wald method for computing the confidence + intervals of a binomial distribution. + """ + z_value = st.norm.ppf(DESIRED_CONFIDENCE_FOR_VARIANT_TRAFFIC_DISTRIBUTION) + margin_of_error = (variant_weight * (1 - variant_weight)) / VARIANT_TRAFFIC_SAMPLING_COUNT + margin_of_error = z_value * math.sqrt(margin_of_error) + return margin_of_error diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 9c7feb0a83..b74f36e2b3 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -346,6 +346,7 @@ def test_numpy_deser_from_npy_object_array(): CSV_CONTENT_TYPE = "text/csv" RETURN_VALUE = 0 CSV_RETURN_VALUE = "1,2,3\r\n" +PRODUCTION_VARIANT_1 = "PRODUCTION_VARIANT_1" ENDPOINT_DESC = {"EndpointConfigName": ENDPOINT} @@ -407,6 +408,31 @@ def test_predict_call_with_headers(): assert result == RETURN_VALUE +def test_predict_call_with_target_variant(): + sagemaker_session = empty_sagemaker_session() + predictor = RealTimePredictor( + ENDPOINT, sagemaker_session, content_type=DEFAULT_CONTENT_TYPE, accept=DEFAULT_CONTENT_TYPE + ) + + data = "untouched" + result = predictor.predict(data, target_variant=PRODUCTION_VARIANT_1) + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called + + expected_request_args = { + "Accept": DEFAULT_CONTENT_TYPE, + "Body": data, + "ContentType": DEFAULT_CONTENT_TYPE, + "EndpointName": ENDPOINT, + "TargetVariant": PRODUCTION_VARIANT_1, + } + + call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args + assert kwargs == expected_request_args + + assert result == RETURN_VALUE + + def test_multi_model_predict_call_with_headers(): sagemaker_session = empty_sagemaker_session() predictor = RealTimePredictor(