Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/deps/untypy/test/impl/test_bound_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_bound_generic_caller_error(self):
with self.assertRaises(UntypyTypeError) as cm:
instance.insert("this should be an int")

self.assertTrue("instance.insert" in cm.exception.last_responsable().source_line)
self.assertTrue("instance.insert" in cm.exception.last_responsable().source_lines_span())

def test_bound_generic_protocol_style_ok(self):
instance = Aint()
Expand Down Expand Up @@ -99,4 +99,4 @@ def test_bound_generic_return_error(self):
with self.assertRaises(UntypyTypeError) as cm:
instance.some_string()

self.assertTrue("def some_string(self) -> T:" in cm.exception.last_responsable().source_line)
self.assertTrue("def some_string(self) -> T:" in cm.exception.last_responsable().source_lines_span())
4 changes: 2 additions & 2 deletions python/deps/untypy/test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def wrap(self, err: UntypyTypeError) -> UntypyTypeError:
responsable=Location(
file="dummy",
line_no=0,
source_line="dummy"
line_span=1
)
))

Expand All @@ -27,7 +27,7 @@ def __init__(self, typevars: dict[TypeVar, Any] = dict()):
super().__init__(typevars.copy(), Location(
file="dummy",
line_no=0,
source_line="dummy"
line_span=1
), checkedpkgprefixes=["test"])


Expand Down
Empty file.
36 changes: 36 additions & 0 deletions python/deps/untypy/test/util_test/test_return_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import ast
import unittest

from untypy.util.return_traces import ReturnTraceManager, ReturnTracesTransformer


class TestAstTransform(unittest.TestCase):

def test_ast_transform(self):
src = """
def foo(flag: bool) -> int:
print('Hello World')
if flag:
return 1
else:
return 'you stupid'
"""
target = """
def foo(flag: bool) -> int:
print('Hello World')
if flag:
untypy._before_return(0)
return 1
else:
untypy._before_return(1)
return 'you stupid'
"""

tree = ast.parse(src)
mgr = ReturnTraceManager()
ReturnTracesTransformer("<dummyfile>", mgr).visit(tree)
ast.fix_missing_locations(tree)
self.assertEqual(ast.unparse(tree).strip(), target.strip())
self.assertEqual(mgr.get(0), ("<dummyfile>", 5))
self.assertEqual(mgr.get(1), ("<dummyfile>", 7))

32 changes: 21 additions & 11 deletions python/deps/untypy/untypy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,21 @@
UntypyAstImportTransformer
from .patching.import_hook import install_import_hook
from .util.condition import FunctionCondition
from .util.return_traces import ReturnTracesTransformer, before_return, GlobalReturnTraceManager
from .util.tranformer_combinator import TransformerCombinator

GlobalConfig = DefaultConfig

"""
This function is called before any return statement, to store which was the last return.
For this the AST is transformed using ReturnTracesTransformer.
Must be in untypy so it can be used in transformed module.
Must also be in other module, so it can be used from inside (No circular imports).
"""
_before_return = before_return

_importhook_transformer_builder = lambda path, file: TransformerCombinator(UntypyAstTransformer(),
ReturnTracesTransformer(file))

def just_install_hook(prefixes=[]):
def predicate(module_name):
Expand All @@ -22,18 +34,15 @@ def predicate(module_name):
return True
return False

install_import_hook(predicate, lambda path: UntypyAstTransformer())

install_import_hook(predicate, _importhook_transformer_builder)

def just_transform(source, modname, symbol='exec'):
tree = compile(source, modname, symbol, flags=ast.PyCF_ONLY_AST, dont_inherit=True, optimize=-1)
transform_tree(tree)
return tree

def transform_tree(tree):
def transform_tree(tree, file):
UntypyAstTransformer().visit(tree)
ReturnTracesTransformer(file).visit(tree)
ast.fix_missing_locations(tree)


def enable(*, recursive: bool = True, root: Union[ModuleType, str, None] = None, prefixes: list[str] = []) -> None:
global GlobalConfig
caller = _find_calling_module()
Expand Down Expand Up @@ -66,9 +75,9 @@ def predicate(module_name):
else:
raise AssertionError("You cannot run 'untypy.enable()' twice!")

transformer = lambda path: UntypyAstTransformer()
transformer = _importhook_transformer_builder
install_import_hook(predicate, transformer)
_exec_module_patched(root, exit_after, transformer(caller.__name__.split(".")))
_exec_module_patched(root, exit_after, transformer(caller.__name__.split("."), caller.__file__))


def enable_on_imports(*prefixes):
Expand All @@ -88,9 +97,9 @@ def predicate(module_name: str):
else:
return False

