From b826431247cb4f8e959f77064ff2a3d59720fa15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 24 Mar 2020 16:14:46 +0100 Subject: [PATCH 1/6] make dataset expressions serializable --- cpp/src/arrow/dataset/filter.cc | 8 ++ cpp/src/arrow/dataset/filter.h | 2 + python/pyarrow/_dataset.pyx | 141 +++++++++++++++---- python/pyarrow/includes/libarrow.pxd | 4 + python/pyarrow/includes/libarrow_dataset.pxd | 39 ++--- python/pyarrow/tests/test_dataset.py | 26 ++-- 6 files changed, 166 insertions(+), 54 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index ee401c81345..32ac4347f01 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -607,6 +607,14 @@ std::shared_ptr CastExpression::Assume(const Expression& given) cons return std::make_shared(std::move(operand), std::move(like), options_); } +std::shared_ptr CastExpression::to_type() const { + if (arrow::util::holds_alternative>(to_)) { + return arrow::util::get>(to_); + } else { + return std::shared_ptr(nullptr); + } +} + std::string FieldExpression::ToString() const { return name_; } std::string OperatorName(compute::CompareOperator op) { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index be1ba34da18..2137cdb48c0 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -377,6 +377,8 @@ class ARROW_DS_EXPORT CastExpression final const compute::CastOptions& options() const { return options_; } + std::shared_ptr to_type() const; + private: util::variant, std::shared_ptr> to_; compute::CastOptions options_; diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index bbb990392b6..fe7967bd87a 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -27,6 +27,8 @@ from pyarrow.lib cimport * from pyarrow.includes.libarrow_dataset cimport * from pyarrow.compat import frombytes, tobytes from pyarrow._fs cimport FileSystem, FileInfo, FileSelector +from pyarrow.types import (is_null, is_boolean, is_integer, is_floating, + is_string) def _forbid_instantiation(klass, subclasses_instead=True): @@ -1305,6 +1307,10 @@ cdef class UnaryExpression(Expression): Expression.init(self, sp) self.unary = sp.get() + @property + def operand(self): + return Expression.wrap(self.unary.operand()) + cdef class BinaryExpression(Expression): @@ -1332,7 +1338,9 @@ cdef class ScalarExpression(Expression): shared_ptr[CScalar] scalar shared_ptr[CScalarExpression] expr - if isinstance(value, bool): + if value is None: + scalar.reset(new CNullScalar()) + elif isinstance(value, bool): scalar = MakeScalar(value) elif isinstance(value, float): scalar = MakeScalar(value) @@ -1350,6 +1358,38 @@ cdef class ScalarExpression(Expression): Expression.init(self, sp) self.scalar = sp.get() + @property + def value(self): + cdef: + shared_ptr[CScalar] scalar = self.scalar.value() + DataType typ = pyarrow_wrap_data_type(scalar.get().type) + c_string val + + if is_null(typ): + return None + + val = scalar.get().ToString() + if is_integer(typ): + return int(val) + elif is_floating(typ): + return float(val) + elif is_string(typ): + return frombytes(val) + elif is_boolean(typ): + if val == b'true': + return True + elif val == b'false': + return False + else: + raise ValueError( + 'Unexpected boolean value: {}'.format(frombytes(val)) + ) + else: + raise TypeError('Not yet supported scalar type: {}'.format(typ)) + + def __reduce__(self): + return ScalarExpression, (self.value,) + cdef class FieldExpression(Expression): @@ -1366,9 +1406,13 @@ cdef class FieldExpression(Expression): Expression.init(self, sp) self.scalar = sp.get() + @property def name(self): return frombytes(self.scalar.name()) + def __reduce__(self): + return FieldExpression, (self.name,) + cpdef enum CompareOperator: Equal = CCompareOperator_EQUAL @@ -1399,9 +1443,15 @@ cdef class ComparisonExpression(BinaryExpression): BinaryExpression.init(self, sp) self.comparison = sp.get() + @property def op(self): return self.comparison.op() + def __reduce__(self): + return ComparisonExpression, ( + self.op, self.left_operand, self.right_operand + ) + cdef class IsValidExpression(UnaryExpression): @@ -1410,27 +1460,54 @@ cdef class IsValidExpression(UnaryExpression): expr = make_shared[CIsValidExpression](operand.unwrap()) self.init( expr) + def __reduce__(self): + return IsValidExpression, (self.operand,) + cdef class CastExpression(UnaryExpression): + cdef CCastExpression *cast + def __init__(self, Expression operand not None, DataType to not None, bint safe=True): - # TODO(kszucs): safe is consistently used across pyarrow, but on long - # term we should expose the CastOptions object cdef: CastOptions options shared_ptr[CExpression] expr options = CastOptions.safe() if safe else CastOptions.unsafe() - expr.reset(new CCastExpression( - operand.unwrap(), - pyarrow_unwrap_data_type(to), - options.unwrap() - )) + expr.reset( + new CCastExpression( + operand.unwrap(), + pyarrow_unwrap_data_type(to), + options.unwrap() + ) + ) self.init(expr) + cdef void init(self, const shared_ptr[CExpression]& sp): + UnaryExpression.init(self, sp) + self.cast = sp.get() + + @property + def to(self): + # safe to assume that CastExpression::to_ variant holds a DataType + # instance because the construction from python only allows that + return pyarrow_wrap_data_type(self.cast.to_type()) + + @property + def safe(self): + cdef CCastOptions options = self.cast.options() + # infer safeness from any of the allow_* properties of the cast option + return not options.allow_int_overflow + + def __reduce__(self): + return CastExpression, (self.operand, self.to, self.safe) + cdef class InExpression(UnaryExpression): + cdef: + CInExpression* inexpr + def __init__(self, Expression operand not None, Array haystack not None): cdef shared_ptr[CExpression] expr expr.reset( @@ -1438,6 +1515,17 @@ cdef class InExpression(UnaryExpression): ) self.init(expr) + cdef void init(self, const shared_ptr[CExpression]& sp): + UnaryExpression.init(self, sp) + self.inexpr = sp.get() + + @property + def values(self): + return pyarrow_wrap_array(self.inexpr.set()) + + def __reduce__(self): + return InExpression, (self.operand, self.values) + cdef class NotExpression(UnaryExpression): @@ -1446,30 +1534,27 @@ cdef class NotExpression(UnaryExpression): expr = CMakeNotExpression(operand.unwrap()) self.init( expr) + def __reduce__(self): + return NotExpression, (self.operand,) + cdef class AndExpression(BinaryExpression): - def __init__(self, Expression left not None, Expression right not None, - *additional_operands): - cdef: - Expression operand - vector[shared_ptr[CExpression]] exprs - exprs.push_back(left.unwrap()) - exprs.push_back(right.unwrap()) - for operand in additional_operands: - exprs.push_back(operand.unwrap()) - self.init(CMakeAndExpression(exprs)) + def __init__(self, Expression left not None, Expression right not None): + cdef shared_ptr[CAndExpression] expr + expr.reset(new CAndExpression(left.unwrap(), right.unwrap())) + self.init( expr) + + def __reduce__(self): + return AndExpression, (self.left_operand, self.right_operand) cdef class OrExpression(BinaryExpression): - def __init__(self, Expression left not None, Expression right not None, - *additional_operands): - cdef: - Expression operand - vector[shared_ptr[CExpression]] exprs - exprs.push_back(left.unwrap()) - exprs.push_back(right.unwrap()) - for operand in additional_operands: - exprs.push_back(operand.unwrap()) - self.init(CMakeOrExpression(exprs)) + def __init__(self, Expression left not None, Expression right not None): + cdef shared_ptr[COrExpression] expr + expr.reset(new COrExpression(left.unwrap(), right.unwrap())) + self.init( expr) + + def __reduce__(self): + return OrExpression, (self.left_operand, self.right_operand) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index b3c1b17d89e..2b68bd12a67 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -814,6 +814,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CScalar" arrow::Scalar": shared_ptr[CDataType] type c_bool is_valid + c_string ToString() const + + cdef cppclass CNullScalar" arrow::NullScalar"(CScalar): + CNullScalar() cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar): int8_t value diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 9cac4be4b09..f9d52fdcdf9 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -49,13 +49,13 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CExpression "arrow::dataset::Expression": CExpression(CExpressionType type) - c_bool Equals(const CExpression & other) const - c_bool Equals(const shared_ptr[CExpression] & other) const + c_bool Equals(const CExpression& other) const + c_bool Equals(const shared_ptr[CExpression]& other) const c_bool IsNull() const - CResult[shared_ptr[CDataType]] Validate(const CSchema & schema) const - shared_ptr[CExpression] Assume(const CExpression & given) const + CResult[shared_ptr[CDataType]] Validate(const CSchema& schema) const + shared_ptr[CExpression] Assume(const CExpression& given) const shared_ptr[CExpression] Assume( - const shared_ptr[CExpression] & given) const + const shared_ptr[CExpression]& given) const c_string ToString() const CExpressionType type() const shared_ptr[CExpression] Copy() const @@ -65,17 +65,17 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CUnaryExpression "arrow::dataset::UnaryExpression"( CExpression): - const shared_ptr[CExpression] & operand() const + const shared_ptr[CExpression]& operand() const cdef cppclass CBinaryExpression "arrow::dataset::BinaryExpression"( CExpression): - const shared_ptr[CExpression] & left_operand() const - const shared_ptr[CExpression] & right_operand() const + const shared_ptr[CExpression]& left_operand() const + const shared_ptr[CExpression]& right_operand() const cdef cppclass CScalarExpression "arrow::dataset::ScalarExpression"( CExpression): - CScalarExpression(const shared_ptr[CScalar] & value) - const shared_ptr[CScalar] & value() const + CScalarExpression(const shared_ptr[CScalar]& value) + const shared_ptr[CScalar]& value() const cdef cppclass CFieldExpression "arrow::dataset::FieldExpression"( CExpression): @@ -91,11 +91,13 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CAndExpression "arrow::dataset::AndExpression"( CBinaryExpression): - pass + CAndExpression(shared_ptr[CExpression] left_operand, + shared_ptr[CExpression] right_operand) cdef cppclass COrExpression "arrow::dataset::OrExpression"( CBinaryExpression): - pass + COrExpression(shared_ptr[CExpression] left_operand, + shared_ptr[CExpression] right_operand) cdef cppclass CNotExpression "arrow::dataset::NotExpression"( CUnaryExpression): @@ -110,17 +112,20 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CCastExpression(shared_ptr[CExpression] operand, shared_ptr[CDataType] to, CCastOptions options) + const CCastOptions& options() const + const shared_ptr[CDataType]& to_type() const cdef cppclass CInExpression "arrow::dataset::InExpression"( CUnaryExpression): CInExpression(shared_ptr[CExpression] operand, shared_ptr[CArray] set) + const shared_ptr[CArray]& set() const cdef shared_ptr[CNotExpression] CMakeNotExpression "arrow::dataset::not_"( shared_ptr[CExpression] operand) cdef shared_ptr[CExpression] CMakeAndExpression "arrow::dataset::and_"( - const CExpressionVector & subexpressions) + const CExpressionVector& subexpressions) cdef shared_ptr[CExpression] CMakeOrExpression "arrow::dataset::or_"( - const CExpressionVector & subexpressions) + const CExpressionVector& subexpressions) cdef CResult[shared_ptr[CExpression]] CInsertImplicitCasts \ "arrow::dataset::InsertImplicitCasts"( @@ -155,8 +160,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CScannerBuilder "arrow::dataset::ScannerBuilder": CScannerBuilder(shared_ptr[CDataset], shared_ptr[CScanContext] scan_context) - CStatus Project(const vector[c_string] & columns) - CStatus Filter(const CExpression & filter) + CStatus Project(const vector[c_string]& columns) + CStatus Filter(const CExpression& filter) CStatus Filter(shared_ptr[CExpression] filter) CStatus UseThreads(c_bool use_threads) CStatus BatchSize(int64_t batch_size) @@ -185,7 +190,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[vector[shared_ptr[CSchema]]] InspectSchemas() CResult[shared_ptr[CSchema]] Inspect() CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"( - const shared_ptr[CSchema] & schema) + const shared_ptr[CSchema]& schema) CResult[shared_ptr[CDataset]] Finish() const shared_ptr[CExpression]& root_partition() CStatus SetRootPartition(shared_ptr[CExpression] partition) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index c4b40ca59e6..c5b57684b6e 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -18,6 +18,8 @@ import contextlib import operator import os +import pickle +import urllib import numpy as np import pytest @@ -337,9 +339,11 @@ def test_expression(): b = ds.ScalarExpression(1.1) c = ds.ScalarExpression(True) d = ds.ScalarExpression("string") + e = ds.ScalarExpression(None) equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - assert equal.op() == ds.CompareOperator.Equal + greater = a > b + assert equal.op == ds.CompareOperator.Equal and_ = ds.AndExpression(a, b) assert and_.left_operand.equals(a) @@ -347,14 +351,12 @@ def test_expression(): assert and_.equals(ds.AndExpression(a, b)) assert and_.equals(and_) - ds.AndExpression(a, b, c) - ds.OrExpression(a, b) - ds.OrExpression(a, b, c, d) - ds.NotExpression(ds.OrExpression(a, b, c)) - ds.IsValidExpression(a) - ds.CastExpression(a, pa.int32()) - ds.CastExpression(a, pa.int32(), safe=True) - ds.InExpression(a, pa.array([1, 2, 3])) + or_ = ds.OrExpression(a, b) + not_ = ds.NotExpression(ds.OrExpression(a, b)) + is_valid = ds.IsValidExpression(a) + cast_unsafe = ds.CastExpression(a, pa.int32()) + cast_safe = ds.CastExpression(a, pa.int32(), safe=True) + in_ = ds.InExpression(a, pa.array([1, 2, 3])) condition = ds.ComparisonExpression( ds.CompareOperator.Greater, @@ -382,6 +384,12 @@ def test_expression(): assert str(condition) == "(i64 > 5:int64)" assert "(i64 > 5:int64)" in repr(condition) + all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid, + cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7] + for expr in all_exprs: + restored = pickle.loads(pickle.dumps(expr)) + assert expr.equals(restored) + def test_expression_ergonomics(): zero = ds.scalar(0) From 8194d7670817567a3f2b1a9e7e48a251075a0e87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 24 Mar 2020 17:35:09 +0100 Subject: [PATCH 2/6] use the existing scalar wrappers --- cpp/src/arrow/dataset/filter.h | 1 + python/pyarrow/_dataset.pyx | 42 ++++++++++---------------- python/pyarrow/includes/libarrow.pxd | 5 +++- python/pyarrow/scalar.pxi | 44 ++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 2137cdb48c0..0d4a4f06d1c 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -377,6 +377,7 @@ class ARROW_DS_EXPORT CastExpression final const compute::CastOptions& options() const { return options_; } + /// Try to return with the DataType variant of the cast target. std::shared_ptr to_type() const; private: diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index fe7967bd87a..db0e382e7a2 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -1360,32 +1360,8 @@ cdef class ScalarExpression(Expression): @property def value(self): - cdef: - shared_ptr[CScalar] scalar = self.scalar.value() - DataType typ = pyarrow_wrap_data_type(scalar.get().type) - c_string val - - if is_null(typ): - return None - - val = scalar.get().ToString() - if is_integer(typ): - return int(val) - elif is_floating(typ): - return float(val) - elif is_string(typ): - return frombytes(val) - elif is_boolean(typ): - if val == b'true': - return True - elif val == b'false': - return False - else: - raise ValueError( - 'Unexpected boolean value: {}'.format(frombytes(val)) - ) - else: - raise TypeError('Not yet supported scalar type: {}'.format(typ)) + cdef ScalarValue scalar = pyarrow_wrap_scalar(self.scalar.value()) + return scalar.as_py() def __reduce__(self): return ScalarExpression, (self.value,) @@ -1489,12 +1465,26 @@ cdef class CastExpression(UnaryExpression): @property def to(self): + """ + Target DataType of the cast operation. + + Returns + ------- + DataType + """ # safe to assume that CastExpression::to_ variant holds a DataType # instance because the construction from python only allows that return pyarrow_wrap_data_type(self.cast.to_type()) @property def safe(self): + """ + Whether to check for overflows or other unsafe conversions. + + Returns + ------- + bool + """ cdef CCastOptions options = self.cast.options() # infer safeness from any of the allow_* properties of the cast option return not options.allow_int_overflow diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 2b68bd12a67..b8f5abc056b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -819,6 +819,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CNullScalar" arrow::NullScalar"(CScalar): CNullScalar() + cdef cppclass CBooleanScalar" arrow::BooleanScalar"(CScalar): + c_bool value + cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar): int8_t value @@ -850,7 +853,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: double value cdef cppclass CStringScalar" arrow::StringScalar"(CScalar): - pass + shared_ptr[CBuffer] value shared_ptr[CScalar] MakeScalar[Value](Value value) shared_ptr[CScalar] MakeStringScalar" arrow::MakeScalar"(c_string value) diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index faa385e0dd9..3538483c6cf 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1032,6 +1032,31 @@ cdef class ScalarValue(Scalar): return hash(self.as_py()) +cdef class NullScalar(ScalarValue): + """ + Concrete class for null scalars. + """ + + def as_py(self): + """ + Return this value as a Python None. + """ + return None + + +cdef class BooleanScalar(ScalarValue): + """ + Concrete class for boolean scalars. + """ + + def as_py(self): + """ + Return this value as a Python bool. + """ + cdef CBooleanScalar* sp = self.sp_scalar.get() + return sp.value if sp.is_valid else None + + cdef class UInt8Scalar(ScalarValue): """ Concrete class for uint8 scalars. @@ -1162,7 +1187,25 @@ cdef class DoubleScalar(ScalarValue): return sp.value if sp.is_valid else None +cdef class StringScalar(ScalarValue): + """ + Concrete class for string scalars. + """ + + def as_py(self): + """ + Return this value as a Python string. + """ + cdef CStringScalar* sp = self.sp_scalar.get() + if sp.is_valid: + return frombytes(pyarrow_wrap_buffer(sp.value).to_pybytes()) + else: + return None + + cdef dict _scalar_classes = { + _Type_NA: NullScalar, + _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, _Type_UINT16: UInt16Scalar, _Type_UINT32: UInt32Scalar, @@ -1173,6 +1216,7 @@ cdef dict _scalar_classes = { _Type_INT64: Int64Scalar, _Type_FLOAT: FloatScalar, _Type_DOUBLE: DoubleScalar, + _Type_STRING: StringScalar, } cdef object box_scalar(DataType type, const shared_ptr[CArray]& sp_array, From 549bd8973797fcb1f945b307eee556f70d537bad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 24 Mar 2020 18:01:35 +0100 Subject: [PATCH 3/6] expose both variants of cast expression's target type --- cpp/src/arrow/dataset/filter.cc | 14 +++++++++++--- cpp/src/arrow/dataset/filter.h | 9 +++++++-- python/pyarrow/_dataset.pyx | 18 +++++++++++++----- python/pyarrow/includes/libarrow_dataset.pxd | 1 + 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 32ac4347f01..5004e5a3263 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -607,12 +607,20 @@ std::shared_ptr CastExpression::Assume(const Expression& given) cons return std::make_shared(std::move(operand), std::move(like), options_); } -std::shared_ptr CastExpression::to_type() const { +const std::shared_ptr& CastExpression::to_type() const { if (arrow::util::holds_alternative>(to_)) { return arrow::util::get>(to_); - } else { - return std::shared_ptr(nullptr); } + static std::shared_ptr null; + return null; +} + +const std::shared_ptr& CastExpression::like_expr() const { + if (arrow::util::holds_alternative>(to_)) { + return arrow::util::get>(to_); + } + static std::shared_ptr null; + return null; } std::string FieldExpression::ToString() const { return name_; } diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 0d4a4f06d1c..e33b904bbfc 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -377,8 +377,13 @@ class ARROW_DS_EXPORT CastExpression final const compute::CastOptions& options() const { return options_; } - /// Try to return with the DataType variant of the cast target. - std::shared_ptr to_type() const; + /// Return the target type of this CastTo expression, or nullptr if this is a + /// CastLike expression. + const std::shared_ptr& to_type() const; + + /// Return the target expression of this CastLike expression, or nullptr if + /// this is a CastTo expression. + const std::shared_ptr& like_expr() const; private: util::variant, std::shared_ptr> to_; diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index db0e382e7a2..464dbe024aa 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -1466,15 +1466,23 @@ cdef class CastExpression(UnaryExpression): @property def to(self): """ - Target DataType of the cast operation. + Target DataType or Expression of the cast operation. Returns ------- - DataType + DataType or Expression """ - # safe to assume that CastExpression::to_ variant holds a DataType - # instance because the construction from python only allows that - return pyarrow_wrap_data_type(self.cast.to_type()) + cdef: + shared_ptr[CDataType] typ = self.cast.to_type() + shared_ptr[CExpression] expr = self.cast.like_expr() + if typ.get() != nullptr: + return pyarrow_wrap_data_type(typ) + elif expr.get() != nullptr: + return Expression.wrap(expr) + else: + raise TypeError( + 'Cannot determine the target type of the cast expression' + ) @property def safe(self): diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index f9d52fdcdf9..9c67fcc54ab 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -114,6 +114,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CCastOptions options) const CCastOptions& options() const const shared_ptr[CDataType]& to_type() const + const shared_ptr[CExpression]& like_expr() const cdef cppclass CInExpression "arrow::dataset::InExpression"( CUnaryExpression): From 53bce90ba4f5b726d47057caf858bccd876fd436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 24 Mar 2020 18:17:47 +0100 Subject: [PATCH 4/6] address review comments --- python/pyarrow/_dataset.pyx | 13 ++++++------- python/pyarrow/compute.pxi | 9 +++++++++ python/pyarrow/tests/test_dataset.py | 10 ++++++++-- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 464dbe024aa..ca727c6f642 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -1493,9 +1493,8 @@ cdef class CastExpression(UnaryExpression): ------- bool """ - cdef CCastOptions options = self.cast.options() - # infer safeness from any of the allow_* properties of the cast option - return not options.allow_int_overflow + cdef CastOptions options = CastOptions.wrap(self.cast.options()) + return options.is_safe() def __reduce__(self): return CastExpression, (self.operand, self.to, self.safe) @@ -1506,10 +1505,10 @@ cdef class InExpression(UnaryExpression): cdef: CInExpression* inexpr - def __init__(self, Expression operand not None, Array haystack not None): + def __init__(self, Expression operand not None, Array set_ not None): cdef shared_ptr[CExpression] expr expr.reset( - new CInExpression(operand.unwrap(), pyarrow_unwrap_array(haystack)) + new CInExpression(operand.unwrap(), pyarrow_unwrap_array(set_)) ) self.init(expr) @@ -1518,11 +1517,11 @@ cdef class InExpression(UnaryExpression): self.inexpr = sp.get() @property - def values(self): + def set_(self): return pyarrow_wrap_array(self.inexpr.set()) def __reduce__(self): - return InExpression, (self.operand, self.values) + return InExpression, (self.operand, self.set_) cdef class NotExpression(UnaryExpression): diff --git a/python/pyarrow/compute.pxi b/python/pyarrow/compute.pxi index d5cc366bfed..d0c0d4d5826 100644 --- a/python/pyarrow/compute.pxi +++ b/python/pyarrow/compute.pxi @@ -48,6 +48,15 @@ cdef class CastOptions: def unsafe(): return CastOptions.wrap(CCastOptions.Unsafe()) + def is_safe(self): + return not ( + self.options.allow_int_overflow or + self.options.allow_time_truncate or + self.options.allow_time_overflow or + self.options.allow_float_truncate or + self.options.allow_invalid_utf8 + ) + cdef inline CCastOptions unwrap(self) nogil: return self.options diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index c5b57684b6e..24f306fe68b 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -354,10 +354,16 @@ def test_expression(): or_ = ds.OrExpression(a, b) not_ = ds.NotExpression(ds.OrExpression(a, b)) is_valid = ds.IsValidExpression(a) - cast_unsafe = ds.CastExpression(a, pa.int32()) - cast_safe = ds.CastExpression(a, pa.int32(), safe=True) + cast_safe = ds.CastExpression(a, pa.int32()) + cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) in_ = ds.InExpression(a, pa.array([1, 2, 3])) + assert is_valid.operand == a + assert in_.set_.equals(pa.array([1, 2, 3])) + assert cast_unsafe.to == pa.int32() + assert cast_unsafe.safe is False + assert cast_safe.safe is True + condition = ds.ComparisonExpression( ds.CompareOperator.Greater, ds.FieldExpression('i64'), From cb2566a664f2d1cca1ce2314719096bb1c036be2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 25 Mar 2020 16:03:00 +0100 Subject: [PATCH 5/6] revert wrapping cast-like cast expression --- python/pyarrow/_dataset.pyx | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index ca727c6f642..77f1c7a3953 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -1472,13 +1472,10 @@ cdef class CastExpression(UnaryExpression): ------- DataType or Expression """ - cdef: - shared_ptr[CDataType] typ = self.cast.to_type() - shared_ptr[CExpression] expr = self.cast.like_expr() + cdef shared_ptr[CDataType] typ = self.cast.to_type() + if typ.get() != nullptr: return pyarrow_wrap_data_type(typ) - elif expr.get() != nullptr: - return Expression.wrap(expr) else: raise TypeError( 'Cannot determine the target type of the cast expression' From 87fb1759ac4ae7ae74b1d0abf6a65f9f76bd0a82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 25 Mar 2020 17:43:25 +0100 Subject: [PATCH 6/6] rebase --- python/pyarrow/tests/test_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 24f306fe68b..e4fa9720d2e 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -19,7 +19,6 @@ import operator import os import pickle -import urllib import numpy as np import pytest