diff --git a/src/attr/_make.py b/src/attr/_make.py index 8bc8634d6..76b1c62f4 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -286,6 +286,36 @@ def attrib( ) +def _compile_and_eval(script, globs, locs=None, filename=""): + """ + "Exec" the script with the given global (globs) and local (locs) variables. + """ + bytecode = compile(script, filename, "exec") + eval(bytecode, globs, locs) + + +def _make_method(name, script, filename, globs=None): + """ + Create the method with the script given and return the method object. + """ + locs = {} + if globs is None: + globs = {} + + _compile_and_eval(script, globs, locs, filename) + + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + linecache.cache[filename] = ( + len(script), + None, + script.splitlines(True), + filename, + ) + + return locs[name] + + def _make_attr_tuple_class(cls_name, attr_names): """ Create a tuple subclass to hold `Attribute`s for an `attrs` class. @@ -309,8 +339,7 @@ class MyClassAttributes(tuple): else: attr_class_template.append(" pass") globs = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} - eval(compile("\n".join(attr_class_template), "", "exec"), globs) - + _compile_and_eval("\n".join(attr_class_template), globs) return globs[attr_class_name] @@ -1591,21 +1620,7 @@ def append_hash_computation_lines(prefix, indent): append_hash_computation_lines("return ", tab) script = "\n".join(method_lines) - globs = {} - locs = {} - bytecode = compile(script, unique_filename, "exec") - eval(bytecode, globs, locs) - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - linecache.cache[unique_filename] = ( - len(script), - None, - script.splitlines(True), - unique_filename, - ) - - return locs["__hash__"] + return _make_method("__hash__", script, unique_filename) def _add_hash(cls, attrs): @@ -1661,20 +1676,7 @@ def _make_eq(cls, attrs): lines.append(" return True") script = "\n".join(lines) - globs = {} - locs = {} - bytecode = compile(script, unique_filename, "exec") - eval(bytecode, globs, locs) - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - linecache.cache[unique_filename] = ( - len(script), - None, - script.splitlines(True), - unique_filename, - ) - return locs["__eq__"] + return _make_method("__eq__", script, unique_filename) def _make_order(cls, attrs): @@ -1949,8 +1951,10 @@ def _make_init( has_global_on_setattr, attrs_init, ) - locs = {} - bytecode = compile(script, unique_filename, "exec") + if cls.__module__ in sys.modules: + # This makes typing.get_type_hints(CLS.__init__) resolve string types. + globs.update(sys.modules[cls.__module__].__dict__) + globs.update({"NOTHING": NOTHING, "attr_dict": attr_dict}) if needs_cached_setattr: @@ -1958,18 +1962,12 @@ def _make_init( # setattr hooks. globs["_cached_setattr"] = _obj_setattr - eval(bytecode, globs, locs) - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - linecache.cache[unique_filename] = ( - len(script), - None, - script.splitlines(True), + init = _make_method( + "__attrs_init__" if attrs_init else "__init__", + script, unique_filename, + globs, ) - - init = locs["__attrs_init__"] if attrs_init else locs["__init__"] init.__annotations__ = annotations return init diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 3e19aa14f..0b2709936 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -578,3 +578,32 @@ class B: assert typing.List[B] == attr.fields(A).a.type assert A == attr.fields(B).a.type + + def test_init_type_hints(self): + """ + Forward references in __init__ can be automatically resolved. + """ + + @attr.s + class C: + x = attr.ib(type="typing.List[int]") + + assert typing.get_type_hints(C.__init__) == { + "return": type(None), + "x": typing.List[int], + } + + def test_init_type_hints_fake_module(self): + """ + If you somehow set the __module__ to something that doesn't exist + you'll lose __init__ resolution. + """ + + class C: + x = attr.ib(type="typing.List[int]") + + C.__module__ = "totally fake" + C = attr.s(C) + + with pytest.raises(NameError): + typing.get_type_hints(C.__init__)