Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cpp/src/arrow/dataset/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,22 @@ std::shared_ptr<Expression> CastExpression::Assume(const Expression& given) cons
return std::make_shared<CastExpression>(std::move(operand), std::move(like), options_);
}

const std::shared_ptr<DataType>& CastExpression::to_type() const {
if (arrow::util::holds_alternative<std::shared_ptr<DataType>>(to_)) {
return arrow::util::get<std::shared_ptr<DataType>>(to_);
}
static std::shared_ptr<DataType> null;
return null;
}

const std::shared_ptr<Expression>& CastExpression::like_expr() const {
if (arrow::util::holds_alternative<std::shared_ptr<Expression>>(to_)) {
return arrow::util::get<std::shared_ptr<Expression>>(to_);
}
static std::shared_ptr<Expression> null;
return null;
}

std::string FieldExpression::ToString() const { return name_; }

std::string OperatorName(compute::CompareOperator op) {
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/dataset/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,14 @@ class ARROW_DS_EXPORT CastExpression final

const compute::CastOptions& options() const { return options_; }

/// Return the target type of this CastTo expression, or nullptr if this is a
/// CastLike expression.
const std::shared_ptr<DataType>& to_type() const;

/// Return the target expression of this CastLike expression, or nullptr if
/// this is a CastTo expression.
const std::shared_ptr<Expression>& like_expr() const;

private:
util::variant<std::shared_ptr<DataType>, std::shared_ptr<Expression>> to_;
compute::CastOptions options_;
Expand Down
139 changes: 109 additions & 30 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ from pyarrow.lib cimport *
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow.compat import frombytes, tobytes
from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
from pyarrow.types import (is_null, is_boolean, is_integer, is_floating,
is_string)


def _forbid_instantiation(klass, subclasses_instead=True):
Expand Down Expand Up @@ -1305,6 +1307,10 @@ cdef class UnaryExpression(Expression):
Expression.init(self, sp)
self.unary = <CUnaryExpression*> sp.get()

@property
def operand(self):
return Expression.wrap(self.unary.operand())


cdef class BinaryExpression(Expression):

Expand Down Expand Up @@ -1332,7 +1338,9 @@ cdef class ScalarExpression(Expression):
shared_ptr[CScalar] scalar
shared_ptr[CScalarExpression] expr

if isinstance(value, bool):
if value is None:
scalar.reset(new CNullScalar())
elif isinstance(value, bool):
scalar = MakeScalar(<c_bool>value)
elif isinstance(value, float):
scalar = MakeScalar(<double>value)
Expand All @@ -1350,6 +1358,14 @@ cdef class ScalarExpression(Expression):
Expression.init(self, sp)
self.scalar = <CScalarExpression*> sp.get()

@property
def value(self):
cdef ScalarValue scalar = pyarrow_wrap_scalar(self.scalar.value())
return scalar.as_py()

def __reduce__(self):
return ScalarExpression, (self.value,)


cdef class FieldExpression(Expression):

Expand All @@ -1366,9 +1382,13 @@ cdef class FieldExpression(Expression):
Expression.init(self, sp)
self.scalar = <CFieldExpression*> sp.get()

@property
def name(self):
return frombytes(self.scalar.name())

def __reduce__(self):
return FieldExpression, (self.name,)


cpdef enum CompareOperator:
Equal = <int8_t> CCompareOperator_EQUAL
Expand Down Expand Up @@ -1399,9 +1419,15 @@ cdef class ComparisonExpression(BinaryExpression):
BinaryExpression.init(self, sp)
self.comparison = <CComparisonExpression*> sp.get()

@property
def op(self):
return <CompareOperator> self.comparison.op()

def __reduce__(self):
return ComparisonExpression, (
self.op, self.left_operand, self.right_operand
)


cdef class IsValidExpression(UnaryExpression):

Expand All @@ -1410,34 +1436,90 @@ cdef class IsValidExpression(UnaryExpression):
expr = make_shared[CIsValidExpression](operand.unwrap())
self.init(<shared_ptr[CExpression]> expr)

def __reduce__(self):
return IsValidExpression, (self.operand,)


cdef class CastExpression(UnaryExpression):

cdef CCastExpression *cast

def __init__(self, Expression operand not None, DataType to not None,
bint safe=True):
# TODO(kszucs): safe is consistently used across pyarrow, but on long
# term we should expose the CastOptions object
cdef:
CastOptions options
shared_ptr[CExpression] expr
options = CastOptions.safe() if safe else CastOptions.unsafe()
expr.reset(new CCastExpression(
operand.unwrap(),
pyarrow_unwrap_data_type(to),
options.unwrap()
))
expr.reset(
new CCastExpression(
operand.unwrap(),
pyarrow_unwrap_data_type(to),
options.unwrap()
)
)
self.init(expr)

cdef void init(self, const shared_ptr[CExpression]& sp):
UnaryExpression.init(self, sp)
self.cast = <CCastExpression*> sp.get()

@property
def to(self):
"""
Target DataType or Expression of the cast operation.

Returns
-------
DataType or Expression
"""
cdef shared_ptr[CDataType] typ = self.cast.to_type()

if typ.get() != nullptr:
return pyarrow_wrap_data_type(typ)
else:
raise TypeError(
'Cannot determine the target type of the cast expression'
)

@property
def safe(self):
"""
Whether to check for overflows or other unsafe conversions.

Returns
-------
bool
"""
cdef CastOptions options = CastOptions.wrap(self.cast.options())
return options.is_safe()

def __reduce__(self):
return CastExpression, (self.operand, self.to, self.safe)


cdef class InExpression(UnaryExpression):

def __init__(self, Expression operand not None, Array haystack not None):
cdef:
CInExpression* inexpr

def __init__(self, Expression operand not None, Array set_ not None):
cdef shared_ptr[CExpression] expr
expr.reset(
new CInExpression(operand.unwrap(), pyarrow_unwrap_array(haystack))
new CInExpression(operand.unwrap(), pyarrow_unwrap_array(set_))
)
self.init(expr)

cdef void init(self, const shared_ptr[CExpression]& sp):
UnaryExpression.init(self, sp)
self.inexpr = <CInExpression*> sp.get()

@property
def set_(self):
return pyarrow_wrap_array(self.inexpr.set())

def __reduce__(self):
return InExpression, (self.operand, self.set_)


cdef class NotExpression(UnaryExpression):

Expand All @@ -1446,30 +1528,27 @@ cdef class NotExpression(UnaryExpression):
expr = CMakeNotExpression(operand.unwrap())
self.init(<shared_ptr[CExpression]> expr)

def __reduce__(self):
return NotExpression, (self.operand,)


cdef class AndExpression(BinaryExpression):

def __init__(self, Expression left not None, Expression right not None,
*additional_operands):
cdef:
Expression operand
vector[shared_ptr[CExpression]] exprs
exprs.push_back(left.unwrap())
exprs.push_back(right.unwrap())
for operand in additional_operands:
exprs.push_back(operand.unwrap())
self.init(CMakeAndExpression(exprs))
def __init__(self, Expression left not None, Expression right not None):
cdef shared_ptr[CAndExpression] expr
expr.reset(new CAndExpression(left.unwrap(), right.unwrap()))
self.init(<shared_ptr[CExpression]> expr)

def __reduce__(self):
return AndExpression, (self.left_operand, self.right_operand)


cdef class OrExpression(BinaryExpression):

def __init__(self, Expression left not None, Expression right not None,
*additional_operands):
cdef:
Expression operand
vector[shared_ptr[CExpression]] exprs
exprs.push_back(left.unwrap())
exprs.push_back(right.unwrap())
for operand in additional_operands:
exprs.push_back(operand.unwrap())
self.init(CMakeOrExpression(exprs))
def __init__(self, Expression left not None, Expression right not None):
cdef shared_ptr[COrExpression] expr
expr.reset(new COrExpression(left.unwrap(), right.unwrap()))
self.init(<shared_ptr[CExpression]> expr)

def __reduce__(self):
return OrExpression, (self.left_operand, self.right_operand)
9 changes: 9 additions & 0 deletions python/pyarrow/compute.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ cdef class CastOptions:
def unsafe():
return CastOptions.wrap(CCastOptions.Unsafe())

def is_safe(self):
return not (
self.options.allow_int_overflow or
self.options.allow_time_truncate or
self.options.allow_time_overflow or
self.options.allow_float_truncate or
self.options.allow_invalid_utf8
)

cdef inline CCastOptions unwrap(self) nogil:
return self.options

Expand Down
9 changes: 8 additions & 1 deletion python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,13 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
cdef cppclass CScalar" arrow::Scalar":
shared_ptr[CDataType] type
c_bool is_valid
c_string ToString() const

cdef cppclass CNullScalar" arrow::NullScalar"(CScalar):
CNullScalar()

cdef cppclass CBooleanScalar" arrow::BooleanScalar"(CScalar):
c_bool value

cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar):
int8_t value
Expand Down Expand Up @@ -846,7 +853,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
double value

cdef cppclass CStringScalar" arrow::StringScalar"(CScalar):
pass
shared_ptr[CBuffer] value

shared_ptr[CScalar] MakeScalar[Value](Value value)
shared_ptr[CScalar] MakeStringScalar" arrow::MakeScalar"(c_string value)
Expand Down
Loading