Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pragma/collapse_literals.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import ast
import logging

from .core import TrackedContextTransformer, make_function_transformer, primitive_ast_types
from .core import TrackedContextTransformer, make_function_transformer, primitive_ast_types, iterable_ast_types

log = logging.getLogger(__name__)


# noinspection PyPep8Naming
class CollapseTransformer(TrackedContextTransformer):
collapse_iterables = False

def visit_Name(self, node):
res = self.resolve_literal(node)
if isinstance(res, primitive_ast_types):
return res
if isinstance(res, iterable_ast_types):
if self.collapse_iterables:
return res
else:
log.debug("Not collapsing iterable {}. Change this setting with collapse_literals(collapse_iterables=True)".format(res))
return node

def visit_BinOp(self, node):
Expand Down
2 changes: 2 additions & 0 deletions pragma/core/resolve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def resolve_name_or_attribute(node, ctxt):
float_types = (float,)

primitive_types = tuple([str, bytes, bool, type(None)] + list(num_types) + list(float_types))
iterable_types = (list, tuple)

try:
primitive_ast_types = (ast.Num, ast.Str, ast.Bytes, ast.NameConstant, ast.Constant, ast.JoinedStr)
except AttributeError: # Python <3.6
primitive_ast_types = (ast.Num, ast.Str, ast.Bytes, ast.NameConstant)
iterable_ast_types = (ast.List, ast.Tuple)


def make_binop(op):
Expand Down
23 changes: 17 additions & 6 deletions pragma/core/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,18 @@ def visit_ExceptHandler(self, node):
def make_function_transformer(transformer_type, name, description, **transformer_kwargs):
@optional_argument_decorator
@magic_contract
def transform(return_source=False, save_source=True, function_globals=None, **kwargs):
def transform(return_source=False, save_source=True, function_globals=None, collapse_iterables=False, explicit_only=False, **kwargs):
"""
:param return_source: Returns the transformed function's source code instead of compiling it
:type return_source: bool
:param save_source: Saves the function source code to a tempfile to make it inspectable
:type save_source: bool
:param function_globals: Overridden global name assignments to use when processing the function
:type function_globals: dict|None
:param collapse_iterables: Collapse iterable types
:type collapse_iterables: bool
:param explicit_only: Whether to use global variables or just keyword and function_globals in the replacement context
:type explicit_only: bool
:param kwargs: Any other environmental variables to provide during unrolling
:type kwargs: dict
:return: The transformed function, or its source code if requested
Expand All @@ -375,16 +379,23 @@ def transform(return_source=False, save_source=True, function_globals=None, **kw
@magic_contract(f='Callable', returns='Callable|str')
def inner(f):
f_mod, f_body, f_file = function_ast(f)
# Grab function globals
glbls = f.__globals__.copy()
# Grab function closure variables
if isinstance(f.__closure__, tuple):
glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)})
if not explicit_only:
# Grab function globals
glbls = f.__globals__.copy()
# Grab function closure variables
if isinstance(f.__closure__, tuple):
glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)})
else:
# Initialize empty context
if function_globals is None and len(kwargs) == 0:
log.warning("No global context nor function context. No collapse will occur")
glbls = dict()
# Apply manual globals override
if function_globals is not None:
glbls.update(function_globals)
# print({k: v for k, v in glbls.items() if k not in globals()})
trans = transformer_type(DictStack(glbls, kwargs), **transformer_kwargs)
trans.collapse_iterables = collapse_iterables
f_mod.body[0].decorator_list = []
f_mod = trans.visit(f_mod)
# print(astor.dump_tree(f_mod))
Expand Down
4 changes: 4 additions & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
filterwarnings =
ignore::DeprecationWarning
addopts = -s --log-cli-level 30
38 changes: 38 additions & 0 deletions tests/test_collapse_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,20 @@ def f():

self.assertSourceEqual(f, result)

def test_iterable_option(self):
a = [1, 2, 3, 4]

@pragma.collapse_literals(collapse_iterables=True)
def f():
x = a

result = '''
def f():
x = [1, 2, 3, 4]
'''

self.assertSourceEqual(f, result)

def test_indexable_operations(self):
dct = dict(a=1, b=2, c=3, d=4)

Expand Down Expand Up @@ -468,3 +482,27 @@ def f():

self.assertSourceEqual(f, result)
self.assertEqual(f(), 4)


def test_explicit_collapse(self):
a = 2
b = 3
@pragma.collapse_literals(explicit_only=True, b=b)
def f():
x = a
y = b
result = '''
def f():
x = a
y = 3
'''
self.assertSourceEqual(f, result)

@pragma.collapse_literals(explicit_only=True)
def f():
x = a
result = '''
def f():
x = a
'''
self.assertSourceEqual(f, result)