Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1fb628b
ARROW-8919: [C++][Compute] Add Function::DispatchBest
bkietz Jan 12, 2021
e6f0840
support implicit casts in Function::Execute, CallFunction
bkietz Jan 14, 2021
5cdd710
first pass at integrating DispatchBest into Expressions
bkietz Jan 15, 2021
6385032
add DispatchBest to SetLookup kernels
bkietz Jan 15, 2021
058e15a
repair implicit cast is_in execution test
bkietz Jan 19, 2021
b8525a4
add support for null -> * cast to arithmetic and compare
bkietz Jan 19, 2021
e60e555
use explicit schema to avoid inferring bool as str
bkietz Jan 19, 2021
2528d95
apply implicit casts to R binding
bkietz Jan 19, 2021
ecd778c
ensure value_set is cast to the input type
bkietz Jan 21, 2021
003ef40
always check for an exact match first
bkietz Jan 21, 2021
7ebb067
add implicit cast between timestamp-like types to comparison
bkietz Jan 21, 2021
8100d21
support dictionary(X) -> Y casts if X -> Y
bkietz Jan 22, 2021
c1de51d
describe implicit cast behavior in compute.rst
bkietz Jan 22, 2021
db5ae2f
msvc: linkage fix
bkietz Jan 22, 2021
0852305
review comments
bkietz Jan 27, 2021
ff9cde2
unskip implicit casting comparison test
bkietz Jan 27, 2021
c761233
Revert "unskip implicit casting comparison test"
bkietz Jan 27, 2021
dd68342
review comments
bkietz Feb 5, 2021
282dac5
expand common numeric type when signed/unsigned
bkietz Feb 5, 2021
66aa801
add test case for stripping casts from uint32 to signed integer types
bkietz Feb 8, 2021
62a6b5e
Nits + fix compile error (hopefully)
pitrou Feb 10, 2021
6ded65f
inline InitKernelState, ensure KernelInitArgs::inputs is bound to a n…
bkietz Feb 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,15 @@ void RegisterScalarCast(FunctionRegistry* registry) {

} // namespace internal

struct CastFunction::CastFunctionImpl {
Type::type out_type;
std::unordered_set<int> in_types;
};

CastFunction::CastFunction(std::string name, Type::type out_type)
: ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr) {
impl_.reset(new CastFunctionImpl());
impl_->out_type = out_type;
}

CastFunction::~CastFunction() = default;

Type::type CastFunction::out_type_id() const { return impl_->out_type; }
CastFunction::CastFunction(std::string name, Type::type out_type_id)
: ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr),
out_type_id_(out_type_id) {}

Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) {
// We use the same KernelInit for every cast
kernel.init = internal::CastState::Init;
RETURN_NOT_OK(ScalarFunction::AddKernel(kernel));
impl_->in_types.insert(static_cast<int>(in_type_id));
in_type_ids_.push_back(in_type_id);
return Status::OK();
}

Expand All @@ -159,19 +148,10 @@ Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_
return AddKernel(in_type_id, std::move(kernel));
}

bool CastFunction::CanCastTo(const DataType& out_type) const {
return impl_->in_types.find(static_cast<int>(out_type.id())) != impl_->in_types.end();
}

