Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def __init__(self, term: Union[str, UnboundTerm[Any]]):

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the UnboundPredicate class."""
return self.term == other.term if isinstance(other, UnboundPredicate) else False
return self.term == other.term if isinstance(other, self.__class__) else False

@abstractmethod
def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression:
Expand Down Expand Up @@ -531,7 +531,7 @@ def __repr__(self) -> str:

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the SetPredicate class."""
return self.term == other.term and self.literals == other.literals if isinstance(other, SetPredicate) else False
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False

def __getnewargs__(self) -> Tuple[UnboundTerm[L], Set[Literal[L]]]:
"""Pickle the SetPredicate class."""
Expand Down Expand Up @@ -664,12 +664,6 @@ def __invert__(self) -> In[L]:
"""Transform the Expression into its negated version."""
return In[L](self.term, self.literals)

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the NotIn class."""
if isinstance(other, NotIn):
return self.term == other.term and self.literals == other.literals
return False

@property
def as_bound(self) -> Type[BoundNotIn[L]]:
return BoundNotIn[L]
Expand Down Expand Up @@ -701,7 +695,7 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredi

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the LiteralPredicate class."""
if isinstance(other, LiteralPredicate):
if isinstance(other, self.__class__):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is for the LiteralPredicate, but there are more:

  • Line 367 should also use self.__class__ for IsNull, IsNaN etc.
  • Can you add self.__class__ to the SetPredicate on line 534.
  • The __eq__ of NotIn on line 667 can go (In doesn't have one either).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raised a PR danielcweeks#1 to get these fixed :)

return self.term == other.term and self.literal == other.literal
return False

Expand Down
4 changes: 2 additions & 2 deletions tests/expressions/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_greater_than() -> None:


def test_greater_than_or_equal() -> None:
assert GreaterThanOrEqual("foo", 5) == parser.parse("foo <= 5")
assert GreaterThanOrEqual("foo", "a") == parser.parse("'a' >= foo")
assert GreaterThanOrEqual("foo", 5) == parser.parse("foo >= 5")
assert GreaterThanOrEqual("foo", "a") == parser.parse("'a' <= foo")


def test_equal_to() -> None:
Expand Down
20 changes: 16 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,15 +830,27 @@ def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReferen


def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project("name", BoundGreaterThan(term=bound_reference_str, literal=literal("data"))) == EqualTo(
term="name", literal=literal("da")
)
assert TruncateTransform(2).project(
"name", BoundGreaterThan(term=bound_reference_str, literal=literal("data"))
) == GreaterThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project(
"name", BoundGreaterThanOrEqual(term=bound_reference_str, literal=literal("data"))
) == EqualTo(term="name", literal=literal("da"))
) == GreaterThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project(
"name", BoundLessThan(term=bound_reference_str, literal=literal("data"))
) == LessThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project(
"name", BoundLessThanOrEqual(term=bound_reference_str, literal=literal("data"))
) == LessThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference[str]) -> None:
Expand Down