transformer = lambda path: UntypyAstImportTransformer(predicate, path)
transformer = _importhook_transformer_builder
install_import_hook(predicate, transformer)
_exec_module_patched(caller, True, transformer(caller.__name__.split(".")))
_exec_module_patched(caller, True, transformer(caller.__name__.split("."), caller.__file__))


def _exec_module_patched(mod: ModuleType, exit_after: bool, transformer: ast.NodeTransformer):
Expand All @@ -104,6 +113,7 @@ def _exec_module_patched(mod: ModuleType, exit_after: bool, transformer: ast.Nod
"\tuntypy.enable()")

transformer.visit(tree)
ReturnTracesTransformer(mod.__file__).visit(tree)
ast.fix_missing_locations(tree)
patched_mod = compile(tree, mod.__file__, 'exec', dont_inherit=True, optimize=-1)
stack = list(map(lambda s: s.frame, inspect.stack()))
Expand Down
86 changes: 68 additions & 18 deletions python/deps/untypy/untypy/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,66 @@

import inspect
from enum import Enum
from os.path import relpath
from typing import Any, Optional, Tuple, Iterable


class Location:
file: str
line_no: int
source_line: str
line_span : int
source_lines: Optional[str]

def __init__(self, file: str, line_no: int, source_line: str):
def __init__(self, file: str, line_no: int, line_span : int):
self.file = file
self.line_no = line_no
self.source_line = source_line
self.line_span = line_span
self.source_lines = None

def source(self) -> Optional[str]:
if self.source_lines is None:
try:
with open(self.file, "r") as f:
self.source_lines = f.read()
except OSError:
pass

return self.source_lines

def source_lines_span(self) -> Optional[str]:
# This is still used for unit testing

source = self.source()
if source is None:
return None

buf = ""

for i, line in enumerate(source.splitlines()):
if (i + 1) in range(self.line_no, self.line_no + self.line_span):
buf += f"\n{line}"

return buf

def __str__(self):
buf = f"{self.file}:{self.line_no}"
if self.source_line:
for i, line in enumerate(self.source_line.splitlines()):
if i < 5:
buf += f"\n{'{:3}'.format(self.line_no + i)} | {line}"
if i >= 5:
buf += "\n | ..."
buf = f"{relpath(self.file)}:{self.line_no}"
source = self.source()
if source is None:
buf += f"\n{'{:3}'.format(self.line_no)} | <source code not found>"
return buf

start = max(self.line_no - 2, 1)
end = start + 5
for i, line in enumerate(source.splitlines()):
if (i + 1) == self.line_no:
buf += f"\n{'{:3}'.format(i + 1)} > {line}"
elif (i + 1) in range(start, end):
buf += f"\n{'{:3}'.format(i + 1)} | {line}"

return buf

def __repr__(self):
return f"Location(file={self.file.__repr__()}, line_no={self.line_no.__repr__()}, source_line={repr(self.source_line)})"
return f"Location(file={self.file.__repr__()}, line_no={self.line_no.__repr__()}, line_span={self.line_span})"

def __eq__(self, other):
if not isinstance(other, Location):
Expand All @@ -39,13 +74,13 @@ def from_code(obj) -> Location:
return Location(
file=inspect.getfile(obj),
line_no=inspect.getsourcelines(obj)[1],
source_line="".join(inspect.getsourcelines(obj)[0]),
line_span=len(inspect.getsourcelines(obj)[0]),
)
except Exception:
return Location(
file=inspect.getfile(obj),
line_no=1,
source_line=repr(obj)
line_span=1,
)

@staticmethod
Expand All @@ -55,29 +90,44 @@ def from_stack(stack) -> Location:
return Location(
file=stack.filename,
line_no=stack.lineno,
source_line=stack.code_context[0]
line_span=1
)
except Exception:
return Location(
file=stack.filename,
line_no=stack.lineno,
source_line=None
line_span=1
)
else: # assume sys._getframe(...)
try:
source_line = inspect.findsource(stack.f_code)[0][stack.f_lineno - 1]
return Location(
file=stack.f_code.co_filename,
line_no=stack.f_lineno,
source_line=source_line
line_span=1
)
except Exception:
return Location(
file=stack.f_code.co_filename,
line_no=stack.f_lineno,
source_line=None
line_span=1
)

def narrow_in_span(self, reti_loc : Tuple[str, int]):
"""
Use new Location if inside of span of this Location
:param reti_loc: filename and line_no
:return: a new Location, else self
"""
file, line = reti_loc
if self.file == file and line in range(self.line_no, self.line_no + self.line_span):
return Location(
file=file,
line_no=line,
line_span=1
)
else:
return self


