Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@

#include "arrow/compute/cast_internal.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/registry.h"

namespace arrow {
namespace compute {

namespace internal {

std::unordered_map<int, std::shared_ptr<const CastFunction>> g_cast_table;
std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table;
static std::once_flag cast_table_initialized;

void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) {
Expand All @@ -51,6 +52,38 @@ void InitCastTable() {

void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); }

// A function that overrides Function::Execute to dispatch to the appropriate
// target-type-specific CastFunction
//
// This corresponds to the standard SQL CAST(expr AS target_type)
//
// As a "metafunction" this function has no kernels and is intended to be used
// through its Execute function
class CastMetaFunction : public ScalarFunction {
public:
CastMetaFunction() : ScalarFunction("cast", Arity::Unary()) {}

Result<Datum> Execute(const std::vector<Datum>& args, const FunctionOptions* options,
ExecContext* ctx) const override {
auto cast_options = static_cast<const CastOptions*>(options);
if (cast_options == nullptr || cast_options->to_type == nullptr) {
return Status::Invalid(
"Cast requires that options be passed with "
"the to_type populated");
}
if (args[0].type()->Equals(*cast_options->to_type)) {
return args[0];
}
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<CastFunction> cast_func,
GetCastFunction(cast_options->to_type));
return cast_func->Execute(args, options, ctx);
}
};

void RegisterScalarCast(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>()));
}

} // namespace internal

struct CastFunction::CastFunctionImpl {
Expand Down Expand Up @@ -138,16 +171,15 @@ Result<const ScalarKernel*> CastFunction::DispatchExact(
}
}

Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
return CallFunction("cast", {value}, &options, ctx);
}

Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type,
const CastOptions& options, ExecContext* ctx) {
if (value.type()->Equals(*to_type)) {
return value;
}
CastOptions options_with_to_type = options;
options_with_to_type.to_type = to_type;
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<const CastFunction> cast_func,
GetCastFunction(to_type));
return cast_func->Execute({Datum(value)}, &options_with_to_type, ctx);
return Cast(value, options_with_to_type, ctx);
}

Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type,
Expand All @@ -156,7 +188,7 @@ Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType
return result.make_array();
}

