Skip to content

Commit e2744f9

Browse files
committed
Implement type assignment, add type handling for lists, tuples, sets
1 parent 076f5e7 commit e2744f9

File tree

3 files changed

+120
-20
lines changed

3 files changed

+120
-20
lines changed

IPython/core/completer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,9 @@ def _evaluate_expr(self, expr):
13461346
),
13471347
)
13481348
done = True
1349-
except (SyntaxError, TypeError):
1349+
except (SyntaxError, TypeError) as e:
1350+
if self.debug:
1351+
warnings.warn(f"Trimming because of {e}")
13501352
# TypeError can show up with something like `+ d`
13511353
# where `d` is a dictionary.
13521354

@@ -1357,8 +1359,10 @@ def _evaluate_expr(self, expr):
13571359
expr = self._trim_expr(expr)
13581360
except Exception as e:
13591361
if self.debug:
1360-
print("Evaluation exception", e)
1362+
warnings.warn(f"Evaluation exception {e}")
13611363
done = True
1364+
if self.debug:
1365+
warnings.warn(f"Resolved to {obj}")
13621366
return obj
13631367

13641368
@property

IPython/core/guarded_eval.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import dataclasses
2323
import operator
2424
import sys
25+
import typing
2526
import warnings
2627
from functools import cached_property
2728
from dataclasses import dataclass, field
@@ -227,6 +228,8 @@ class SelectivePolicy(EvaluationPolicy):
227228
allowed_operations: set = field(default_factory=set)
228229
allowed_operations_external: set[tuple[str, ...] | str] = field(default_factory=set)
229230

