Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/array/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,12 @@ static ScalarVector GetScalars() {
sparse_union_ty),
std::make_shared<SparseUnionScalar>(std::make_shared<Int32Scalar>(100), 42,
sparse_union_ty),
std::make_shared<SparseUnionScalar>(42, sparse_union_ty),
std::make_shared<DenseUnionScalar>(std::make_shared<Int32Scalar>(101), 6,
dense_union_ty),
std::make_shared<DenseUnionScalar>(std::make_shared<Int32Scalar>(101), 42,
dense_union_ty),
std::make_shared<DenseUnionScalar>(42, dense_union_ty),
DictionaryScalar::Make(ScalarFromJSON(int8(), "1"),
ArrayFromJSON(utf8(), R"(["foo", "bar"])")),
DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"),
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/array/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,8 @@ Result<std::shared_ptr<Array>> MakeArrayOfNull(const std::shared_ptr<DataType>&

Result<std::shared_ptr<Array>> MakeArrayFromScalar(const Scalar& scalar, int64_t length,
MemoryPool* pool) {
if (!scalar.is_valid) {
// Null union scalars still have a type code associated
if (!scalar.is_valid && !is_union(scalar.type->id())) {
return MakeArrayOfNull(scalar.type, length, pool);
}
return RepeatedArrayFactory(pool, scalar, length).Create();
Expand Down
50 changes: 33 additions & 17 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,45 @@ Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& desc
return result;
}

Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
ValueDescr result = descrs.back();
result.shape = GetBroadcastShape(descrs);
return result;
}

Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
}

void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
for (ValueDescr& descr : *descrs) {
if (descr.type->id() == Type::DICTIONARY) {
descr.type = checked_cast<const DictionaryType&>(*descr.type).value_type();
EnsureDictionaryDecoded(descrs->data(), descrs->size());
}

void EnsureDictionaryDecoded(ValueDescr* begin, size_t count) {
auto* end = begin + count;
for (auto it = begin; it != end; it++) {
if (it->type->id() == Type::DICTIONARY) {
it->type = checked_cast<const DictionaryType&>(*it->type).value_type();
}
}
}

void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs) {
DCHECK_EQ(descrs->size(), 2);
ReplaceNullWithOtherType(descrs->data(), descrs->size());
}

