diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 84f8d68ce092a4..e55132daafa94e 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -446,8 +446,7 @@ def _tuple_str(obj_name, fields): return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' -def _create_fn(name, args, body, *, globals=None, locals=None, - return_type=MISSING): +def _create_fn_def(name, args, body, *, locals=None, return_type=MISSING): # Note that we may mutate locals. Callers beware! # The only callers are internal to this module, so no # worries about external callers. @@ -455,24 +454,31 @@ def _create_fn(name, args, body, *, globals=None, locals=None, locals = {} return_annotation = '' if return_type is not MISSING: - locals['__dataclass_return_type__'] = return_type - return_annotation = '->__dataclass_return_type__' + fn_name = name.replace("__", "") + locals[f'__dataclass_{fn_name}_return_type__'] = return_type + return_annotation = f'->__dataclass_{fn_name}_return_type__' args = ','.join(args) body = '\n'.join(f' {b}' for b in body) # Compute the text of the entire function. - txt = f' def {name}({args}){return_annotation}:\n{body}' + txt = f'def {name}({args}){return_annotation}:\n{body}' + return (name, txt, locals) + +def _exec_fn_defs(fn_defs, globals=None): # Free variables in exec are resolved in the global namespace. # The global namespace we have is user-provided, so we can't modify it for # our purposes. So we put the things we need into locals and introduce a # scope to allow the function we're creating to close over them. - local_vars = ', '.join(locals.keys()) - txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" + locals_dict = {k: v for _, _, locals_ in fn_defs + for k, v in locals_.items()} + local_vars = ', '.join(locals_dict.keys()) + fn_names = ", ".join(name for name, _, _ in fn_defs) + txt = "\n".join(f" {txt}" for _, txt, _ in fn_defs) + txt = f"def __create_fn__({local_vars}):\n{txt}\n return {fn_names}" ns = {} exec(txt, globals, ns) - return ns['__create_fn__'](**locals) - + return ns['__create_fn__'](**locals_dict) def _field_assign(frozen, name, value, self_name): # If we're a frozen class, then assign to our fields in __init__ @@ -566,7 +572,7 @@ def _init_param(f): def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, - self_name, globals, slots): + self_name, slots): # fields contains both real fields and InitVar pseudo-fields. # Make sure we don't have fields without defaults following fields @@ -616,68 +622,61 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, # (instead of just concatenting the lists together). _init_params += ['*'] _init_params += [_init_param(f) for f in kw_only_fields] - return _create_fn('__init__', + return _create_fn_def('__init__', [self_name] + _init_params, body_lines, locals=locals, - globals=globals, return_type=None) -def _repr_fn(fields, globals): - fn = _create_fn('__repr__', +def _repr_fn(fields): + return _create_fn_def('__repr__', ('self',), ['return f"{self.__class__.__qualname__}(' + ', '.join([f"{f.name}={{self.{f.name}!r}}" for f in fields]) + - ')"'], - globals=globals) - return _recursive_repr(fn) + ')"'],) -def _frozen_get_del_attr(cls, fields, globals): +def _frozen_get_del_attr(cls, fields): locals = {'cls': cls, 'FrozenInstanceError': FrozenInstanceError} condition = 'type(self) is cls' if fields: condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}' - return (_create_fn('__setattr__', + return (_create_fn_def('__setattr__', ('self', 'name', 'value'), (f'if {condition}:', ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', f'super(cls, self).__setattr__(name, value)'), - locals=locals, - globals=globals), - _create_fn('__delattr__', + locals=locals), + _create_fn_def('__delattr__', ('self', 'name'), (f'if {condition}:', ' raise FrozenInstanceError(f"cannot delete field {name!r}")', f'super(cls, self).__delattr__(name)'), - locals=locals, - globals=globals), + locals=locals), ) -def _cmp_fn(name, op, self_tuple, other_tuple, globals): +def _cmp_fn(name, op, self_tuple, other_tuple): # Create a comparison function. If the fields in the object are # named 'x' and 'y', then self_tuple is the string # '(self.x,self.y)' and other_tuple is the string # '(other.x,other.y)'. - return _create_fn(name, + return _create_fn_def(name, ('self', 'other'), [ 'if other.__class__ is self.__class__:', f' return {self_tuple}{op}{other_tuple}', - 'return NotImplemented'], - globals=globals) + 'return NotImplemented'],) -def _hash_fn(fields, globals): +def _hash_fn(fields): self_tuple = _tuple_str('self', fields) - return _create_fn('__hash__', + return _create_fn_def('__hash__', ('self',), - [f'return hash({self_tuple})'], - globals=globals) + [f'return hash({self_tuple})'],) def _is_classvar(a_type, typing): @@ -855,7 +854,7 @@ def _get_field(cls, a_name, a_type, default_kw_only): return f def _set_qualname(cls, value): - # Ensure that the functions returned from _create_fn uses the proper + # Ensure that the functions returned from _exec_fn_defs uses the proper # __qualname__ (the class they belong to). if isinstance(value, FunctionType): value.__qualname__ = f"{cls.__qualname__}.{value.__name__}" @@ -879,9 +878,9 @@ def _set_new_attribute(cls, name, value): def _hash_set_none(cls, fields, globals): return None -def _hash_add(cls, fields, globals): - flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] - return _set_qualname(cls, _hash_fn(flds, globals)) +class _HASH_ADD: + pass +_hash_add = _HASH_ADD def _hash_exception(cls, fields, globals): # Raise an exception. @@ -925,6 +924,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # derived class fields overwrite base class fields, but the order # is defined by the base class, which is found first. fields = {} + fn_defs = [] # store txt defs to exec combined if cls.__module__ in sys.modules: globals = sys.modules[cls.__module__].__dict__ @@ -1059,8 +1059,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # Does this class have a post-init function? has_post_init = hasattr(cls, _POST_INIT_NAME) - _set_new_attribute(cls, '__init__', - _init_fn(all_init_fields, + fn_defs.append(_init_fn(all_init_fields, std_init_fields, kw_only_init_fields, frozen, @@ -1070,7 +1069,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # if possible. '__dataclass_self__' if 'self' in fields else 'self', - globals, slots, )) _set_new_attribute(cls, '__replace__', _replace) @@ -1081,7 +1079,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, if repr: flds = [f for f in field_list if f.repr] - _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals)) + fn_defs.append(_repr_fn(flds)) if eq: # Create __eq__ method. There's no need for a __ne__ method, @@ -1092,41 +1090,55 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, body = [f'if other.__class__ is self.__class__:', f' return {field_comparisons}', f'return NotImplemented'] - func = _create_fn('__eq__', - ('self', 'other'), - body, - globals=globals) - _set_new_attribute(cls, '__eq__', func) + fn_defs.append(_create_fn_def('__eq__',('self', 'other'), body,)) if order: # Create and set the ordering methods. flds = [f for f in field_list if f.compare] self_tuple = _tuple_str('self', flds) other_tuple = _tuple_str('other', flds) - for name, op in [('__lt__', '<'), - ('__le__', '<='), - ('__gt__', '>'), - ('__ge__', '>='), - ]: - if _set_new_attribute(cls, name, - _cmp_fn(name, op, self_tuple, other_tuple, - globals=globals)): - raise TypeError(f'Cannot overwrite attribute {name} ' - f'in class {cls.__name__}. Consider using ' - 'functools.total_ordering') + order_flds = {'__lt__' : '<', + '__le__' : '<=', + '__gt__' : '>', + '__ge__' : '>=', + } + for name, op in order_flds.items(): + fn_defs.append(_cmp_fn(name, op, self_tuple, other_tuple)) if frozen: - for fn in _frozen_get_del_attr(cls, field_list, globals): - if _set_new_attribute(cls, fn.__name__, fn): - raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' - f'in class {cls.__name__}') + fn_defs.extend(_frozen_get_del_attr(cls, field_list)) # Decide if/how we're going to create a hash function. hash_action = _hash_action[bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash] - if hash_action: + + if hash_action == _hash_add: + flds = [f for f in field_list if (f.compare if f.hash is None else f.hash)] + fn_defs.append(_hash_fn(field_list)) + hash_action = None # assign when iterating + + # exec functions and assign + functions_objects = _exec_fn_defs(fn_defs, globals=globals) + for fn in functions_objects: + name = fn.__name__ + if name == '__repr__': + fn = _recursive_repr(fn) + + if name == '__hash__': + cls.__hash__ = _set_qualname(cls, fn) + else: + if _set_new_attribute(cls, name, fn): + if order and name in order_flds: + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in class {cls.__name__}. Consider using ' + 'functools.total_ordering') + elif frozen and name in ['__setattr__','__delattr__']: + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in class {cls.__name__}') + + if hash_action: # for _hash_set_none, _hash_exception # No need to call _set_new_attribute here, since by the time # we're here the overwriting is unconditional. cls.__hash__ = hash_action(cls, field_list, globals)