Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,32 @@ void AddListCast(CastFunction* func) {
struct CastStruct {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const CastOptions& options = CastState::Get(ctx);
const StructType& in_type = checked_cast<const StructType&>(*batch[0].type());
const StructType& out_type = checked_cast<const StructType&>(*out->type());
const auto in_field_count = in_type.num_fields();

if (in_field_count != out_type.num_fields()) {
return Status::TypeError("struct field sizes do not match: ", in_type.ToString(),
" ", out_type.ToString());
}

for (int i = 0; i < in_field_count; ++i) {
const auto in_field = in_type.field(i);
const auto out_field = out_type.field(i);
if (in_field->name() != out_field->name()) {
return Status::TypeError("struct field names do not match: ", in_type.ToString(),
" ", out_type.ToString());
const auto& in_type = checked_cast<const StructType&>(*batch[0].type());
const auto& out_type = checked_cast<const StructType&>(*out->type());
const int in_field_count = in_type.num_fields();
const int out_field_count = out_type.num_fields();

std::vector<int> fields_to_select(out_field_count, -1);

int out_field_index = 0;
for (int in_field_index = 0;
in_field_index < in_field_count && out_field_index < out_field_count;
++in_field_index) {
const auto& in_field = in_type.field(in_field_index);
const auto& out_field = out_type.field(out_field_index);
if (in_field->name() == out_field->name()) {
if (in_field->nullable() && !out_field->nullable()) {
return Status::TypeError("cannot cast nullable field to non-nullable field: ",
in_type.ToString(), " ", out_type.ToString());
}
fields_to_select[out_field_index++] = in_field_index;
}
}

if (in_field->nullable() && !out_field->nullable()) {
return Status::TypeError("cannot cast nullable struct to non-nullable struct: ",
in_type.ToString(), " ", out_type.ToString());
}
if (out_field_index < out_field_count) {
return Status::TypeError(
"struct fields don't match or are in the wrong order: Input fields: ",
in_type.ToString(), " output fields: ", out_type.ToString());
}

if (out->kind() == Datum::SCALAR) {
Expand All @@ -182,9 +187,10 @@ struct CastStruct {

DCHECK(!out_scalar->is_valid);
if (in_scalar.is_valid) {
for (int i = 0; i < in_field_count; i++) {
auto values = in_scalar.value[i];
auto target_type = out->type()->field(i)->type();
out_field_index = 0;
for (int field_index : fields_to_select) {
const auto& values = in_scalar.value[field_index];
const auto& target_type = out->type()->field(out_field_index++)->type();
ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, target_type, options, ctx->exec_context()));
DCHECK_EQ(Datum::SCALAR, cast_values.kind());
Expand All @@ -204,9 +210,11 @@ struct CastStruct {
in_array.offset, in_array.length));
}

for (int i = 0; i < in_field_count; ++i) {
auto values = in_array.child_data[i]->Slice(in_array.offset, in_array.length);
auto target_type = out->type()->field(i)->type();
out_field_index = 0;
for (int field_index : fields_to_select) {
const auto& values =
in_array.child_data[field_index]->Slice(in_array.offset, in_array.length);
const auto& target_type = out->type()->field(out_field_index++)->type();

ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, target_type, options, ctx->exec_context()));
Expand Down
Loading