231+
allow_getitem_on_types: bool = field(default_factory=bool)
232+
230233
_operation_methods_cache: dict[str, set[Callable]] = field(
231234
default_factory=dict, init=False
232235
)
@@ -290,6 +293,13 @@ def can_get_attr(self, value, attr):
290293
def can_get_item(self, value, item):
291294
"""Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
292295
allowed_getitem_external = _coerce_path_to_tuples(self.allowed_getitem_external)
296+
if self.allow_getitem_on_types:
297+
# e.g. Union[str, int] or Literal[True, 1]
298+
if isinstance(value, (typing._SpecialForm, typing._BaseGenericAlias)):
299+
return True
300+
# PEP 560 e.g. list[str]
301+
if isinstance(value, type) and hasattr(value, "__class_getitem__"):
302+
return True
293303
return _has_original_dunder(
294304
value,
295305
allowed_types=self.allowed_getitem,
@@ -478,8 +488,8 @@ class _Duck:
478488
"""A dummy class used to create objects pretending to have given attributes"""
479489

480490
def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
481-
self.attributes = attributes or {}
482-
self.items = items or {}
491+
self.attributes = attributes if attributes is not None else {}
492+
self.items = items if items is not None else {}
483493

484494
def __getattr__(self, attr: str):
485495
return self.attributes[attr]
@@ -640,6 +650,14 @@ def dummy_function(*args, **kwargs):
640650
return None
641651
if isinstance(node, ast.Assign):
642652
return _handle_assign(node, context)
653+
if isinstance(node, ast.AnnAssign):
654+
if not node.simple:
655+
# for now only handle simple annotations
656+
return None
657+
context.transient_locals[node.target.id] = _resolve_annotation(
658+
eval_node(node.annotation, context), context
659+
)
660+
return None
643661
if isinstance(node, ast.Expression):
644662
return eval_node(node.body, context)
645663
if isinstance(node, ast.Expr):
@@ -793,7 +811,13 @@ def dummy_function(*args, **kwargs):
793811
func, # not joined to avoid calling `repr`
794812
f"not allowed in {context.evaluation} mode",
795813
)
796-
raise ValueError("Unhandled node", ast.dump(node))
814+
if isinstance(node, ast.Assert):
815+
# message is always the second item, so if it is defined user would be completing
816+
# on the message, not on the assertion test
817+
if node.msg:
818+
return eval_node(node.msg, context)
819+
return eval_node(node.test, context)
820+
raise SyntaxError(f"Unhandled node: {ast.dump(node)}")
797821

798822

799823
def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
@@ -809,7 +833,7 @@ def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext
809833
# but resolved by signature call we know the return type
810834
not_empty = sig.return_annotation is not Signature.empty
811835
if not_empty:
812-
return _resolve_annotation(sig.return_annotation, sig, func, node, context)
836+
return _resolve_annotation(sig.return_annotation, context, sig, func, node)
813837
return NOT_EVALUATED
814838

815839

@@ -824,12 +848,26 @@ def _eval_annotation(
824848
)
825849

826850

851+
class _GetItemDuck(dict):
852+
"""A dict subclass that always returns the factory instance and claims to have any item."""
853+
854+
def __init__(self, factory, *args, **kwargs):
855+
super().__init__(*args, **kwargs)
856+
self._factory = factory
857+
858+
def __getitem__(self, key):
859+
return self._factory()
860+
861+
def __contains__(self, key):
862+
return True
863+
864+
827865
def _resolve_annotation(
828-
annotation: str,
829-
sig: Signature,
830-
func: Callable,
831-
node: ast.Call,
866+
annotation: object | str,
832867
context: EvaluationContext,
868+
sig: Signature | None = None,
869+
func: Callable | None = None,
870+
node: ast.Call | None = None,
833871
):
834872
"""Resolve annotation created by user with `typing` module and custom objects."""
835873
annotation = _eval_annotation(annotation, context)
@@ -844,7 +882,7 @@ def _resolve_annotation(
844882
return ""
845883
elif annotation is AnyStr:
846884
index = None
847-
if hasattr(func, "__node__"):
885+
if func and hasattr(func, "__node__"):
848886
def_node = func.__node__
849887
for i, arg in enumerate(def_node.args.args):
850888
if not arg.annotation:
@@ -858,7 +896,7 @@ def _resolve_annotation(
858896
)
859897
if index and is_bound_method:
860898
index -= 1
861-
else:
899+
elif sig:
862900
for i, (key, value) in enumerate(sig.parameters.items()):
863901
if value.annotation is AnyStr:
864902
index = i
@@ -870,32 +908,53 @@ def _resolve_annotation(
870908
return eval_node(node.args[index], context)
871909
elif origin is TypeGuard:
872910
return False
911+
elif origin is set or origin is list:
912+
# only one type argument allowed
913+
attributes = [
914+
attr
915+
for attr in dir(
916+
_resolve_annotation(get_args(annotation)[0], context, sig, func, node)
917+
)
918+
]
919+
duck = _Duck(attributes=dict.fromkeys(attributes))
920+
return _Duck(
921+
attributes=dict.fromkeys(dir(origin())),
922+
# items are not strrictly needed for set
923+
items=_GetItemDuck(lambda: duck),
924+
)
925+
elif origin is tuple:
926+
# multiple type arguments
927+
return tuple(
928+
_resolve_annotation(arg, context, sig, func, node)
929+
for arg in get_args(annotation)
930+
)
873931
elif origin is Union:
932+
# multiple type arguments
874933
attributes = [
875934
attr
876935
for type_arg in get_args(annotation)
877-
for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
936+
for attr in dir(_resolve_annotation(type_arg, context, sig, func, node))
878937
]
879938
return _Duck(attributes=dict.fromkeys(attributes))
880939
elif is_typeddict(annotation):
881940
return _Duck(
882941
attributes=dict.fromkeys(dir(dict())),
883942
items={
884-
k: _resolve_annotation(v, sig, func, node, context)
943+
k: _resolve_annotation(v, context, sig, func, node)
885944
for k, v in annotation.__annotations__.items()
886945
},
887946
)
888947
elif hasattr(annotation, "_is_protocol"):
889948
return _Duck(attributes=dict.fromkeys(dir(annotation)))
890949
elif origin is Annotated:
891950
type_arg = get_args(annotation)[0]
892-
return _resolve_annotation(type_arg, sig, func, node, context)
951+
return _resolve_annotation(type_arg, context, sig, func, node)
893952
elif isinstance(annotation, NewType):
894-
return _eval_or_create_duck(annotation.__supertype__, node, context)
953+
return _eval_or_create_duck(annotation.__supertype__, context)
895954
elif isinstance(annotation, TypeAliasType):
896-
return _eval_or_create_duck(annotation.__value__, node, context)
955+
return _eval_or_create_duck(annotation.__value__, context)
897956
else:
898-
return _eval_or_create_duck(annotation, node, context)
957+
return _eval_or_create_duck(annotation, context)
899958

900959

901960
def _eval_node_name(node_id: str, context: EvaluationContext):
@@ -919,7 +978,7 @@ def _eval_node_name(node_id: str, context: EvaluationContext):
919978
raise NameError(f"{node_id} not found in locals, globals, nor builtins")
920979

921980

922-
def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
981+
def _eval_or_create_duck(duck_type, context: EvaluationContext):
923982
policy = get_policy(context)
924983
# if allow-listed builtin is on type annotation, instantiate it
925984
if policy.can_call(duck_type):
@@ -959,6 +1018,8 @@ def _create_duck_for_heap_type(duck_type):
9591018
bytes, # type: ignore[arg-type]
9601019
list,
9611020
tuple,
1021+
type, # for type annotations like list[str]
1022+
_Duck,
9621023
collections.defaultdict,
9631024
collections.deque,
9641025
collections.OrderedDict,
@@ -1063,6 +1124,7 @@ def _list_methods(cls, source=None):
10631124
allow_builtins_access=True,
10641125
allow_locals_access=True,
10651126
allow_globals_access=True,
1127+
allow_getitem_on_types=True,
10661128
allowed_calls=ALLOWED_CALLS,
10671129
),
10681130
"unsafe": EvaluationPolicy(

tests/test_guarded_eval.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,45 @@ def test_mock_class_and_func_instances(code, expected):
545545
],
546546
)
547547
def test_evaluates_assignments(code, expected):
548-
context = limited(TypedClass=TypedClass, AnyStr=AnyStr)
548+
context = limited()
549549
value = guarded_eval(code, context)
550550
assert isinstance(value, expected)
551551

552552

553+
def equals(a, b):
554+
return a == b
555+
556+
557+
def quacks_like(test_duck, reference_duck):
558+
return set(dir(reference_duck)) - set(dir(test_duck)) == set()
559+
560+
561+
@pytest.mark.parametrize(
562+
"code,expected,check",
563+
[
564+
["\n".join(["a: Literal[True]", "a"]), True, equals],
565+
["\n".join(["a: bool", "a"]), bool, isinstance],
566+
["\n".join(["a: str", "a"]), str, isinstance],
567+
# for lists we need quacking as we do not know:
568+
# - how many elements in the list
569+
# - which element is of which type
570+
["\n".join(["a: list[str]", "a"]), list, quacks_like],
571+
["\n".join(["a: list[str]", "a[0]"]), str, quacks_like],
572+
["\n".join(["a: list[str]", "a[999]"]), str, quacks_like],
573+
# set
574+
["\n".join(["a: set[str]", "a"]), set, quacks_like],
575+
# for tuples we do know which element is which
576+
["\n".join(["a: tuple[str, int]", "a"]), tuple, isinstance],
577+
["\n".join(["a: tuple[str, int]", "a[0]"]), str, isinstance],
578+
["\n".join(["a: tuple[str, int]", "a[1]"]), int, isinstance],
579+
],
580+
)
581+
def test_evaluates_type_assignments(code, expected, check):
582+
context = limited(Literal=Literal)
583+
value = guarded_eval(code, context)
584+
assert check(value, expected)
585+
586+
553587
@pytest.mark.parametrize(
554588
"data,bad",
555589
[

0 commit comments

Comments
 (0)