From 9921a078e277c73052fd322001ec3933dc05b3f5 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 2 Jul 2020 16:21:21 -0700 Subject: [PATCH 1/2] change: handle image_uri rename for estimators and models in v2 migration tool --- .../cli/compatibility/v2/ast_transformer.py | 1 + .../v2/modifiers/framework_version.py | 14 +-- .../v2/modifiers/renamed_params.py | 62 +++++++++ .../v2/modifiers/tf_legacy_mode.py | 4 +- .../v2/modifiers/test_framework_version.py | 4 +- .../v2/modifiers/test_image_uri.py | 118 ++++++++++++++++++ .../v2/modifiers/test_tf_legacy_mode.py | 4 +- 7 files changed, 193 insertions(+), 14 deletions(-) create mode 100644 tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_image_uri.py diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 09bd63d792..4e99f39308 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -18,6 +18,7 @@ from sagemaker.cli.compatibility.v2 import modifiers FUNCTION_CALL_MODIFIERS = [ + modifiers.renamed_params.EstimatorImageURIRenamer(), modifiers.framework_version.FrameworkVersionEnforcer(), modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(), modifiers.tf_legacy_mode.TensorBoardParameterRemover(), diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 9480a4af3d..4f144067aa 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -19,6 +19,7 @@ from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier FRAMEWORK_ARG = "framework_version" +IMAGE_ARG = "image_uri" PY_ARG = "py_version" FRAMEWORK_DEFAULTS = { @@ -70,11 +71,8 @@ def node_should_be_modified(self, node): bool: If the ``ast.Call`` is instantiating a framework class that should specify ``framework_version``, but doesn't. """ - if matching.matches_any(node, ESTIMATORS): - return _version_args_needed(node, "image_name") - - if matching.matches_any(node, MODELS): - return _version_args_needed(node, "image") + if matching.matches_any(node, ESTIMATORS) or matching.matches_any(node, MODELS): + return _version_args_needed(node) return False @@ -169,14 +167,14 @@ def _framework_from_node(node): return framework, is_model -def _version_args_needed(node, image_arg): +def _version_args_needed(node): """Determines if image_arg or version_arg was supplied Applies similar logic as ``validate_version_or_image_args`` """ # if image_arg is present, no need to supply version arguments - image_name = _arg_value(node, image_arg) - if image_name: + image_uri = _arg_value(node, IMAGE_ARG) + if image_uri: return False # if framework_version is None, need args diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py index c63eb95d51..da24fb7c97 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py @@ -167,3 +167,65 @@ def node_should_be_modified(self, node): return False return super(S3SessionRenamer, self).node_should_be_modified(node) + + +class EstimatorImageURIRenamer(ParamRenamer): + """A class to rename the ``image_name`` attribute to ``image_uri`` in estimators.""" + + @property + def calls_to_modify(self): + """A dictionary mapping estimators with the ``image_name`` attribute to their + respective namespaces. + """ + return { + "Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"), + "Estimator": ("sagemaker.estimator",), + "Framework": ("sagemaker.estimator",), + "MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"), + "PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"), + "RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"), + "SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"), + "TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"), + "XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"), + } + + @property + def old_param_name(self): + """The previous name for the image URI argument.""" + return "image_name" + + @property + def new_param_name(self): + """The new name for the image URI argument.""" + return "image_uri" + + +class ModelImageURIRenamer(ParamRenamer): + """A class to rename the ``image`` attribute to ``image_uri`` in models.""" + + @property + def calls_to_modify(self): + """A dictionary mapping models with the ``image`` attribute to their + respective namespaces. + """ + return { + "ChainerModel": ("sagemaker.chainer", "sagemaker.chainer.model"), + "Model": ("sagemaker.model",), + "MultiDataModel": ("sagemaker.multidatamodel",), + "FrameworkModel": ("sagemaker.model",), + "MXNetModel": ("sagemaker.mxnet", "sagemaker.mxnet.model"), + "PyTorchModel": ("sagemaker.pytorch", "sagemaker.pytorch.model"), + "SKLearnModel": ("sagemaker.sklearn", "sagemaker.sklearn.model"), + "TensorFlowModel": ("sagemaker.tensorflow", "sagemaker.tensorflow.model"), + "XGBoostModel": ("sagemaker.xgboost", "sagemaker.xgboost.model"), + } + + @property + def old_param_name(self): + """The previous name for the image URI argument.""" + return "image" + + @property + def new_param_name(self): + """The new name for the image URI argument.""" + return "image_uri" diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py index 99c221ef4e..794c4909ee 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py @@ -116,7 +116,7 @@ def modify_node(self, node): hp_key = self._hyperparameter_key_for_param(kw.arg) additional_hps[hp_key] = kw.value kw_to_remove.append(kw) - if kw.arg == "image_name": + if kw.arg == "image_uri": add_image_uri = False self._remove_keywords(node, kw_to_remove) @@ -124,7 +124,7 @@ def modify_node(self, node): if add_image_uri: image_uri = self._image_uri_from_args(node.keywords) - node.keywords.append(ast.keyword(arg="image_name", value=ast.Str(s=image_uri))) + node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri))) node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False))) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py index 82fe36738b..994bcb3f2a 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py @@ -57,7 +57,7 @@ def _templates(self, model=False): def _frameworks(self, versions=False, image=False): keywords = dict() if image: - keywords["image_name"] = "my:image" + keywords["image_uri"] = "my:image" if versions: keywords["framework_version"] = self.framework_version keywords["py_version"] = self.py_version @@ -66,7 +66,7 @@ def _frameworks(self, versions=False, image=False): def _models(self, versions=False, image=False): keywords = dict() if image: - keywords["image"] = "my:image" + keywords["image_uri"] = "my:image" if versions: keywords["framework_version"] = self.framework_version if self.py_version_for_model: diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_image_uri.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_image_uri.py new file mode 100644 index 0000000000..cd464c0204 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_image_uri.py @@ -0,0 +1,118 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pasta + +from sagemaker.cli.compatibility.v2.modifiers import renamed_params +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call + +ESTIMATORS = { + "Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"), + "Estimator": ("sagemaker.estimator",), + "Framework": ("sagemaker.estimator",), + "MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"), + "PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"), + "RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"), + "SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"), + "TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"), + "XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"), +} + +MODELS = { + "ChainerModel": ("sagemaker.chainer", "sagemaker.chainer.model"), + "Model": ("sagemaker.model",), + "MultiDataModel": ("sagemaker.multidatamodel",), + "FrameworkModel": ("sagemaker.model",), + "MXNetModel": ("sagemaker.mxnet", "sagemaker.mxnet.model"), + "PyTorchModel": ("sagemaker.pytorch", "sagemaker.pytorch.model"), + "SKLearnModel": ("sagemaker.sklearn", "sagemaker.sklearn.model"), + "TensorFlowModel": ("sagemaker.tensorflow", "sagemaker.tensorflow.model"), + "XGBoostModel": ("sagemaker.xgboost", "sagemaker.xgboost.model"), +} + + +def test_estimator_node_should_be_modified(): + modifier = renamed_params.EstimatorImageURIRenamer() + + for estimator, namespaces in ESTIMATORS.items(): + call = "{}(image_name='my-image:latest')".format(estimator) + assert modifier.node_should_be_modified(ast_call(call)) + + for namespace in namespaces: + call = "{}.{}(image_name='my-image:latest')".format(namespace, estimator) + assert modifier.node_should_be_modified(ast_call(call)) + + +def test_estimator_node_should_be_modified_no_distribution(): + modifier = renamed_params.EstimatorImageURIRenamer() + + for estimator, namespaces in ESTIMATORS.items(): + call = "{}()".format(estimator) + assert not modifier.node_should_be_modified(ast_call(call)) + + for namespace in namespaces: + call = "{}.{}()".format(namespace, estimator) + assert not modifier.node_should_be_modified(ast_call(call)) + + +def test_estimator_node_should_be_modified_random_function_call(): + modifier = renamed_params.EstimatorImageURIRenamer() + assert not modifier.node_should_be_modified(ast_call("Session()")) + + +def test_estimator_modify_node(): + node = ast_call("TensorFlow(image_name=my_image)") + modifier = renamed_params.EstimatorImageURIRenamer() + modifier.modify_node(node) + + expected = "TensorFlow(image_uri=my_image)" + assert expected == pasta.dump(node) + + +def test_model_node_should_be_modified(): + modifier = renamed_params.ModelImageURIRenamer() + + for model, namespaces in MODELS.items(): + call = "{}(image='my-image:latest')".format(model) + assert modifier.node_should_be_modified(ast_call(call)) + + for namespace in namespaces: + call = "{}.{}(image='my-image:latest')".format(namespace, model) + assert modifier.node_should_be_modified(ast_call(call)) + + +def test_model_node_should_be_modified_no_distribution(): + modifier = renamed_params.ModelImageURIRenamer() + + for model, namespaces in MODELS.items(): + call = "{}()".format(model) + assert not modifier.node_should_be_modified(ast_call(call)) + + for namespace in namespaces: + call = "{}.{}()".format(namespace, model) + assert not modifier.node_should_be_modified(ast_call(call)) + + +def test_model_node_should_be_modified_random_function_call(): + modifier = renamed_params.ModelImageURIRenamer() + assert not modifier.node_should_be_modified(ast_call("Session()")) + + +def test_model_modify_node(): + node = ast_call("TensorFlowModel(image=my_image)") + modifier = renamed_params.ModelImageURIRenamer() + modifier.modify_node(node) + + expected = "TensorFlowModel(image_uri=my_image)" + assert expected == pasta.dump(node) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py index 6b049a510a..9d018b2727 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py @@ -90,7 +90,7 @@ def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session node = ast_call(constructor) modifier.modify_node(node) - assert "TensorFlow(image_name='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node) + assert "TensorFlow(image_uri='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node) create_image_uri.assert_called_with( REGION_NAME, "tensorflow", "ml.m4.xlarge", "1.11.0", "py2" ) @@ -111,7 +111,7 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session): expected_string = ( "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', " - "image_name='{}', model_dir=False)".format(IMAGE_URI) + "image_uri='{}', model_dir=False)".format(IMAGE_URI) ) assert expected_string == pasta.dump(node) From 2130f0b9ae5d970b8dec8e062293e54e8fac07b7 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 7 Jul 2020 15:04:55 -0700 Subject: [PATCH 2/2] fix --- src/sagemaker/cli/compatibility/v2/ast_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 3abc14837f..199efe1a3e 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -19,6 +19,7 @@ FUNCTION_CALL_MODIFIERS = [ modifiers.renamed_params.EstimatorImageURIRenamer(), + modifiers.renamed_params.ModelImageURIRenamer(), modifiers.framework_version.FrameworkVersionEnforcer(), modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(), modifiers.tf_legacy_mode.TensorBoardParameterRemover(),