Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pyiceberg/expressions/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
LongLiteral,
StringLiteral,
)
from pyiceberg.transforms import UnboundTransform
from pyiceberg.typedef import L

ParserElement.enablePackrat()
Expand All @@ -74,6 +75,7 @@
NULL = CaselessKeyword("null")
NAN = CaselessKeyword("nan")
LIKE = CaselessKeyword("like")
CAST = CaselessKeyword("cast")

identifier = Word(alphas, alphanums + "_$").set_results_name("identifier")
column = DelimitedList(identifier, delim=".", combine=False).set_results_name("column")
Expand Down Expand Up @@ -240,7 +242,17 @@ def _evaluate_like_statement(result: ParseResults) -> BooleanExpression:
return EqualTo(result.column, StringLiteral(literal_like.value.replace('\\%', '%')))


predicate = (comparison | in_check | null_check | nan_check | starts_check | boolean).set_results_name("predicate")
cast_expression = (CAST + "(" + column + "as" + identifier + Suppress(")")).set_results_name("cast")


@cast_expression.set_parse_action
def _(result: ParseResults) -> UnboundTransform[L]:
return UnboundTransform(result.column, result[-1])


predicate = (cast_expression | comparison | in_check | null_check | nan_check | starts_check | boolean).set_results_name(
"predicate"
)


def handle_not(result: ParseResults) -> Not:
Expand Down
38 changes: 36 additions & 2 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import singledispatch
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Any, Callable, Generic, Optional, Type, TypeVar
from typing import Literal as LiteralType
from uuid import UUID

Expand All @@ -38,6 +38,7 @@
BoundNotIn,
BoundNotStartsWith,
BoundPredicate,
BoundReference,
BoundSetPredicate,
BoundStartsWith,
BoundTerm,
Expand All @@ -49,6 +50,7 @@
Reference,
StartsWith,
UnboundPredicate,
UnboundTerm,
)
from pyiceberg.expressions.literals import (
DateLiteral,
Expand All @@ -58,7 +60,8 @@
TimestampLiteral,
literal,
)
from pyiceberg.typedef import IcebergRootModel, L
from pyiceberg.schema import Schema
from pyiceberg.typedef import IcebergRootModel, L, StructProtocol
from pyiceberg.types import (
BinaryType,
DateType,
Expand Down Expand Up @@ -821,3 +824,34 @@ class BoundTransform(BoundTerm[L]):
def __init__(self, term: BoundTerm[L], transform: Transform[L, Any]):
self.term: BoundTerm[L] = term
self.transform = transform
Comment on lines 824 to 826
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, term: BoundTerm[L], transform: Transform[L, Any]):
self.term: BoundTerm[L] = term
self.transform = transform
def __init__(self, term: BoundTerm[L], transform_func: Callable[Optional[L], Optional[Any]]):
self.term: BoundTerm[L] = term
self.transform = transform


def eval(self, struct: StructProtocol) -> L:
"""Return the value at the referenced field's position in an object that abides by the StructProtocol."""
return self.term.eval(struct)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.term.eval(struct)
return self.transform(self.term.eval(struct))


def ref(self) -> BoundReference[L]:
"""Return the bound reference."""
return self.term.ref()


class UnboundTransform(UnboundTerm[L]):
"""An unbound transform expression."""

transform: Transform[L, Any]

def __init__(self, term: UnboundTerm[L], transform: Transform[L, Any]):
self.term: UnboundTerm[L] = term
self.transform = transform

def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundTransform[L]:
bound_term = self.term.bind(schema, case_sensitive)

if not self.transform.can_transform(bound_term.ref().field.field_type):
raise ValueError("some better error message")

else:
return BoundTransform(bound_term, self.transform)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should actually instantiate the callable of the transform:

Suggested change
return BoundTransform(bound_term, self.transform)
return BoundTransform(bound_term, self.transform.transform(bound_term.ref().field.field_type))


@property
def as_bound(self) -> Type[BoundTerm[L]]:
return BoundTerm[L]
6 changes: 6 additions & 0 deletions tests/expressions/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
StartsWith,
parser,
)
from pyiceberg.transforms import Reference


def test_true() -> None:
Expand Down Expand Up @@ -199,3 +200,8 @@ def test_with_function() -> None:
parser.parse("foo = 1 and lower(bar) = '2'")

assert "Expected end of text, found 'and'" in str(exc_info)


def test_cast() -> None:
cast = parser.parse("CAST(created_at as date)")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to see some more tests here. The most important part that's missing here is converting date to a DayTransform. Other options such as year, month etc should be added as well.

assert cast.term == Reference("created_at")