void ReplaceNullWithOtherType(ValueDescr* first, size_t count) {
DCHECK_EQ(count, 2);

if (descrs->at(0).type->id() == Type::NA) {
descrs->at(0).type = descrs->at(1).type;
ValueDescr* second = first++;
if (first->type->id() == Type::NA) {
first->type = second->type;
return;
}

if (descrs->at(1).type->id() == Type::NA) {
descrs->at(1).type = descrs->at(0).type;
if (second->type->id() == Type::NA) {
second->type = first->type;
return;
}
}
Expand Down Expand Up @@ -164,14 +180,15 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}

std::shared_ptr<DataType> CommonTemporal(const std::vector<ValueDescr>& descrs) {
std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const std::string* timezone = nullptr;
bool saw_date32 = false;
bool saw_date64 = false;

for (const auto& descr : descrs) {
auto id = descr.type->id();
const ValueDescr* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
// a common timestamp is only possible if all types are timestamp like
switch (id) {
case Type::DATE32:
Expand All @@ -183,9 +200,7 @@ std::shared_ptr<DataType> CommonTemporal(const std::vector<ValueDescr>& descrs)
saw_date64 = true;
continue;
case Type::TIMESTAMP: {
const auto& ty = checked_cast<const TimestampType&>(*descr.type);
// Don't cast to common timezone by default (may not make
// sense for all kernels)
const auto& ty = checked_cast<const TimestampType&>(*it->type);
if (timezone && *timezone != ty.timezone()) return nullptr;
timezone = &ty.timezone();
finest_unit = std::max(finest_unit, ty.unit());
Expand All @@ -207,11 +222,12 @@ std::shared_ptr<DataType> CommonTemporal(const std::vector<ValueDescr>& descrs)
return nullptr;
}

std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs) {
std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count) {
bool all_utf8 = true, all_offset32 = true;

for (const auto& descr : descrs) {
auto id = descr.type->id();
const ValueDescr* end = begin + count;
for (auto it = begin; it != end; ++it) {
auto id = it->type->id();
// a common varbinary type is only possible if all types are binary like
switch (id) {
case Type::STRING:
Expand Down
15 changes: 11 additions & 4 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,14 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {

template <>
struct UnboxScalar<Decimal128Type> {
static Decimal128 Unbox(const Scalar& val) {
static const Decimal128& Unbox(const Scalar& val) {
return checked_cast<const Decimal128Scalar&>(val).value;
}
};

template <>
struct UnboxScalar<Decimal256Type> {
static Decimal256 Unbox(const Scalar& val) {
static const Decimal256& Unbox(const Scalar& val) {
return checked_cast<const Decimal256Scalar&>(val).value;
}
};
Expand Down Expand Up @@ -397,6 +397,7 @@ static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& ar
// Reusable type resolvers

Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs);
Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs);
Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args);

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -1279,9 +1280,15 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
ARROW_EXPORT
void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs);

ARROW_EXPORT
void EnsureDictionaryDecoded(ValueDescr* begin, size_t count);

ARROW_EXPORT
void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs);

ARROW_EXPORT
void ReplaceNullWithOtherType(ValueDescr* begin, size_t count);

ARROW_EXPORT
void ReplaceTypes(const std::shared_ptr<DataType>&, std::vector<ValueDescr>* descrs);

Expand All @@ -1295,10 +1302,10 @@ ARROW_EXPORT
std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count);

ARROW_EXPORT
std::shared_ptr<DataType> CommonTemporal(const std::vector<ValueDescr>& descrs);
std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count);

ARROW_EXPORT
std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs);
std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count);

/// How to promote decimal precision/scale in CastBinaryDecimalArgs.
enum class DecimalPromotion : uint8_t {
Expand Down
42 changes: 25 additions & 17 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,32 @@ TEST(TestDispatchBest, CastDecimalArgs) {
}

TEST(TestDispatchBest, CommonTemporal) {
AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal({timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::NANO)}));
std::vector<ValueDescr> args;

args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)};
AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::NANO, "UTC")};
AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
CommonTemporal({timestamp(TimeUnit::SECOND, "UTC"),
timestamp(TimeUnit::NANO, "UTC")}));
AssertTypeEqual(timestamp(TimeUnit::NANO),
CommonTemporal({date32(), timestamp(TimeUnit::NANO)}));
AssertTypeEqual(timestamp(TimeUnit::MILLI),
CommonTemporal({date64(), timestamp(TimeUnit::SECOND)}));
AssertTypeEqual(date32(), CommonTemporal({date32(), date32()}));
AssertTypeEqual(date64(), CommonTemporal({date64(), date64()}));
AssertTypeEqual(date64(), CommonTemporal({date32(), date64()}));
ASSERT_EQ(nullptr, CommonTemporal({}));
ASSERT_EQ(nullptr, CommonTemporal({float64(), int32()}));
ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND, "UTC")}));
ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"),
timestamp(TimeUnit::SECOND, "UTC")}));
CommonTemporal(args.data(), args.size()));
args = {date32(), timestamp(TimeUnit::NANO)};
AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
args = {date64(), timestamp(TimeUnit::SECOND)};
AssertTypeEqual(timestamp(TimeUnit::MILLI), CommonTemporal(args.data(), args.size()));
args = {date32(), date32()};
AssertTypeEqual(date32(), CommonTemporal(args.data(), args.size()));
args = {date64(), date64()};
AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size()));
args = {date32(), date64()};
AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size()));
args = {};
ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
args = {float64(), int32()};
ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC")};
ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND, "America/Phoenix"),
timestamp(TimeUnit::SECOND, "UTC")};
ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
}

} // namespace internal
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ struct CompareFunction : ScalarFunction {

if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
} else if (auto type = CommonTemporal(*values)) {
} else if (auto type = CommonTemporal(values->data(), values->size())) {
ReplaceTypes(type, values);
} else if (auto type = CommonBinary(*values)) {
} else if (auto type = CommonBinary(values->data(), values->size())) {
ReplaceTypes(type, values);
}

Expand All @@ -195,7 +195,7 @@ struct VarArgsCompareFunction : ScalarFunction {

if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
} else if (auto type = CommonTemporal(*values)) {
} else if (auto type = CommonTemporal(values->data(), values->size())) {
ReplaceTypes(type, values);
}

Expand Down
Loading