Result<const Kernel*> CastFunction::DispatchExact(
const std::vector<ValueDescr>& values) const {
const int passed_num_args = static_cast<int>(values.size());
RETURN_NOT_OK(CheckArity(values));

// Validate arity
if (passed_num_args != 1) {
return Status::Invalid("Cast functions accept 1 argument but passed ",
passed_num_args);
}
std::vector<const ScalarKernel*> candidate_kernels;
for (const auto& kernel : kernels_) {
if (kernel.signature->MatchesInputs(values)) {
Expand All @@ -181,25 +161,28 @@ Result<const Kernel*> CastFunction::DispatchExact(

if (candidate_kernels.size() == 0) {
return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(),
" to ", ToTypeName(impl_->out_type), " using function ",
" to ", ToTypeName(out_type_id_), " using function ",
this->name());
} else if (candidate_kernels.size() == 1) {
}

if (candidate_kernels.size() == 1) {
// One match, return it
return candidate_kernels[0];
} else {
// Now we are in a casting scenario where we may have both a EXACT_TYPE and
// a SAME_TYPE_ID. So we will see if there is an exact match among the
// candidate kernels and if not we will just return the first one
for (auto kernel : candidate_kernels) {
const InputType& arg0 = kernel->signature->in_types()[0];
if (arg0.kind() == InputType::EXACT_TYPE) {
// Bingo. Return it
return kernel;
}
}

// Now we are in a casting scenario where we may have both a EXACT_TYPE and
// a SAME_TYPE_ID. So we will see if there is an exact match among the
// candidate kernels and if not we will just return the first one
for (auto kernel : candidate_kernels) {
const InputType& arg0 = kernel->signature->in_types()[0];
if (arg0.kind() == InputType::EXACT_TYPE) {
// Bingo. Return it
return kernel;
}
// We didn't find an exact match. So just return some kernel that matches
return candidate_kernels[0];
}

// We didn't find an exact match. So just return some kernel that matches
return candidate_kernels[0];
}

Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
Expand All @@ -225,13 +208,37 @@ Result<std::shared_ptr<CastFunction>> GetCastFunction(
}

bool CanCast(const DataType& from_type, const DataType& to_type) {
// TODO
internal::EnsureInitCastTable();
auto it = internal::g_cast_table.find(static_cast<int>(from_type.id()));
auto it = internal::g_cast_table.find(static_cast<int>(to_type.id()));
if (it == internal::g_cast_table.end()) {
return false;
}
return it->second->CanCastTo(to_type);

const CastFunction* function = it->second.get();
DCHECK_EQ(function->out_type_id(), to_type.id());

for (auto from_id : function->in_type_ids()) {
// XXX should probably check the output type as well
if (from_type.id() == from_id) return true;
}

return false;
}

Result<std::vector<Datum>> Cast(std::vector<Datum> datums, std::vector<ValueDescr> descrs,
ExecContext* ctx) {
for (size_t i = 0; i != datums.size(); ++i) {
if (descrs[i] != datums[i].descr()) {
if (descrs[i].shape != datums[i].shape()) {
return Status::NotImplemented("casting between Datum shapes");
}

ARROW_ASSIGN_OR_RAISE(datums[i],
Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx));
}
}

return datums;
}

} // namespace compute
Expand Down
24 changes: 17 additions & 7 deletions cpp/src/arrow/compute/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions {
// the same execution machinery
class CastFunction : public ScalarFunction {
public:
CastFunction(std::string name, Type::type out_type);
~CastFunction() override;
CastFunction(std::string name, Type::type out_type_id);

Type::type out_type_id() const;
Type::type out_type_id() const { return out_type_id_; }
const std::vector<Type::type>& in_type_ids() const { return in_type_ids_; }

Status AddKernel(Type::type in_type_id, std::vector<InputType> in_types,
OutputType out_type, ArrayKernelExec exec,
Expand All @@ -96,14 +96,12 @@ class CastFunction : public ScalarFunction {
// function to CastInit
Status AddKernel(Type::type in_type_id, ScalarKernel kernel);

bool CanCastTo(const DataType& out_type) const;

Result<const Kernel*> DispatchExact(
const std::vector<ValueDescr>& values) const override;

private:
struct CastFunctionImpl;
std::unique_ptr<CastFunctionImpl> impl_;
std::vector<Type::type> in_type_ids_;
const Type::type out_type_id_;
};

ARROW_EXPORT
Expand Down Expand Up @@ -157,5 +155,17 @@ Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type,
const CastOptions& options = CastOptions::Safe(),
ExecContext* ctx = NULLPTR);

/// \brief Cast several values simultaneously. Safe cast options are used.
/// \param[in] values datums to cast
/// \param[in] descrs ValueDescrs to cast to
/// \param[in] ctx the function execution context, optional
/// \return the resulting datums
///
/// \since 4.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<std::vector<Datum>> Cast(std::vector<Datum> values, std::vector<ValueDescr> descrs,
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
Loading