diff --git a/changelog.d/435.change.rst b/changelog.d/435.change.rst new file mode 100644 index 000000000..9286afddc --- /dev/null +++ b/changelog.d/435.change.rst @@ -0,0 +1 @@ +It's now possible to customize the behavior of ``eq`` and ``order`` by passing in a callable. diff --git a/docs/api.rst b/docs/api.rst index 5cc2f2d60..08067d548 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -102,7 +102,7 @@ Core ... class C(object): ... x = attr.ib() >>> attr.fields(C).x - Attribute(name='x', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None) + Attribute(name='x', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None) .. autofunction:: attr.make_class @@ -180,9 +180,9 @@ Helpers ... x = attr.ib() ... y = attr.ib() >>> attr.fields(C) - (Attribute(name='x', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None), Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None)) + (Attribute(name='x', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None), Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None)) >>> attr.fields(C)[1] - Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None) + Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None) >>> attr.fields(C).y is attr.fields(C)[1] True @@ -197,9 +197,9 @@ Helpers ... x = attr.ib() ... y = attr.ib() >>> attr.fields_dict(C) - {'x': Attribute(name='x', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None), 'y': Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None)} + {'x': Attribute(name='x', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None), 'y': Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None)} >>> attr.fields_dict(C)['y'] - Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None) + Attribute(name='y', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None) >>> attr.fields_dict(C)['y'] is attr.fields(C).y True diff --git a/docs/extending.rst b/docs/extending.rst index fed39a306..25765ee6c 100644 --- a/docs/extending.rst +++ b/docs/extending.rst @@ -16,7 +16,7 @@ So it is fairly simple to build your own decorators on top of ``attrs``: ... @attr.s ... class C(object): ... a = attr.ib() - (Attribute(name='a', default=NOTHING, validator=None, repr=True, eq=True, order=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None),) + (Attribute(name='a', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False, inherited=False, on_setattr=None),) .. warning:: diff --git a/src/attr/__init__.pyi b/src/attr/__init__.pyi index eb301b431..f3e8be801 100644 --- a/src/attr/__init__.pyi +++ b/src/attr/__init__.pyi @@ -37,6 +37,7 @@ __copyright__: str _T = TypeVar("_T") _C = TypeVar("_C", bound=type) +_EqOrderType = Union[bool, Callable[[Any], Any]] _ValidatorType = Callable[[Any, Attribute[_T], _T], Any] _ConverterType = Callable[[Any], Any] _FilterType = Callable[[Attribute[_T], _T], bool] @@ -72,9 +73,9 @@ class Attribute(Generic[_T]): default: Optional[_T] validator: Optional[_ValidatorType[_T]] repr: _ReprArgType - cmp: bool - eq: bool - order: bool + cmp: _EqOrderType + eq: _EqOrderType + order: _EqOrderType hash: Optional[bool] init: bool converter: Optional[_ConverterType] @@ -114,7 +115,7 @@ def attrib( default: None = ..., validator: None = ..., repr: _ReprArgType = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., metadata: Optional[Mapping[Any, Any]] = ..., @@ -122,8 +123,8 @@ def attrib( converter: None = ..., factory: None = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> Any: ... @@ -134,7 +135,7 @@ def attrib( default: None = ..., validator: Optional[_ValidatorArgType[_T]] = ..., repr: _ReprArgType = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., metadata: Optional[Mapping[Any, Any]] = ..., @@ -142,8 +143,8 @@ def attrib( converter: Optional[_ConverterType] = ..., factory: Optional[Callable[[], _T]] = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> _T: ... @@ -153,7 +154,7 @@ def attrib( default: _T, validator: Optional[_ValidatorArgType[_T]] = ..., repr: _ReprArgType = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., metadata: Optional[Mapping[Any, Any]] = ..., @@ -161,8 +162,8 @@ def attrib( converter: Optional[_ConverterType] = ..., factory: Optional[Callable[[], _T]] = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> _T: ... @@ -172,7 +173,7 @@ def attrib( default: Optional[_T] = ..., validator: Optional[_ValidatorArgType[_T]] = ..., repr: _ReprArgType = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., metadata: Optional[Mapping[Any, Any]] = ..., @@ -180,8 +181,8 @@ def attrib( converter: Optional[_ConverterType] = ..., factory: Optional[Callable[[], _T]] = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> Any: ... @overload @@ -215,8 +216,8 @@ def field( converter: Optional[_ConverterType] = ..., factory: Optional[Callable[[], _T]] = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> _T: ... @@ -233,8 +234,8 @@ def field( converter: Optional[_ConverterType] = ..., factory: Optional[Callable[[], _T]] = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> _T: ... @@ -251,8 +252,8 @@ def field( converter: Optional[_ConverterType] = ..., factory: Optional[Callable[[], _T]] = ..., kw_only: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., ) -> Any: ... @overload @@ -261,7 +262,7 @@ def attrs( these: Optional[Dict[str, Any]] = ..., repr_ns: Optional[str] = ..., repr: bool = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., slots: bool = ..., @@ -272,8 +273,8 @@ def attrs( kw_only: bool = ..., cache_hash: bool = ..., auto_exc: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., auto_detect: bool = ..., collect_by_mro: bool = ..., getstate_setstate: Optional[bool] = ..., @@ -286,7 +287,7 @@ def attrs( these: Optional[Dict[str, Any]] = ..., repr_ns: Optional[str] = ..., repr: bool = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., slots: bool = ..., @@ -297,8 +298,8 @@ def attrs( kw_only: bool = ..., cache_hash: bool = ..., auto_exc: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., auto_detect: bool = ..., collect_by_mro: bool = ..., getstate_setstate: Optional[bool] = ..., @@ -377,7 +378,7 @@ def make_class( bases: Tuple[type, ...] = ..., repr_ns: Optional[str] = ..., repr: bool = ..., - cmp: Optional[bool] = ..., + cmp: Optional[_EqOrderType] = ..., hash: Optional[bool] = ..., init: bool = ..., slots: bool = ..., @@ -388,8 +389,8 @@ def make_class( kw_only: bool = ..., cache_hash: bool = ..., auto_exc: bool = ..., - eq: Optional[bool] = ..., - order: Optional[bool] = ..., + eq: Optional[_EqOrderType] = ..., + order: Optional[_EqOrderType] = ..., collect_by_mro: bool = ..., on_setattr: Optional[_OnSetAttrArgType] = ..., field_transformer: Optional[_FieldTransformer] = ..., diff --git a/src/attr/_make.py b/src/attr/_make.py index 76b1c62f4..44362c85d 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -178,13 +178,26 @@ def attrib( as-is, i.e. it will be used directly *instead* of calling ``repr()`` (the default). :type repr: a `bool` or a `callable` to use a custom function. - :param bool eq: If ``True`` (default), include this attribute in the + + :param eq: If ``True`` (default), include this attribute in the generated ``__eq__`` and ``__ne__`` methods that check two instances - for equality. - :param bool order: If ``True`` (default), include this attributes in the + for equality. To override how the attribute value is compared, + pass a ``callable`` that takes a single value and returns the value + to be compared. + :type eq: a `bool` or a `callable`. + + :param order: If ``True`` (default), include this attributes in the generated ``__lt__``, ``__le__``, ``__gt__`` and ``__ge__`` methods. - :param bool cmp: Setting to ``True`` is equivalent to setting ``eq=True, - order=True``. Deprecated in favor of *eq* and *order*. + To override how the attribute value is ordered, + pass a ``callable`` that takes a single value and returns the value + to be ordered. + :type order: a `bool` or a `callable`. + + :param cmp: Setting to ``True`` is equivalent to setting ``eq=True, + order=True``. Can also be set to a ``callable``. + Deprecated in favor of *eq* and *order*. + :type cmp: a `bool` or a `callable`. + :param Optional[bool] hash: Include this attribute in the generated ``__hash__`` method. If ``None`` (default), mirror *eq*'s value. This is the correct behavior according the Python spec. Setting this value @@ -232,14 +245,17 @@ def attrib( .. versionadded:: 18.1.0 ``factory=f`` is syntactic sugar for ``default=attr.Factory(f)``. .. versionadded:: 18.2.0 *kw_only* - .. versionchanged:: 19.2.0 *convert* keyword argument removed + .. versionchanged:: 19.2.0 *convert* keyword argument removed. .. versionchanged:: 19.2.0 *repr* also accepts a custom callable. .. deprecated:: 19.2.0 *cmp* Removal on or after 2021-06-01. .. versionadded:: 19.2.0 *eq* and *order* .. versionadded:: 20.1.0 *on_setattr* .. versionchanged:: 20.3.0 *kw_only* backported to Python 2 + .. versionchanged:: 21.1.0 *eq* and *order* also accept a custom callable. """ - eq, order = _determine_eq_order(cmp, eq, order, True) + eq, eq_key, order, order_key = _determine_attrib_eq_order( + cmp, eq, order, True + ) if hash is not None and hash is not True and hash is not False: raise TypeError( @@ -281,7 +297,9 @@ def attrib( type=type, kw_only=kw_only, eq=eq, + eq_key=eq_key, order=order, + order_key=order_key, on_setattr=on_setattr, ) @@ -1042,8 +1060,13 @@ def _add_method_dunders(self, method): "2021-06-01. Please use `eq` and `order` instead." ) +_EQ_ORDER_CUSTOMIZATION = ( + "You have customized the behaviour of `eq` but not of `order`. " + "This is probably a bug." +) + -def _determine_eq_order(cmp, eq, order, default_eq): +def _determine_attrs_eq_order(cmp, eq, order, default_eq): """ Validate the combination of *cmp*, *eq*, and *order*. Derive the effective values of eq and order. If *eq* is None, set it to *default_eq*. @@ -1071,6 +1094,52 @@ def _determine_eq_order(cmp, eq, order, default_eq): return eq, order +def _determine_attrib_eq_order(cmp, eq, order, default_eq): + """ + Validate the combination of *cmp*, *eq*, and *order*. Derive the effective + values of eq and order. If *eq* is None, set it to *default_eq*. + """ + if cmp is not None and any((eq is not None, order is not None)): + raise ValueError("Don't mix `cmp` with `eq' and `order`.") + + def decide_callable_or_boolean(value): + """ + Decide whether a key function is used. + """ + if callable(value): + value, key = True, value + else: + key = None + return value, key + + # cmp takes precedence due to bw-compatibility. + if cmp is not None: + warnings.warn(_CMP_DEPRECATION, DeprecationWarning, stacklevel=3) + + cmp, cmp_key = decide_callable_or_boolean(cmp) + return cmp, cmp_key, cmp, cmp_key + + # If left None, equality is set to the specified default and ordering + # mirrors equality. + if eq is None: + eq, eq_key = default_eq, None + else: + eq, eq_key = decide_callable_or_boolean(eq) + + if order is None: + if eq_key is not None: + warnings.warn(_EQ_ORDER_CUSTOMIZATION, SyntaxWarning, stacklevel=3) + + order, order_key = eq, eq_key + else: + order, order_key = decide_callable_or_boolean(order) + + if eq is False and order is True: + raise ValueError("`order` can only be True if `eq` is True too.") + + return eq, eq_key, order, order_key + + def _determine_whether_to_implement( cls, flag, auto_detect, dunders, default=True ): @@ -1369,7 +1438,7 @@ def attrs( "auto_detect only works on Python 3 and later." ) - eq_, order_ = _determine_eq_order(cmp, eq, order, None) + eq_, order_ = _determine_attrs_eq_order(cmp, eq, order, None) hash_ = hash # work around the lack of nonlocal if isinstance(on_setattr, (list, tuple)): @@ -1520,13 +1589,6 @@ def _has_frozen_base_class(cls): return cls.__setattr__ == _frozen_setattrs -def _attrs_to_tuple(obj, attrs): - """ - Create a tuple of all values of *obj*'s *attrs*. - """ - return tuple(getattr(obj, a.name) for a in attrs) - - def _generate_unique_filename(cls, func_name): """ Create a "filename" suitable for a function being generated. @@ -1662,21 +1724,44 @@ def _make_eq(cls, attrs): " if other.__class__ is not self.__class__:", " return NotImplemented", ] + # We can't just do a big self.x = other.x and... clause due to # irregularities like nan == nan is false but (nan,) == (nan,) is true. + globs = {} if attrs: lines.append(" return (") others = [" ) == ("] for a in attrs: - lines.append(" self.%s," % (a.name,)) - others.append(" other.%s," % (a.name,)) + if a.eq_key: + cmp_name = "_%s_key" % (a.name,) + # Add the key function to the global namespace + # of the evaluated function. + globs[cmp_name] = a.eq_key + lines.append( + " %s(self.%s)," + % ( + cmp_name, + a.name, + ) + ) + others.append( + " %s(other.%s)," + % ( + cmp_name, + a.name, + ) + ) + else: + lines.append(" self.%s," % (a.name,)) + others.append(" other.%s," % (a.name,)) lines += others + [" )"] else: lines.append(" return True") script = "\n".join(lines) - return _make_method("__eq__", script, unique_filename) + + return _make_method("__eq__", script, unique_filename, globs) def _make_order(cls, attrs): @@ -1689,7 +1774,12 @@ def attrs_to_tuple(obj): """ Save us some typing. """ - return _attrs_to_tuple(obj, attrs) + return tuple( + key(value) if key else value + for value, key in ( + (getattr(obj, a.name), a.order_key) for a in attrs + ) + ) def __lt__(self, other): """ @@ -2394,6 +2484,7 @@ class Attribute(object): .. versionadded:: 20.1.0 *on_setattr* .. versionchanged:: 20.2.0 *inherited* is not taken into account for equality checks and hashing anymore. + .. versionadded:: 20.X.Y *eq_key* and *order_key* For the full version history of the fields, see `attr.ib`. """ @@ -2404,7 +2495,9 @@ class Attribute(object): "validator", "repr", "eq", + "eq_key", "order", + "order_key", "hash", "init", "metadata", @@ -2430,10 +2523,14 @@ def __init__( converter=None, kw_only=False, eq=None, + eq_key=None, order=None, + order_key=None, on_setattr=None, ): - eq, order = _determine_eq_order(cmp, eq, order, True) + eq, eq_key, order, order_key = _determine_attrib_eq_order( + cmp, eq_key or eq, order_key or order, True + ) # Cache this descriptor here to speed things up later. bound_setattr = _obj_setattr.__get__(self, Attribute) @@ -2445,7 +2542,9 @@ def __init__( bound_setattr("validator", validator) bound_setattr("repr", repr) bound_setattr("eq", eq) + bound_setattr("eq_key", eq_key) bound_setattr("order", order) + bound_setattr("order_key", order_key) bound_setattr("hash", hash) bound_setattr("init", init) bound_setattr("converter", converter) @@ -2561,7 +2660,9 @@ def _setattrs(self, name_values_pairs): repr=True, cmp=None, eq=True, + # eq_key=None, order=False, + # order_key=None, hash=(name != "metadata"), init=True, inherited=False, @@ -2592,7 +2693,9 @@ class _CountingAttr(object): "_default", "repr", "eq", + "eq_key", "order", + "order_key", "hash", "init", "metadata", @@ -2613,7 +2716,9 @@ class _CountingAttr(object): init=True, kw_only=False, eq=True, + eq_key=None, order=False, + order_key=None, inherited=False, on_setattr=None, ) @@ -2638,7 +2743,9 @@ class _CountingAttr(object): init=True, kw_only=False, eq=True, + eq_key=None, order=False, + order_key=None, inherited=False, on_setattr=None, ), @@ -2658,7 +2765,9 @@ def __init__( type, kw_only, eq, + eq_key, order, + order_key, on_setattr, ): _CountingAttr.cls_counter += 1 @@ -2668,7 +2777,9 @@ def __init__( self.converter = converter self.repr = repr self.eq = eq + self.eq_key = eq_key self.order = order + self.order_key = order_key self.hash = hash self.init = init self.metadata = metadata @@ -2830,7 +2941,7 @@ def make_class(name, attrs, bases=(object,), **attributes_arguments): ( attributes_arguments["eq"], attributes_arguments["order"], - ) = _determine_eq_order( + ) = _determine_attrs_eq_order( cmp, attributes_arguments.get("eq"), attributes_arguments.get("order"), diff --git a/tests/test_dunders.py b/tests/test_dunders.py index a34f8f481..3ff4c1672 100644 --- a/tests/test_dunders.py +++ b/tests/test_dunders.py @@ -36,6 +36,31 @@ ReprC = simple_class(repr=True) ReprCSlots = simple_class(repr=True, slots=True) + +@attr.s(eq=True) +class EqCallableC(object): + a = attr.ib(eq=str.lower, order=False) + b = attr.ib(eq=True) + + +@attr.s(eq=True, slots=True) +class EqCallableCSlots(object): + a = attr.ib(eq=str.lower, order=False) + b = attr.ib(eq=True) + + +@attr.s(order=True) +class OrderCallableC(object): + a = attr.ib(eq=True, order=str.lower) + b = attr.ib(order=True) + + +@attr.s(order=True, slots=True) +class OrderCallableCSlots(object): + a = attr.ib(eq=True, order=str.lower) + b = attr.ib(order=True) + + # HashC is hashable by explicit definition while HashCSlots is hashable # implicitly. The "Cached" versions are the same, except with hash code # caching enabled @@ -106,6 +131,16 @@ def test_equal(self, cls): assert cls(1, 2) == cls(1, 2) assert not (cls(1, 2) != cls(1, 2)) + @pytest.mark.parametrize("cls", [EqCallableC, EqCallableCSlots]) + def test_equal_callable(self, cls): + """ + Equal objects are detected as equal. + """ + assert cls("Test", 1) == cls("test", 1) + assert cls("Test", 1) != cls("test", 2) + assert not (cls("Test", 1) != cls("test", 1)) + assert not (cls("Test", 1) == cls("test", 2)) + @pytest.mark.parametrize("cls", [EqC, EqCSlots]) def test_unequal_same_class(self, cls): """ @@ -114,7 +149,17 @@ def test_unequal_same_class(self, cls): assert cls(1, 2) != cls(2, 1) assert not (cls(1, 2) == cls(2, 1)) - @pytest.mark.parametrize("cls", [EqC, EqCSlots]) + @pytest.mark.parametrize("cls", [EqCallableC, EqCallableCSlots]) + def test_unequal_same_class_callable(self, cls): + """ + Unequal objects of correct type are detected as unequal. + """ + assert cls("Test", 1) != cls("foo", 2) + assert not (cls("Test", 1) == cls("foo", 2)) + + @pytest.mark.parametrize( + "cls", [EqC, EqCSlots, EqCallableC, EqCallableCSlots] + ) def test_unequal_different_class(self, cls): """ Unequal objects of different type are detected even if their attributes @@ -140,7 +185,21 @@ def test_lt(self, cls): ]: assert cls(*a) < cls(*b) - @pytest.mark.parametrize("cls", [OrderC, OrderCSlots]) + @pytest.mark.parametrize("cls", [OrderCallableC, OrderCallableCSlots]) + def test_lt_callable(self, cls): + """ + __lt__ compares objects as tuples of attribute values. + """ + # Note: "A" < "a" + for a, b in [ + (("test1", 1), ("Test1", 2)), + (("test0", 1), ("Test1", 1)), + ]: + assert cls(*a) < cls(*b) + + @pytest.mark.parametrize( + "cls", [OrderC, OrderCSlots, OrderCallableC, OrderCallableCSlots] + ) def test_lt_unordable(self, cls): """ __lt__ returns NotImplemented if classes differ. @@ -161,7 +220,23 @@ def test_le(self, cls): ]: assert cls(*a) <= cls(*b) - @pytest.mark.parametrize("cls", [OrderC, OrderCSlots]) + @pytest.mark.parametrize("cls", [OrderCallableC, OrderCallableCSlots]) + def test_le_callable(self, cls): + """ + __le__ compares objects as tuples of attribute values. + """ + # Note: "A" < "a" + for a, b in [ + (("test1", 1), ("Test1", 1)), + (("test1", 1), ("Test1", 2)), + (("test0", 1), ("Test1", 1)), + (("test0", 2), ("Test1", 1)), + ]: + assert cls(*a) <= cls(*b) + + @pytest.mark.parametrize( + "cls", [OrderC, OrderCSlots, OrderCallableC, OrderCallableCSlots] + ) def test_le_unordable(self, cls): """ __le__ returns NotImplemented if classes differ. @@ -180,7 +255,21 @@ def test_gt(self, cls): ]: assert cls(*a) > cls(*b) - @pytest.mark.parametrize("cls", [OrderC, OrderCSlots]) + @pytest.mark.parametrize("cls", [OrderCallableC, OrderCallableCSlots]) + def test_gt_callable(self, cls): + """ + __gt__ compares objects as tuples of attribute values. + """ + # Note: "A" < "a" + for a, b in [ + (("Test1", 2), ("test1", 1)), + (("Test1", 1), ("test0", 1)), + ]: + assert cls(*a) > cls(*b) + + @pytest.mark.parametrize( + "cls", [OrderC, OrderCSlots, OrderCallableC, OrderCallableCSlots] + ) def test_gt_unordable(self, cls): """ __gt__ returns NotImplemented if classes differ. @@ -201,7 +290,23 @@ def test_ge(self, cls): ]: assert cls(*a) >= cls(*b) - @pytest.mark.parametrize("cls", [OrderC, OrderCSlots]) + @pytest.mark.parametrize("cls", [OrderCallableC, OrderCallableCSlots]) + def test_ge_callable(self, cls): + """ + __ge__ compares objects as tuples of attribute values. + """ + # Note: "A" < "a" + for a, b in [ + (("Test1", 1), ("test1", 1)), + (("Test1", 2), ("test1", 1)), + (("Test1", 1), ("test0", 1)), + (("Test1", 1), ("test0", 2)), + ]: + assert cls(*a) >= cls(*b) + + @pytest.mark.parametrize( + "cls", [OrderC, OrderCSlots, OrderCallableC, OrderCallableCSlots] + ) def test_ge_unordable(self, cls): """ __ge__ returns NotImplemented if classes differ. diff --git a/tests/test_make.py b/tests/test_make.py index 4ba413ad5..c9085c77b 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -29,7 +29,8 @@ _Attributes, _ClassBuilder, _CountingAttr, - _determine_eq_order, + _determine_attrib_eq_order, + _determine_attrs_eq_order, _determine_whether_to_implement, _transform_attrs, and_, @@ -222,7 +223,8 @@ class C(object): "No mandatory attributes allowed after an attribute with a " "default value or factory. Attribute in question: Attribute" "(name='y', default=NOTHING, validator=None, repr=True, " - "eq=True, order=True, hash=None, init=True, " + "eq=True, eq_key=None, order=True, order_key=None, " + "hash=None, init=True, " "metadata=mappingproxy({}), type=None, converter=None, " "kw_only=False, inherited=False, on_setattr=None)", ) == e.value.args @@ -1809,19 +1811,19 @@ class B(A): a > b -class TestDetermineEqOrder(object): +class TestDetermineAttrsEqOrder(object): def test_default(self): """ If all are set to None, set both eq and order to the passed default. """ - assert (42, 42) == _determine_eq_order(None, None, None, 42) + assert (42, 42) == _determine_attrs_eq_order(None, None, None, 42) @pytest.mark.parametrize("eq", [True, False]) def test_order_mirrors_eq_by_default(self, eq): """ If order is None, it mirrors eq. """ - assert (eq, eq) == _determine_eq_order(None, eq, None, True) + assert (eq, eq) == _determine_attrs_eq_order(None, eq, None, True) def test_order_without_eq(self): """ @@ -1830,7 +1832,7 @@ def test_order_without_eq(self): with pytest.raises( ValueError, match="`order` can only be True if `eq` is True too." ): - _determine_eq_order(None, False, True, True) + _determine_attrs_eq_order(None, False, True, True) @given(cmp=booleans(), eq=optional_bool, order=optional_bool) def test_mix(self, cmp, eq, order): @@ -1842,7 +1844,7 @@ def test_mix(self, cmp, eq, order): with pytest.raises( ValueError, match="Don't mix `cmp` with `eq' and `order`." ): - _determine_eq_order(cmp, eq, order, True) + _determine_attrs_eq_order(cmp, eq, order, True) def test_cmp_deprecated(self): """ @@ -1863,6 +1865,133 @@ class C(object): ) +class TestDetermineAttribEqOrder(object): + def test_default(self): + """ + If all are set to None, set both eq and order to the passed default. + """ + assert (42, None, 42, None) == _determine_attrib_eq_order( + None, None, None, 42 + ) + + def test_eq_callable_order_boolean(self): + """ + eq=callable or order=callable need to transformed into eq/eq_key + or order/order_key. + """ + assert (True, str.lower, False, None) == _determine_attrib_eq_order( + None, str.lower, False, True + ) + + def test_eq_callable_order_callable(self): + """ + eq=callable or order=callable need to transformed into eq/eq_key + or order/order_key. + """ + assert (True, str.lower, True, abs) == _determine_attrib_eq_order( + None, str.lower, abs, True + ) + + def test_eq_boolean_order_callable(self): + """ + eq=callable or order=callable need to transformed into eq/eq_key + or order/order_key. + """ + assert (True, None, True, str.lower) == _determine_attrib_eq_order( + None, True, str.lower, True + ) + + @pytest.mark.parametrize("eq", [True, False]) + def test_order_mirrors_eq_by_default(self, eq): + """ + If order is None, it mirrors eq. + """ + assert (eq, None, eq, None) == _determine_attrib_eq_order( + None, eq, None, True + ) + + def test_order_missing_and_custom_eq(self): + """ + If eq is customized and order is missing, order mirrors eq + but a warning is raised. + """ + with pytest.warns(None) as wr: + + assert ( + True, + str.lower, + True, + str.lower, + ) == _determine_attrib_eq_order(None, str.lower, None, True) + + (w,) = wr.list + + assert ( + "You have customized the behaviour of `eq` but not of `order`. " + "This is probably a bug." == w.message.args[0] + ) + + def test_order_without_eq(self): + """ + eq=False, order=True raises a meaningful ValueError. + """ + with pytest.raises( + ValueError, match="`order` can only be True if `eq` is True too." + ): + _determine_attrib_eq_order(None, False, True, True) + + @given(cmp=booleans(), eq=optional_bool, order=optional_bool) + def test_mix(self, cmp, eq, order): + """ + If cmp is not None, eq and order must be None and vice versa. + """ + assume(eq is not None or order is not None) + + with pytest.raises( + ValueError, match="Don't mix `cmp` with `eq' and `order`." + ): + _determine_attrib_eq_order(cmp, eq, order, True) + + def test_boolean_cmp_deprecated(self): + """ + Passing a cmp that is not None raises a DeprecationWarning. + """ + with pytest.deprecated_call() as dc: + + assert (True, None, True, None) == _determine_attrib_eq_order( + True, None, None, True + ) + + (w,) = dc.list + + assert ( + "The usage of `cmp` is deprecated and will be removed on or after " + "2021-06-01. Please use `eq` and `order` instead." + == w.message.args[0] + ) + + def test_callable_cmp_deprecated(self): + """ + Passing a cmp that is not None raises a DeprecationWarning. + """ + with pytest.deprecated_call() as dc: + + assert ( + True, + str.lower, + True, + str.lower, + ) == _determine_attrib_eq_order(str.lower, None, None, True) + + (w,) = dc.list + + assert ( + "The usage of `cmp` is deprecated and will be removed on or after " + "2021-06-01. Please use `eq` and `order` instead." + == w.message.args[0] + ) + + class TestDocs: @pytest.mark.parametrize( "meth_name",