Skip to content

Commit 1416fec

Browse files
committed
Apply suggestions: explicitly raise TypeError when the flag value is None
1 parent e81c2ea commit 1416fec

2 files changed

Lines changed: 41 additions & 20 deletions

File tree

Lib/enum.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,37 +1560,50 @@ def __str__(self):
15601560
def __bool__(self):
15611561
return bool(self._value_)
15621562

1563+
def _get_value(self, flag):
1564+
if isinstance(flag, self.__class__):
1565+
return flag._value_
1566+
elif self._member_type_ is not object and isinstance(flag, self._member_type_):
1567+
return flag
1568+
return NotImplemented
1569+
15631570
def __or__(self, other):
1564-
if isinstance(other, self.__class__):
1565-
other = other._value_
1566-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1567-
other = other
1568-
else:
1571+
other_value = self._get_value(other)
1572+
if other_value is NotImplemented:
15691573
return NotImplemented
1574+
1575+
for flag in self, other:
1576+
if self._get_value(flag) is None:
1577+
raise TypeError(f"'{flag}' cannot be combined with other flags with |")
15701578
value = self._value_
1571-
return self.__class__(value | other)
1579+
return self.__class__(value | other_value)
15721580

15731581
def __and__(self, other):
1574-
if isinstance(other, self.__class__):
1575-
other = other._value_
1576-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1577-
other = other
1578-
else:
1582+
other_value = self._get_value(other)
1583+
if other_value is NotImplemented:
15791584
return NotImplemented
1585+
1586+
for flag in self, other:
1587+
if self._get_value(flag) is None:
1588+
raise TypeError(f"'{flag}' cannot be combined with other flags with &")
15801589
value = self._value_
1581-
return self.__class__(value & other)
1590+
return self.__class__(value & other_value)
15821591

15831592
def __xor__(self, other):
1584-
if isinstance(other, self.__class__):
1585-
other = other._value_
1586-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1587-
other = other
1588-
else:
1593+
other_value = self._get_value(other)
1594+
if other_value is NotImplemented:
15891595
return NotImplemented
1596+
1597+
for flag in self, other:
1598+
if self._get_value(flag) is None:
1599+
raise TypeError(f"'{flag}' cannot be combined with other flags with ^")
15901600
value = self._value_
1591-
return self.__class__(value ^ other)
1601+
return self.__class__(value ^ other_value)
15921602

15931603
def __invert__(self):
1604+
if self._get_value(self) is None:
1605+
raise TypeError(f"'{self}' cannot be inverted")
1606+
15941607
if self._inverted_ is None:
15951608
if self._boundary_ in (EJECT, KEEP):
15961609
self._inverted_ = self.__class__(~self._value_)

Lib/test/test_enum.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,14 +1047,22 @@ class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase):
10471047

10481048
class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
10491049
enum_type = Flag
1050-
1050+
10511051
def test_none_member(self):
10521052
class FlagWithNoneMember(Flag):
10531053
A = 1
10541054
E = None
1055-
1055+
10561056
self.assertEqual(FlagWithNoneMember.A.value, 1)
10571057
self.assertIs(FlagWithNoneMember.E.value, None)
1058+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with |"):
1059+
FlagWithNoneMember.A | FlagWithNoneMember.E
1060+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with &"):
1061+
FlagWithNoneMember.E & FlagWithNoneMember.A
1062+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with \^"):
1063+
FlagWithNoneMember.A ^ FlagWithNoneMember.E
1064+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be inverted"):
1065+
~FlagWithNoneMember.E
10581066

10591067

10601068
class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):

0 commit comments

Comments
 (0)