Result<std::shared_ptr<const CastFunction>> GetCastFunction(
Result<std::shared_ptr<CastFunction>> GetCastFunction(
const std::shared_ptr<DataType>& to_type) {
internal::EnsureInitCastTable();
auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
Expand All @@ -169,6 +201,7 @@ Result<std::shared_ptr<const CastFunction>> GetCastFunction(

bool CanCast(const DataType& from_type, const DataType& to_type) {
// TODO
internal::EnsureInitCastTable();
auto it = internal::g_cast_table.find(static_cast<int>(from_type.id()));
if (it == internal::g_cast_table.end()) {
return false;
Expand Down
22 changes: 17 additions & 5 deletions cpp/src/arrow/compute/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class CastFunction : public ScalarFunction {
};

ARROW_EXPORT
Result<std::shared_ptr<const CastFunction>> GetCastFunction(
Result<std::shared_ptr<CastFunction>> GetCastFunction(
const std::shared_ptr<DataType>& to_type);

/// \brief Return true if a cast function is defined
Expand All @@ -117,29 +117,41 @@ bool CanCast(const DataType& from_type, const DataType& to_type);
/// \param[in] value array to cast
/// \param[in] to_type type to cast to
/// \param[in] options casting options
/// \param[in] context the function execution context, optional
/// \param[in] ctx the function execution context, optional
/// \return the resulting array
///
/// \since 1.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type,
const CastOptions& options = CastOptions::Safe(),
ExecContext* context = NULLPTR);
ExecContext* ctx = NULLPTR);

/// \brief Cast from one array type to another
/// \param[in] value array to cast
/// \param[in] options casting options. The "to_type" field must be populated
/// \param[in] ctx the function execution context, optional
/// \return the resulting array
///
/// \since 1.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Cast(const Datum& value, const CastOptions& options,
ExecContext* ctx = NULLPTR);

/// \brief Cast from one value to another
/// \param[in] value datum to cast
/// \param[in] to_type type to cast to
/// \param[in] options casting options
/// \param[in] context the function execution context, optional
/// \param[in] ctx the function execution context, optional
/// \return the resulting datum
///
/// \since 1.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type,
const CastOptions& options = CastOptions::Safe(),
ExecContext* context = NULLPTR);
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
7 changes: 5 additions & 2 deletions cpp/src/arrow/compute/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ class ARROW_EXPORT Function {

/// \brief Convenience for invoking a function with kernel dispatch and
/// memory allocation details taken care of
Result<Datum> Execute(const std::vector<Datum>& args, const FunctionOptions* options,
ExecContext* ctx = NULLPTR) const;
///
/// This function can be overridden in subclasses
virtual Result<Datum> Execute(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx = NULLPTR) const;

protected:
Function(std::string name, Function::Kind kind, const Arity& arity)
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
// Scalar functions
RegisterScalarArithmetic(registry.get());
RegisterScalarBoolean(registry.get());
RegisterScalarCast(registry.get());
RegisterScalarComparison(registry.get());
RegisterScalarSetLookup(registry.get());

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/registry_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace internal {
// Built-in scalar / elementwise functions
void RegisterScalarArithmetic(FunctionRegistry* registry);
void RegisterScalarBoolean(FunctionRegistry* registry);
void RegisterScalarCast(FunctionRegistry* registry);
void RegisterScalarComparison(FunctionRegistry* registry);
void RegisterScalarSetLookup(FunctionRegistry* registry);

Expand Down
4 changes: 2 additions & 2 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ endif()
set(CYTHON_EXTENSIONS
lib
_fs
_compute
_csv
_json
_compute)
_json)

set(LINK_LIBS arrow_shared arrow_python_shared)

Expand Down
37 changes: 37 additions & 0 deletions python/pyarrow/_compute.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# cython: language_level = 3

from pyarrow.lib cimport *
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *


cdef class FunctionOptions:

cdef const CFunctionOptions* get_options(self) except NULL


cdef class CastOptions(FunctionOptions):
cdef:
CCastOptions options

@staticmethod
cdef wrap(CCastOptions options)

cdef inline CCastOptions unwrap(self) nogil
123 changes: 97 additions & 26 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,9 @@

# cython: language_level = 3

from pyarrow.lib cimport (
Array,
wrap_datum,
check_status,
ChunkedArray,
ScalarValue
)
from pyarrow.compat import frombytes, tobytes, ordered_dict
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.common cimport *

from pyarrow.compat import frombytes, tobytes


cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
cdef ScalarFunction func = ScalarFunction.__new__(ScalarFunction)
Expand Down Expand Up @@ -169,16 +160,16 @@ num_kernels: {}
def num_kernels(self):
return self.base_func.num_kernels()

def call(self, args, options=None):
def call(self, args, FunctionOptions options=None):
cdef:
const CFunctionOptions* c_options = NULL
vector[CDatum] c_args
CDatum result

_pack_compute_args(args, &c_args)

if isinstance(options, FunctionOptions):
c_options = (<FunctionOptions> options).options()
if options is not None:
c_options = options.get_options()

with nogil:
result = GetResultValue(self.base_func.Execute(c_args, c_options))
Expand Down Expand Up @@ -273,26 +264,106 @@ def call_function(name, args, options=None):

cdef class FunctionOptions:

cdef const CFunctionOptions* options(self) except NULL:
cdef const CFunctionOptions* get_options(self) except NULL:
raise NotImplementedError("Unimplemented base options")


cdef class CastOptions(FunctionOptions):
cdef:
CCastOptions cast_options

__slots__ = () # avoid mistakingly creating attributes

def __init__(self, DataType target_type=None, allow_int_overflow=None,
allow_time_truncate=None, allow_time_overflow=None,
allow_float_truncate=None, allow_invalid_utf8=None):
if allow_int_overflow is not None:
self.allow_int_overflow = allow_int_overflow
if allow_time_truncate is not None:
self.allow_time_truncate = allow_time_truncate
if allow_time_overflow is not None:
self.allow_time_overflow = allow_time_overflow
if allow_float_truncate is not None:
self.allow_float_truncate = allow_float_truncate
if allow_invalid_utf8 is not None:
self.allow_invalid_utf8 = allow_invalid_utf8

cdef const CFunctionOptions* get_options(self) except NULL:
return &self.options

@staticmethod
cdef wrap(CCastOptions options):
cdef CastOptions self = CastOptions.__new__(CastOptions)
self.options = options
return self

cdef inline CCastOptions unwrap(self) nogil:
return self.options

@staticmethod
def safe():
cdef CastOptions options = CastOptions()
options.cast_options = CCastOptions.Safe()
def safe(target_type=None):
options = CastOptions.wrap(CCastOptions.Safe())
options._set_type(target_type)
return options

@staticmethod
def unsafe():
cdef CastOptions options = CastOptions()
options.cast_options = CCastOptions.Unsafe()
def unsafe(target_type=None):
options = CastOptions.wrap(CCastOptions.Unsafe())
options._set_type(target_type)
return options

def _set_type(self, target_type=None):
if target_type is not None:
self.options.to_type = (
(<DataType> ensure_type(target_type)).sp_type
)

def is_safe(self):
return not (
self.options.allow_int_overflow or
self.options.allow_time_truncate or
self.options.allow_time_overflow or
self.options.allow_float_truncate or
self.options.allow_invalid_utf8
)

@property
def allow_int_overflow(self):
return self.options.allow_int_overflow

@allow_int_overflow.setter
def allow_int_overflow(self, bint flag):
self.options.allow_int_overflow = flag

@property
def allow_time_truncate(self):
return self.options.allow_time_truncate

@allow_time_truncate.setter
def allow_time_truncate(self, bint flag):
self.options.allow_time_truncate = flag

@property
def allow_time_overflow(self):
return self.options.allow_time_overflow

@allow_time_overflow.setter
def allow_time_overflow(self, bint flag):
self.options.allow_time_overflow = flag

@property
def allow_float_truncate(self):
return self.options.allow_float_truncate

@allow_float_truncate.setter
def allow_float_truncate(self, bint flag):
self.options.allow_float_truncate = flag

@property
def allow_invalid_utf8(self):
return self.options.allow_invalid_utf8

cdef const CFunctionOptions* options(self) except NULL:
return &self.cast_options
@allow_invalid_utf8.setter
def allow_invalid_utf8(self, bint flag):
self.options.allow_invalid_utf8 = flag


cdef class FilterOptions(FunctionOptions):
Expand All @@ -314,5 +385,5 @@ cdef class FilterOptions(FunctionOptions):
null_selection_behavior)
)

cdef const CFunctionOptions* options(self) except NULL:
cdef const CFunctionOptions* get_options(self) except NULL:
return &self.filter_options
Loading