2222import dataclasses
2323import operator
2424import sys
25+ import typing
2526import warnings
2627from functools import cached_property
2728from dataclasses import dataclass , field
@@ -227,6 +228,8 @@ class SelectivePolicy(EvaluationPolicy):
227228 allowed_operations : set = field (default_factory = set )
228229 allowed_operations_external : set [tuple [str , ...] | str ] = field (default_factory = set )
229230
231+ allow_getitem_on_types : bool = field (default_factory = bool )
232+
230233 _operation_methods_cache : dict [str , set [Callable ]] = field (
231234 default_factory = dict , init = False
232235 )
@@ -290,6 +293,13 @@ def can_get_attr(self, value, attr):
290293 def can_get_item (self , value , item ):
291294 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
292295 allowed_getitem_external = _coerce_path_to_tuples (self .allowed_getitem_external )
296+ if self .allow_getitem_on_types :
297+ # e.g. Union[str, int] or Literal[True, 1]
298+ if isinstance (value , (typing ._SpecialForm , typing ._BaseGenericAlias )):
299+ return True
300+ # PEP 560 e.g. list[str]
301+ if isinstance (value , type ) and hasattr (value , "__class_getitem__" ):
302+ return True
293303 return _has_original_dunder (
294304 value ,
295305 allowed_types = self .allowed_getitem ,
@@ -478,8 +488,8 @@ class _Duck:
478488 """A dummy class used to create objects pretending to have given attributes"""
479489
480490 def __init__ (self , attributes : Optional [dict ] = None , items : Optional [dict ] = None ):
481- self .attributes = attributes or {}
482- self .items = items or {}
491+ self .attributes = attributes if attributes is not None else {}
492+ self .items = items if items is not None else {}
483493
484494 def __getattr__ (self , attr : str ):
485495 return self .attributes [attr ]
@@ -640,6 +650,14 @@ def dummy_function(*args, **kwargs):
640650 return None
641651 if isinstance (node , ast .Assign ):
642652 return _handle_assign (node , context )
653+ if isinstance (node , ast .AnnAssign ):
654+ if not node .simple :
655+ # for now only handle simple annotations
656+ return None
657+ context .transient_locals [node .target .id ] = _resolve_annotation (
658+ eval_node (node .annotation , context ), context
659+ )
660+ return None
643661 if isinstance (node , ast .Expression ):
644662 return eval_node (node .body , context )
645663 if isinstance (node , ast .Expr ):
@@ -793,7 +811,13 @@ def dummy_function(*args, **kwargs):
793811 func , # not joined to avoid calling `repr`
794812 f"not allowed in { context .evaluation } mode" ,
795813 )
796- raise ValueError ("Unhandled node" , ast .dump (node ))
814+ if isinstance (node , ast .Assert ):
815+ # message is always the second item, so if it is defined user would be completing
816+ # on the message, not on the assertion test
817+ if node .msg :
818+ return eval_node (node .msg , context )
819+ return eval_node (node .test , context )
820+ raise SyntaxError (f"Unhandled node: { ast .dump (node )} " )
797821
798822
799823def _eval_return_type (func : Callable , node : ast .Call , context : EvaluationContext ):
@@ -809,7 +833,7 @@ def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext
809833 # but resolved by signature call we know the return type
810834 not_empty = sig .return_annotation is not Signature .empty
811835 if not_empty :
812- return _resolve_annotation (sig .return_annotation , sig , func , node , context )
836+ return _resolve_annotation (sig .return_annotation , context , sig , func , node )
813837 return NOT_EVALUATED
814838
815839
@@ -824,12 +848,26 @@ def _eval_annotation(
824848 )
825849
826850
851+ class _GetItemDuck (dict ):
852+ """A dict subclass that always returns the factory instance and claims to have any item."""
853+
854+ def __init__ (self , factory , * args , ** kwargs ):
855+ super ().__init__ (* args , ** kwargs )
856+ self ._factory = factory
857+
858+ def __getitem__ (self , key ):
859+ return self ._factory ()
860+
861+ def __contains__ (self , key ):
862+ return True
863+
864+
827865def _resolve_annotation (
828- annotation : str ,
829- sig : Signature ,
830- func : Callable ,
831- node : ast .Call ,
866+ annotation : object | str ,
832867 context : EvaluationContext ,
868+ sig : Signature | None = None ,
869+ func : Callable | None = None ,
870+ node : ast .Call | None = None ,
833871):
834872 """Resolve annotation created by user with `typing` module and custom objects."""
835873 annotation = _eval_annotation (annotation , context )
@@ -844,7 +882,7 @@ def _resolve_annotation(
844882 return ""
845883 elif annotation is AnyStr :
846884 index = None
847- if hasattr (func , "__node__" ):
885+ if func and hasattr (func , "__node__" ):
848886 def_node = func .__node__
849887 for i , arg in enumerate (def_node .args .args ):
850888 if not arg .annotation :
@@ -858,7 +896,7 @@ def _resolve_annotation(
858896 )
859897 if index and is_bound_method :
860898 index -= 1
861- else :
899+ elif sig :
862900 for i , (key , value ) in enumerate (sig .parameters .items ()):
863901 if value .annotation is AnyStr :
864902 index = i
@@ -870,32 +908,53 @@ def _resolve_annotation(
870908 return eval_node (node .args [index ], context )
871909 elif origin is TypeGuard :
872910 return False
911+ elif origin is set or origin is list :
912+ # only one type argument allowed
913+ attributes = [
914+ attr
915+ for attr in dir (
916+ _resolve_annotation (get_args (annotation )[0 ], context , sig , func , node )
917+ )
918+ ]
919+ duck = _Duck (attributes = dict .fromkeys (attributes ))
920+ return _Duck (
921+ attributes = dict .fromkeys (dir (origin ())),
922+ # items are not strrictly needed for set
923+ items = _GetItemDuck (lambda : duck ),
924+ )
925+ elif origin is tuple :
926+ # multiple type arguments
927+ return tuple (
928+ _resolve_annotation (arg , context , sig , func , node )
929+ for arg in get_args (annotation )
930+ )
873931 elif origin is Union :
932+ # multiple type arguments
874933 attributes = [
875934 attr
876935 for type_arg in get_args (annotation )
877- for attr in dir (_resolve_annotation (type_arg , sig , func , node , context ))
936+ for attr in dir (_resolve_annotation (type_arg , context , sig , func , node ))
878937 ]
879938 return _Duck (attributes = dict .fromkeys (attributes ))
880939 elif is_typeddict (annotation ):
881940 return _Duck (
882941 attributes = dict .fromkeys (dir (dict ())),
883942 items = {
884- k : _resolve_annotation (v , sig , func , node , context )
943+ k : _resolve_annotation (v , context , sig , func , node )
885944 for k , v in annotation .__annotations__ .items ()
886945 },
887946 )
888947 elif hasattr (annotation , "_is_protocol" ):
889948 return _Duck (attributes = dict .fromkeys (dir (annotation )))
890949 elif origin is Annotated :
891950 type_arg = get_args (annotation )[0 ]
892- return _resolve_annotation (type_arg , sig , func , node , context )
951+ return _resolve_annotation (type_arg , context , sig , func , node )
893952 elif isinstance (annotation , NewType ):
894- return _eval_or_create_duck (annotation .__supertype__ , node , context )
953+ return _eval_or_create_duck (annotation .__supertype__ , context )
895954 elif isinstance (annotation , TypeAliasType ):
896- return _eval_or_create_duck (annotation .__value__ , node , context )
955+ return _eval_or_create_duck (annotation .__value__ , context )
897956 else :
898- return _eval_or_create_duck (annotation , node , context )
957+ return _eval_or_create_duck (annotation , context )
899958
900959
901960def _eval_node_name (node_id : str , context : EvaluationContext ):
@@ -919,7 +978,7 @@ def _eval_node_name(node_id: str, context: EvaluationContext):
919978 raise NameError (f"{ node_id } not found in locals, globals, nor builtins" )
920979
921980
922- def _eval_or_create_duck (duck_type , node : ast . Call , context : EvaluationContext ):
981+ def _eval_or_create_duck (duck_type , context : EvaluationContext ):
923982 policy = get_policy (context )
924983 # if allow-listed builtin is on type annotation, instantiate it
925984 if policy .can_call (duck_type ):
@@ -959,6 +1018,8 @@ def _create_duck_for_heap_type(duck_type):
9591018 bytes , # type: ignore[arg-type]
9601019 list ,
9611020 tuple ,
1021+ type , # for type annotations like list[str]
1022+ _Duck ,
9621023 collections .defaultdict ,
9631024 collections .deque ,
9641025 collections .OrderedDict ,
@@ -1063,6 +1124,7 @@ def _list_methods(cls, source=None):
10631124 allow_builtins_access = True ,
10641125 allow_locals_access = True ,
10651126 allow_globals_access = True ,
1127+ allow_getitem_on_types = True ,
10661128 allowed_calls = ALLOWED_CALLS ,
10671129 ),
10681130 "unsafe" : EvaluationPolicy (
0 commit comments