diff --git a/pragma/collapse_literals.py b/pragma/collapse_literals.py index 51968b3..1bea702 100644 --- a/pragma/collapse_literals.py +++ b/pragma/collapse_literals.py @@ -1,17 +1,24 @@ import ast import logging -from .core import TrackedContextTransformer, make_function_transformer, primitive_ast_types +from .core import TrackedContextTransformer, make_function_transformer, primitive_ast_types, iterable_ast_types log = logging.getLogger(__name__) # noinspection PyPep8Naming class CollapseTransformer(TrackedContextTransformer): + collapse_iterables = False + def visit_Name(self, node): res = self.resolve_literal(node) if isinstance(res, primitive_ast_types): return res + if isinstance(res, iterable_ast_types): + if self.collapse_iterables: + return res + else: + log.debug("Not collapsing iterable {}. Change this setting with collapse_literals(collapse_iterables=True)".format(res)) return node def visit_BinOp(self, node): diff --git a/pragma/core/resolve/__init__.py b/pragma/core/resolve/__init__.py index 567eb7c..2c88e09 100644 --- a/pragma/core/resolve/__init__.py +++ b/pragma/core/resolve/__init__.py @@ -89,11 +89,13 @@ def resolve_name_or_attribute(node, ctxt): float_types = (float,) primitive_types = tuple([str, bytes, bool, type(None)] + list(num_types) + list(float_types)) +iterable_types = (list, tuple) try: primitive_ast_types = (ast.Num, ast.Str, ast.Bytes, ast.NameConstant, ast.Constant, ast.JoinedStr) except AttributeError: # Python <3.6 primitive_ast_types = (ast.Num, ast.Str, ast.Bytes, ast.NameConstant) +iterable_ast_types = (ast.List, ast.Tuple) def make_binop(op): diff --git a/pragma/core/transformer.py b/pragma/core/transformer.py index eca1b4c..5b59148 100644 --- a/pragma/core/transformer.py +++ b/pragma/core/transformer.py @@ -358,7 +358,7 @@ def visit_ExceptHandler(self, node): def make_function_transformer(transformer_type, name, description, **transformer_kwargs): @optional_argument_decorator @magic_contract - def transform(return_source=False, save_source=True, function_globals=None, **kwargs): + def transform(return_source=False, save_source=True, function_globals=None, collapse_iterables=False, explicit_only=False, **kwargs): """ :param return_source: Returns the transformed function's source code instead of compiling it :type return_source: bool @@ -366,6 +366,10 @@ def transform(return_source=False, save_source=True, function_globals=None, **kw :type save_source: bool :param function_globals: Overridden global name assignments to use when processing the function :type function_globals: dict|None + :param collapse_iterables: Collapse iterable types + :type collapse_iterables: bool + :param explicit_only: Whether to use global variables or just keyword and function_globals in the replacement context + :type explicit_only: bool :param kwargs: Any other environmental variables to provide during unrolling :type kwargs: dict :return: The transformed function, or its source code if requested @@ -375,16 +379,23 @@ def transform(return_source=False, save_source=True, function_globals=None, **kw @magic_contract(f='Callable', returns='Callable|str') def inner(f): f_mod, f_body, f_file = function_ast(f) - # Grab function globals - glbls = f.__globals__.copy() - # Grab function closure variables - if isinstance(f.__closure__, tuple): - glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)}) + if not explicit_only: + # Grab function globals + glbls = f.__globals__.copy() + # Grab function closure variables + if isinstance(f.__closure__, tuple): + glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)}) + else: + # Initialize empty context + if function_globals is None and len(kwargs) == 0: + log.warning("No global context nor function context. No collapse will occur") + glbls = dict() # Apply manual globals override if function_globals is not None: glbls.update(function_globals) # print({k: v for k, v in glbls.items() if k not in globals()}) trans = transformer_type(DictStack(glbls, kwargs), **transformer_kwargs) + trans.collapse_iterables = collapse_iterables f_mod.body[0].decorator_list = [] f_mod = trans.visit(f_mod) # print(astor.dump_tree(f_mod)) diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..397ef63 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +addopts = -s --log-cli-level 30 \ No newline at end of file diff --git a/tests/test_collapse_literals.py b/tests/test_collapse_literals.py index 593cb7d..10eb008 100644 --- a/tests/test_collapse_literals.py +++ b/tests/test_collapse_literals.py @@ -381,6 +381,20 @@ def f(): self.assertSourceEqual(f, result) + def test_iterable_option(self): + a = [1, 2, 3, 4] + + @pragma.collapse_literals(collapse_iterables=True) + def f(): + x = a + + result = ''' + def f(): + x = [1, 2, 3, 4] + ''' + + self.assertSourceEqual(f, result) + def test_indexable_operations(self): dct = dict(a=1, b=2, c=3, d=4) @@ -468,3 +482,27 @@ def f(): self.assertSourceEqual(f, result) self.assertEqual(f(), 4) + + + def test_explicit_collapse(self): + a = 2 + b = 3 + @pragma.collapse_literals(explicit_only=True, b=b) + def f(): + x = a + y = b + result = ''' + def f(): + x = a + y = 3 + ''' + self.assertSourceEqual(f, result) + + @pragma.collapse_literals(explicit_only=True) + def f(): + x = a + result = ''' + def f(): + x = a + ''' + self.assertSourceEqual(f, result)