class Frame:
type_declared: str
Expand Down
2 changes: 1 addition & 1 deletion python/deps/untypy/untypy/impl/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def responsable(self) -> Optional[Location]:
return Location(
file=inspect.getfile(self.generator.gi_frame),
line_no=inspect.getsourcelines(self.generator.gi_frame)[1],
source_line="\n".join(inspect.getsourcelines(self.generator.gi_frame)[0]),
line_span=len(inspect.getsourcelines(self.generator.gi_frame)[0]),
)
except OSError: # this call does not work all the time
pass
Expand Down
2 changes: 1 addition & 1 deletion python/deps/untypy/untypy/impl/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def responsable(self) -> Optional[Location]:
return Location(
file=inspect.getfile(self.iter.gi_frame),
line_no=inspect.getsourcelines(self.iter.gi_frame)[1],
source_line="\n".join(inspect.getsourcelines(self.iter.gi_frame)[0]),
line_span=len(inspect.getsourcelines(self.iter.gi_frame)[0]),
)
except OSError: # this call does not work all the time
pass
Expand Down
2 changes: 1 addition & 1 deletion python/deps/untypy/untypy/impl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _find_bound_typevars(clas: type) -> (type, Dict[TypeVar, Any]):
[Location(
file=inspect.getfile(clas),
line_no=inspect.getsourcelines(clas)[1],
source_line="".join(inspect.getsourcelines(clas)[0]))])
line_span=len(inspect.getsourcelines(clas)[0]))])
return (clas.__origin__, dict(zip(keys, values)))


Expand Down
2 changes: 1 addition & 1 deletion python/deps/untypy/untypy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def find_location(fn) -> Optional[Location]:
return Location(
file=inspect.getfile(fn),
line_no=inspect.getsourcelines(fn)[1],
source_line="".join(inspect.getsourcelines(fn)[0]),
line_span=len(inspect.getsourcelines(fn)[0]),
)
except: # Failes on builtins
return None
4 changes: 2 additions & 2 deletions python/deps/untypy/untypy/patching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def patch_class(clas: type, cfg: Config):
declared_location=Location(
file=inspect.getfile(clas),
line_no=inspect.getsourcelines(clas)[1],
source_line="".join(inspect.getsourcelines(clas)[0]),
line_span=len(inspect.getsourcelines(clas)[0]),
), checkedpkgprefixes=cfg.checkedprefixes)
except (TypeError, OSError) as e: # Built in types
ctx = DefaultCreationContext(
typevars=dict(),
declared_location=Location(
file="<not found>",
line_no=0,
source_line="<not found>",
line_span=1
), checkedpkgprefixes=cfg.checkedprefixes)

setattr(clas, '__patched', True)
Expand Down
2 changes: 1 addition & 1 deletion python/deps/untypy/untypy/patching/ast_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Any


class UntypyAstTransformer(ast.NodeTransformer):
Expand Down
8 changes: 4 additions & 4 deletions python/deps/untypy/untypy/patching/import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from importlib.util import decode_source


def install_import_hook(should_patch_predicate: Callable[[str], bool],
def install_import_hook(should_patch_predicate: Callable[[str, str], bool],
transformer: Callable[[str], ast.NodeTransformer]):
import sys

Expand All @@ -20,7 +20,7 @@ def install_import_hook(should_patch_predicate: Callable[[str], bool],

class UntypyFinder(MetaPathFinder):

def __init__(self, inner_finder: MetaPathFinder, should_patch_predicate: Callable[[str], bool],
def __init__(self, inner_finder: MetaPathFinder, should_patch_predicate: Callable[[str, str], bool],
transformer: Callable[[str], ast.NodeTransformer]):
self.inner_finder = inner_finder
self.should_patch_predicate = should_patch_predicate
Expand All @@ -41,15 +41,15 @@ def should_instrument(self, module_name: str) -> bool:

class UntypyLoader(SourceFileLoader):

def __init__(self, fullname, path, transformer: Callable[[str], ast.NodeTransformer]):
def __init__(self, fullname, path, transformer: Callable[[str, str], ast.NodeTransformer]):
super().__init__(fullname, path)
self.transformer = transformer

def source_to_code(self, data, path, *, _optimize=-1):
source = decode_source(data)
tree = compile(source, path, 'exec', ast.PyCF_ONLY_AST,
dont_inherit=True, optimize=_optimize)
self.transformer(self.name.split('.')).visit(tree)
self.transformer(self.name.split('.'), self.path).visit(tree)
ast.fix_missing_locations(tree)
return compile(tree, path, 'exec', dont_inherit=True, optimize=_optimize)

Expand Down
Loading