From 3561712e4ca6644b68fb24dcaf1656e1b4b23b98 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Fri, 5 Jun 2020 11:49:22 -0700 Subject: [PATCH] change: update v2 migration tool to rename TFS classes/imports --- .../cli/compatibility/v2/ast_transformer.py | 39 ++++++ .../compatibility/v2/modifiers/__init__.py | 1 + .../cli/compatibility/v2/modifiers/tfs.py | 121 ++++++++++++++++++ .../v2/modifiers/ast_converter.py | 4 + .../compatibility/v2/modifiers/test_tfs.py | 107 ++++++++++++++++ 5 files changed, 272 insertions(+) create mode 100644 src/sagemaker/cli/compatibility/v2/modifiers/tfs.py create mode 100644 tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tfs.py diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 3acd05a473..128ddbe235 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -22,8 +22,13 @@ modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(), modifiers.tf_legacy_mode.TensorBoardParameterRemover(), modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(), + modifiers.tfs.TensorFlowServingConstructorRenamer(), ] +IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()] + +IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()] + class ASTTransformer(ast.NodeTransformer): """An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and @@ -46,3 +51,37 @@ def visit_Call(self, node): ast.fix_missing_locations(node) return node + + def visit_Import(self, node): + """Visits an ``ast.Import`` node and returns a modified node, if needed. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. + + Args: + node (ast.Import): a node that represents an import statement. + + Returns: + ast.Import: a node that represents an import statement, which has + potentially been modified from the original input. + """ + for import_checker in IMPORT_MODIFIERS: + import_checker.check_and_modify_node(node) + + ast.fix_missing_locations(node) + return node + + def visit_ImportFrom(self, node): + """Visits an ``ast.ImportFrom`` node and returns a modified node, if needed. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. + + Args: + node (ast.ImportFrom): a node that represents an import statement. + + Returns: + ast.ImportFrom: a node that represents an import statement, which has + potentially been modified from the original input. + """ + for import_checker in IMPORT_FROM_MODIFIERS: + import_checker.check_and_modify_node(node) + + ast.fix_missing_locations(node) + return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index a0bb4763fb..3a68a0b735 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -17,4 +17,5 @@ deprecated_params, framework_version, tf_legacy_mode, + tfs, ) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py b/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py new file mode 100644 index 0000000000..2bf36d6898 --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py @@ -0,0 +1,121 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes to modify TensorFlow Serving code to be compatible with SageMaker Python SDK v2.""" +from __future__ import absolute_import + +import ast + +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + + +class TensorFlowServingConstructorRenamer(Modifier): + """A class to rename TensorFlow Serving classes.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node instantiates a TensorFlow Serving class. + + This looks for the following calls: + + - ``sagemaker.tensorflow.serving.Model`` + - ``sagemaker.tensorflow.serving.Predictor`` + - ``Predictor`` + + Because ``Model`` can refer to either ``sagemaker.tensorflow.serving.Model`` + or :class:`~sagemaker.model.Model`, ``Model`` on its own is not sufficient + for indicating a TFS Model object. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` instantiates a TensorFlow Serving class. + """ + if isinstance(node.func, ast.Name): + return node.func.id == "Predictor" + + if not (isinstance(node.func, ast.Attribute) and node.func.attr in ("Model", "Predictor")): + return False + + return ( + isinstance(node.func.value, ast.Attribute) + and node.func.value.attr == "serving" + and isinstance(node.func.value.value, ast.Attribute) + and node.func.value.value.attr == "tensorflow" + and isinstance(node.func.value.value.value, ast.Name) + and node.func.value.value.value.id == "sagemaker" + ) + + def modify_node(self, node): + """Modifies the ``ast.Call`` node to use the v2 classes for TensorFlow Serving: + + - ``sagemaker.tensorflow.TensorFlowModel`` + - ``sagemaker.tensorflow.TensorFlowPredictor`` + + Args: + node (ast.Call): a node that represents a TensorFlow Serving constructor. + """ + if isinstance(node.func, ast.Name): + node.func.id = self._new_cls_name(node.func.id) + else: + node.func.attr = self._new_cls_name(node.func.attr) + node.func.value = node.func.value.value + + def _new_cls_name(self, cls_name): + """Returns the v2 class name.""" + return "TensorFlow{}".format(cls_name) + + +class TensorFlowServingImportFromRenamer(Modifier): + """A class to update import statements starting with ``from sagemaker.tensorflow.serving``.""" + + def node_should_be_modified(self, node): + """Checks if the import statement imports from the ``sagemaker.tensorflow.serving`` module. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.ImportFrom`` uses the ``sagemaker.tensorflow.serving`` module. + """ + return node.module == "sagemaker.tensorflow.serving" + + def modify_node(self, node): + """Changes the ``ast.ImportFrom`` node's module to ``sagemaker.tensorflow`` and updates the + imported class names to ``TensorFlowModel`` and ``TensorFlowPredictor``, as applicable. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + """ + node.module = "sagemaker.tensorflow" + + for cls in node.names: + cls.name = "TensorFlow{}".format(cls.name) + + +class TensorFlowServingImportRenamer(Modifier): + """A class to update ``import sagemaker.tensorflow.serving``.""" + + def check_and_modify_node(self, node): + """Checks if the ``ast.Import`` node imports the ``sagemaker.tensorflow.serving`` module + and, if so, changes it to ``sagemaker.tensorflow``. + + Args: + node (ast.Import): a node that represents an import statement. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + """ + for module in node.names: + if module.name == "sagemaker.tensorflow.serving": + module.name = "sagemaker.tensorflow" diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/ast_converter.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/ast_converter.py index 09c0f5d834..a3894edc6b 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/ast_converter.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/ast_converter.py @@ -17,3 +17,7 @@ def ast_call(code): return pasta.parse(code).body[0].value + + +def ast_import(code): + return pasta.parse(code).body[0] diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tfs.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tfs.py new file mode 100644 index 0000000000..a5052fa729 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tfs.py @@ -0,0 +1,107 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pasta + +from sagemaker.cli.compatibility.v2.modifiers import tfs +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import + + +def test_constructor_node_should_be_modified_tfs_constructor(): + tfs_constructors = ( + "sagemaker.tensorflow.serving.Model()", + "sagemaker.tensorflow.serving.Predictor()", + "Predictor()", + ) + + modifier = tfs.TensorFlowServingConstructorRenamer() + + for constructor in tfs_constructors: + node = ast_call(constructor) + assert modifier.node_should_be_modified(node) is True + + +def test_constructor_node_should_be_modified_random_function_call(): + modifier = tfs.TensorFlowServingConstructorRenamer() + node = ast_call("Model()") + assert modifier.node_should_be_modified(node) is False + + +def test_constructor_modify_node(): + modifier = tfs.TensorFlowServingConstructorRenamer() + + node = ast_call("sagemaker.tensorflow.serving.Model()") + modifier.modify_node(node) + assert "sagemaker.tensorflow.TensorFlowModel()" == pasta.dump(node) + + node = ast_call("sagemaker.tensorflow.serving.Predictor()") + modifier.modify_node(node) + assert "sagemaker.tensorflow.TensorFlowPredictor()" == pasta.dump(node) + + node = ast_call("Predictor()") + modifier.modify_node(node) + assert "TensorFlowPredictor()" == pasta.dump(node) + + +def test_import_from_node_should_be_modified_tfs_module(): + import_statements = ( + "from sagemaker.tensorflow.serving import Model, Predictor", + "from sagemaker.tensorflow.serving import Predictor", + "from sagemaker.tensorflow.serving import Model as tfsModel", + ) + + modifier = tfs.TensorFlowServingImportFromRenamer() + + for import_from in import_statements: + node = ast_import(import_from) + assert modifier.node_should_be_modified(node) is True + + +def test_import_from_node_should_be_modified_random_import(): + modifier = tfs.TensorFlowServingImportFromRenamer() + node = ast_import("from sagemaker import Session") + assert modifier.node_should_be_modified(node) is False + + +def test_import_from_modify_node(): + modifier = tfs.TensorFlowServingImportFromRenamer() + + node = ast_import("from sagemaker.tensorflow.serving import Model, Predictor") + modifier.modify_node(node) + expected_result = "from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor" + assert expected_result == pasta.dump(node) + + node = ast_import("from sagemaker.tensorflow.serving import Predictor") + modifier.modify_node(node) + assert "from sagemaker.tensorflow import TensorFlowPredictor" == pasta.dump(node) + + node = ast_import("from sagemaker.tensorflow.serving import Model as tfsModel") + modifier.modify_node(node) + assert "from sagemaker.tensorflow import TensorFlowModel as tfsModel" == pasta.dump(node) + + +def test_import_check_and_modify_node_tfs_import(): + modifier = tfs.TensorFlowServingImportRenamer() + node = ast_import("import sagemaker.tensorflow.serving") + modifier.check_and_modify_node(node) + assert "import sagemaker.tensorflow" == pasta.dump(node) + + +def test_import_check_and_modify_node_random_import(): + modifier = tfs.TensorFlowServingImportRenamer() + + import_statement = "import random" + node = ast_import(import_statement) + modifier.check_and_modify_node(node) + assert import_statement == pasta.dump(node)