From 872ba3967a4b1e92fe42075a107a0b8939e3717f Mon Sep 17 00:00:00 2001 From: owahab Date: Sat, 30 May 2020 14:06:26 +0200 Subject: [PATCH 1/3] Major WIP. --- tests/unit/v2/__init__.py | 0 tests/unit/v2/samples/simple.txt | 4 +++ tests/unit/v2/test_framework_version.py | 13 ++++++++ tests/unit/v2/test_transformer.py | 32 +++++++++++++++++++ tests/unit/v2/utils.py | 9 ++++++ tools/compatibility/v2/ast_transformer.py | 2 +- .../v2/modifiers/framework_version.py | 2 +- 7 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 tests/unit/v2/__init__.py create mode 100644 tests/unit/v2/samples/simple.txt create mode 100644 tests/unit/v2/test_framework_version.py create mode 100644 tests/unit/v2/test_transformer.py create mode 100644 tests/unit/v2/utils.py diff --git a/tests/unit/v2/__init__.py b/tests/unit/v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/v2/samples/simple.txt b/tests/unit/v2/samples/simple.txt new file mode 100644 index 0000000000..cfac17cbbb --- /dev/null +++ b/tests/unit/v2/samples/simple.txt @@ -0,0 +1,4 @@ +TensorFlow(entry_point="foo.py") +sagemaker.tensorflow.TensorFlow() +m = MXNet() +sagemaker.mxnet.MXNet() \ No newline at end of file diff --git a/tests/unit/v2/test_framework_version.py b/tests/unit/v2/test_framework_version.py new file mode 100644 index 0000000000..c2d9998d02 --- /dev/null +++ b/tests/unit/v2/test_framework_version.py @@ -0,0 +1,13 @@ +import unittest + + +class FrameworkVersion(unittest.TestCase): + def setUp(self) -> None: + pass + + def test_something(self): + self.assertEqual(True, False) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/v2/test_transformer.py b/tests/unit/v2/test_transformer.py new file mode 100644 index 0000000000..6db8947cc0 --- /dev/null +++ b/tests/unit/v2/test_transformer.py @@ -0,0 +1,32 @@ +import ast +import unittest + +from tests.unit.v2.utils import get_sample_file +from tools.compatibility.v2.ast_transformer import ASTTransformer +import pasta + + +class TransformerTest(unittest.TestCase): + def setUp(self) -> None: + self.transformer_class = ASTTransformer() + + def test_simple_transform(self): + sample = get_sample_file('simple.txt') + rewrite = self.transformer_class.visit( + ast.parse( + sample + ) + ) + + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') +m = MXNet(framework_version='1.2.0') +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" + + self.assertEqual(pasta.dump(rewrite), expected) + + +if __name__ == '__main__': + unittest.main() + + diff --git a/tests/unit/v2/utils.py b/tests/unit/v2/utils.py new file mode 100644 index 0000000000..1c3cf519f5 --- /dev/null +++ b/tests/unit/v2/utils.py @@ -0,0 +1,9 @@ +from os.path import join + +SAMPLES_DIRECTORY = "/Users/owahab/Desktop/personal/sagemaker-python-sdk/tests/unit/v2/samples/" + + +def get_sample_file(filename): + file_path = join(SAMPLES_DIRECTORY, filename) + with open(file_path) as file_content: + return file_content.read() diff --git a/tools/compatibility/v2/ast_transformer.py b/tools/compatibility/v2/ast_transformer.py index 87d7dddcb7..7171840ad0 100644 --- a/tools/compatibility/v2/ast_transformer.py +++ b/tools/compatibility/v2/ast_transformer.py @@ -15,7 +15,7 @@ import ast -from modifiers import framework_version +from tools.compatibility.v2.modifiers import framework_version FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()] diff --git a/tools/compatibility/v2/modifiers/framework_version.py b/tools/compatibility/v2/modifiers/framework_version.py index 2c2a440ba7..0115526549 100644 --- a/tools/compatibility/v2/modifiers/framework_version.py +++ b/tools/compatibility/v2/modifiers/framework_version.py @@ -15,7 +15,7 @@ import ast -from modifiers.modifier import Modifier +from tools.compatibility.v2.modifiers.modifier import Modifier FRAMEWORK_DEFAULTS = { "Chainer": "4.1.0", From 64c66f22b338a62117d967817fdba471365e7916 Mon Sep 17 00:00:00 2001 From: owahab Date: Mon, 8 Jun 2020 13:31:25 +0200 Subject: [PATCH 2/3] feature: Add tests for ast_transformer --- tests/unit/v2/samples/simple.txt | 4 -- tests/unit/v2/test_framework_version.py | 13 ------ tests/unit/v2/test_transformer.py | 55 ++++++++++++++++--------- tests/unit/v2/utils.py | 9 ---- 4 files changed, 35 insertions(+), 46 deletions(-) delete mode 100644 tests/unit/v2/samples/simple.txt delete mode 100644 tests/unit/v2/test_framework_version.py delete mode 100644 tests/unit/v2/utils.py diff --git a/tests/unit/v2/samples/simple.txt b/tests/unit/v2/samples/simple.txt deleted file mode 100644 index cfac17cbbb..0000000000 --- a/tests/unit/v2/samples/simple.txt +++ /dev/null @@ -1,4 +0,0 @@ -TensorFlow(entry_point="foo.py") -sagemaker.tensorflow.TensorFlow() -m = MXNet() -sagemaker.mxnet.MXNet() \ No newline at end of file diff --git a/tests/unit/v2/test_framework_version.py b/tests/unit/v2/test_framework_version.py deleted file mode 100644 index c2d9998d02..0000000000 --- a/tests/unit/v2/test_framework_version.py +++ /dev/null @@ -1,13 +0,0 @@ -import unittest - - -class FrameworkVersion(unittest.TestCase): - def setUp(self) -> None: - pass - - def test_something(self): - self.assertEqual(True, False) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/unit/v2/test_transformer.py b/tests/unit/v2/test_transformer.py index 6db8947cc0..589f95a3dd 100644 --- a/tests/unit/v2/test_transformer.py +++ b/tests/unit/v2/test_transformer.py @@ -1,32 +1,47 @@ import ast -import unittest - -from tests.unit.v2.utils import get_sample_file -from tools.compatibility.v2.ast_transformer import ASTTransformer +from sagemaker.tools.compatibility.v2.ast_transformer import ASTTransformer import pasta -class TransformerTest(unittest.TestCase): - def setUp(self) -> None: - self.transformer_class = ASTTransformer() - - def test_simple_transform(self): - sample = get_sample_file('simple.txt') - rewrite = self.transformer_class.visit( - ast.parse( - sample - ) +def test_code_needs_transform(): + simple = """ +TensorFlow(entry_point="foo.py") +sagemaker.tensorflow.TensorFlow() +m = MXNet() +sagemaker.mxnet.MXNet() +""" + transformer_class = ASTTransformer() + rewrite = transformer_class.visit( + ast.parse( + simple ) - - expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') + ) + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') m = MXNet(framework_version='1.2.0') sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" - self.assertEqual(pasta.dump(rewrite), expected) + assert pasta.dump( + rewrite + ) == expected -if __name__ == '__main__': - unittest.main() - +def test_code_does_not_need_transform(): + simple = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') +m = MXNet(framework_version='1.2.0') +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" + transformer_class = ASTTransformer() + rewrite = transformer_class.visit( + ast.parse( + simple + ) + ) + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') +m = MXNet(framework_version='1.2.0') +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" + assert pasta.dump( + rewrite + ) == expected diff --git a/tests/unit/v2/utils.py b/tests/unit/v2/utils.py deleted file mode 100644 index 1c3cf519f5..0000000000 --- a/tests/unit/v2/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -from os.path import join - -SAMPLES_DIRECTORY = "/Users/owahab/Desktop/personal/sagemaker-python-sdk/tests/unit/v2/samples/" - - -def get_sample_file(filename): - file_path = join(SAMPLES_DIRECTORY, filename) - with open(file_path) as file_content: - return file_content.read() From fa86f399a18f60fd63e139053ad136a912fd0281 Mon Sep 17 00:00:00 2001 From: owahab Date: Tue, 16 Jun 2020 09:58:02 +0200 Subject: [PATCH 3/3] Cleaned up failing tests. --- tests/unit/v2/test_transformer.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/unit/v2/test_transformer.py b/tests/unit/v2/test_transformer.py index 589f95a3dd..e8487de0b1 100644 --- a/tests/unit/v2/test_transformer.py +++ b/tests/unit/v2/test_transformer.py @@ -1,3 +1,5 @@ +from __future__ import absolute_import + import ast from sagemaker.tools.compatibility.v2.ast_transformer import ASTTransformer import pasta @@ -10,20 +12,15 @@ def test_code_needs_transform(): m = MXNet() sagemaker.mxnet.MXNet() """ + transformer_class = ASTTransformer() - rewrite = transformer_class.visit( - ast.parse( - simple - ) - ) + rewrite = transformer_class.visit(ast.parse(simple)) expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') m = MXNet(framework_version='1.2.0') sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" - assert pasta.dump( - rewrite - ) == expected + assert pasta.dump(rewrite) == expected def test_code_does_not_need_transform(): @@ -32,16 +29,10 @@ def test_code_does_not_need_transform(): m = MXNet(framework_version='1.2.0') sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" transformer_class = ASTTransformer() - rewrite = transformer_class.visit( - ast.parse( - simple - ) - ) + rewrite = transformer_class.visit(ast.parse(simple)) expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') m = MXNet(framework_version='1.2.0') sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" - assert pasta.dump( - rewrite - ) == expected + assert pasta.dump(rewrite) == expected