Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ static auto kSliceOptionsType = GetFunctionOptionsType<SliceOptions>(
DataMember("step", &SliceOptions::step));
static auto kCompareOptionsType =
GetFunctionOptionsType<CompareOptions>(DataMember("op", &CompareOptions::op));
static auto kProjectOptionsType = GetFunctionOptionsType<ProjectOptions>(
DataMember("field_names", &ProjectOptions::field_names),
DataMember("field_nullability", &ProjectOptions::field_nullability),
DataMember("field_metadata", &ProjectOptions::field_metadata));
static auto kMakeStructOptionsType = GetFunctionOptionsType<MakeStructOptions>(
DataMember("field_names", &MakeStructOptions::field_names),
DataMember("field_nullability", &MakeStructOptions::field_nullability),
DataMember("field_metadata", &MakeStructOptions::field_metadata));
static auto kDayOfWeekOptionsType = GetFunctionOptionsType<DayOfWeekOptions>(
DataMember("one_based_numbering", &DayOfWeekOptions::one_based_numbering),
DataMember("week_start", &DayOfWeekOptions::week_start));
Expand Down Expand Up @@ -265,21 +265,22 @@ CompareOptions::CompareOptions(CompareOperator op)
CompareOptions::CompareOptions() : CompareOptions(CompareOperator::EQUAL) {}
constexpr char CompareOptions::kTypeName[];

ProjectOptions::ProjectOptions(std::vector<std::string> n, std::vector<bool> r,
std::vector<std::shared_ptr<const KeyValueMetadata>> m)
: FunctionOptions(internal::kProjectOptionsType),
MakeStructOptions::MakeStructOptions(
std::vector<std::string> n, std::vector<bool> r,
std::vector<std::shared_ptr<const KeyValueMetadata>> m)
: FunctionOptions(internal::kMakeStructOptionsType),
field_names(std::move(n)),
field_nullability(std::move(r)),
field_metadata(std::move(m)) {}

ProjectOptions::ProjectOptions(std::vector<std::string> n)
: FunctionOptions(internal::kProjectOptionsType),
MakeStructOptions::MakeStructOptions(std::vector<std::string> n)
: FunctionOptions(internal::kMakeStructOptionsType),
field_names(std::move(n)),
field_nullability(field_names.size(), true),
field_metadata(field_names.size(), NULLPTR) {}

ProjectOptions::ProjectOptions() : ProjectOptions(std::vector<std::string>()) {}
constexpr char ProjectOptions::kTypeName[];
MakeStructOptions::MakeStructOptions() : MakeStructOptions(std::vector<std::string>()) {}
constexpr char MakeStructOptions::kTypeName[];

DayOfWeekOptions::DayOfWeekOptions(bool one_based_numbering, uint32_t week_start)
: FunctionOptions(internal::kDayOfWeekOptionsType),
Expand All @@ -304,7 +305,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kTrimOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSliceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCompareOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kProjectOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType));
}
} // namespace internal
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ class ARROW_EXPORT CompareOptions : public FunctionOptions {
enum CompareOperator op;
};

class ARROW_EXPORT ProjectOptions : public FunctionOptions {
class ARROW_EXPORT MakeStructOptions : public FunctionOptions {
public:
ProjectOptions(std::vector<std::string> n, std::vector<bool> r,
std::vector<std::shared_ptr<const KeyValueMetadata>> m);
explicit ProjectOptions(std::vector<std::string> n);
ProjectOptions();
constexpr static char const kTypeName[] = "ProjectOptions";
MakeStructOptions(std::vector<std::string> n, std::vector<bool> r,
std::vector<std::shared_ptr<const KeyValueMetadata>> m);
explicit MakeStructOptions(std::vector<std::string> n);
MakeStructOptions();
constexpr static char const kTypeName[] = "MakeStructOptions";

/// Names for wrapped columns
std::vector<std::string> field_names;
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/exec/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ std::string Expression::ToString() const {
return binary(std::move(op));
}

if (auto options = GetProjectOptions(*call)) {
if (auto options = GetMakeStructOptions(*call)) {
std::string out = "{";
auto argument = call->arguments.begin();
for (const auto& field_name : options->field_names) {
Expand Down Expand Up @@ -1122,7 +1122,8 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
}

Expression project(std::vector<Expression> values, std::vector<std::string> names) {
return call("project", std::move(values), compute::ProjectOptions{std::move(names)});
return call("make_struct", std::move(values),
compute::MakeStructOptions{std::move(names)});
}

Expression equal(Expression lhs, Expression rhs) {
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/arrow/compute/exec/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ inline bool IsSetLookup(const std::string& function) {
return function == "is_in" || function == "index_in";
}

inline const compute::ProjectOptions* GetProjectOptions(const Expression::Call& call) {
if (call.function_name != "project") return nullptr;
return checked_cast<const compute::ProjectOptions*>(call.options.get());
inline const compute::MakeStructOptions* GetMakeStructOptions(
const Expression::Call& call) {
if (call.function_name != "make_struct") return nullptr;
return checked_cast<const compute::MakeStructOptions*>(call.options.get());
}

/// A helper for unboxing an Expression composed of associative function calls.
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new CompareOptions(CompareOperator::EQUAL));
options.emplace_back(new CompareOptions(CompareOperator::LESS));
// N.B. we never actually use field_nullability or field_metadata in Arrow
options.emplace_back(new ProjectOptions({"col1"}, {true}, {}));
options.emplace_back(new ProjectOptions({"col1"}, {false}, {}));
options.emplace_back(new MakeStructOptions({"col1"}, {true}, {}));
options.emplace_back(new MakeStructOptions({"col1"}, {false}, {}));
options.emplace_back(
new ProjectOptions({"col1"}, {false}, {key_value_metadata({{"key", "val"}})}));
new MakeStructOptions({"col1"}, {false}, {key_value_metadata({{"key", "val"}})}));
options.emplace_back(new DayOfWeekOptions(false, 1));
options.emplace_back(new CastOptions(CastOptions::Safe(boolean())));
options.emplace_back(new CastOptions(CastOptions::Unsafe(int64())));
Expand Down
8 changes: 1 addition & 7 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,7 @@ class TestCaseWhenNumeric : public ::testing::Test {};
TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes);

Datum MakeStruct(const std::vector<Datum>& conds) {
ProjectOptions options;
options.field_names.resize(conds.size());
options.field_metadata.resize(conds.size());
for (const auto& datum : conds) {
options.field_nullability.push_back(datum.null_count() > 0);
}
EXPECT_OK_AND_ASSIGN(auto result, CallFunction("project", conds, &options));
EXPECT_OK_AND_ASSIGN(auto result, CallFunction("make_struct", conds));
return result;
}

Expand Down
57 changes: 34 additions & 23 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,23 @@ const FunctionDoc list_value_length_doc{
"Null values emit a null in the output."),
{"lists"}};

Result<ValueDescr> ProjectResolve(KernelContext* ctx,
const std::vector<ValueDescr>& descrs) {
const auto& names = OptionsWrapper<ProjectOptions>::Get(ctx).field_names;
const auto& nullable = OptionsWrapper<ProjectOptions>::Get(ctx).field_nullability;
const auto& metadata = OptionsWrapper<ProjectOptions>::Get(ctx).field_metadata;

if (names.size() != descrs.size() || nullable.size() != descrs.size() ||
metadata.size() != descrs.size()) {
return Status::Invalid("project() was passed ", descrs.size(), " arguments but ",
Result<ValueDescr> MakeStructResolve(KernelContext* ctx,
const std::vector<ValueDescr>& descrs) {
auto names = OptionsWrapper<MakeStructOptions>::Get(ctx).field_names;
auto nullable = OptionsWrapper<MakeStructOptions>::Get(ctx).field_nullability;
auto metadata = OptionsWrapper<MakeStructOptions>::Get(ctx).field_metadata;

if (names.size() == 0) {
names.resize(descrs.size());
nullable.resize(descrs.size(), true);
metadata.resize(descrs.size(), nullptr);
int i = 0;
for (auto& name : names) {
name = std::to_string(i++);
}
} else if (names.size() != descrs.size() || nullable.size() != descrs.size() ||
metadata.size() != descrs.size()) {
return Status::Invalid("make_struct() was passed ", descrs.size(), " arguments but ",
names.size(), " field names, ", nullable.size(),
" nullability bits, and ", metadata.size(),
" metadata dictionaries.");
Expand All @@ -94,15 +102,16 @@ Result<ValueDescr> ProjectResolve(KernelContext* ctx,
}
}

fields[i] = field(names[i], descr.type, nullable[i], metadata[i]);
fields[i] =
field(std::move(names[i]), descr.type, nullable[i], std::move(metadata[i]));
++i;
}

return ValueDescr{struct_(std::move(fields)), shape};
}

Status ProjectExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ARROW_ASSIGN_OR_RAISE(auto descr, ProjectResolve(ctx, batch.GetDescriptors()));
Status MakeStructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ARROW_ASSIGN_OR_RAISE(auto descr, MakeStructResolve(ctx, batch.GetDescriptors()));

for (int i = 0; i < batch.num_values(); ++i) {
const auto& field = checked_cast<const StructType&>(*descr.type).field(i);
Expand Down Expand Up @@ -139,11 +148,11 @@ Status ProjectExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
return Status::OK();
}

const FunctionDoc project_doc{"Wrap Arrays into a StructArray",
("Names of the StructArray's fields are\n"
"specified through ProjectOptions."),
{"*args"},
"ProjectOptions"};
const FunctionDoc make_struct_doc{"Wrap Arrays into a StructArray",
("Names of the StructArray's fields are\n"
"specified through MakeStructOptions."),
{"*args"},
"MakeStructOptions"};

} // namespace

Expand All @@ -156,15 +165,17 @@ void RegisterScalarNested(FunctionRegistry* registry) {
ListValueLength<LargeListType>));
DCHECK_OK(registry->AddFunction(std::move(list_value_length)));

auto project_function =
std::make_shared<ScalarFunction>("project", Arity::VarArgs(), &project_doc);
ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{ProjectResolve},
static MakeStructOptions kDefaultMakeStructOptions;
auto make_struct_function = std::make_shared<ScalarFunction>(
"make_struct", Arity::VarArgs(), &make_struct_doc, &kDefaultMakeStructOptions);

ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{MakeStructResolve},
/*is_varargs=*/true),
ProjectExec, OptionsWrapper<ProjectOptions>::Init};
MakeStructExec, OptionsWrapper<MakeStructOptions>::Init};
kernel.null_handling = NullHandling::OUTPUT_NOT_NULL;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(project_function->AddKernel(std::move(kernel)));
DCHECK_OK(registry->AddFunction(std::move(project_function)));
DCHECK_OK(make_struct_function->AddKernel(std::move(kernel)));
DCHECK_OK(registry->AddFunction(std::move(make_struct_function)));
}

} // namespace internal
Expand Down
65 changes: 37 additions & 28 deletions cpp/src/arrow/compute/kernels/scalar_nested_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "arrow/compute/kernels/test_util.h"
#include "arrow/result.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
#include "arrow/util/key_value_metadata.h"

namespace arrow {
Expand All @@ -39,48 +40,55 @@ TEST(TestScalarNested, ListValueLength) {
}

struct {
Result<Datum> operator()(std::vector<Datum> args) {
return CallFunction("make_struct", args);
}

template <typename... Options>
Result<Datum> operator()(std::vector<Datum> args, std::vector<std::string> field_names,
Options... options) {
ProjectOptions opts{field_names, options...};
return CallFunction("project", args, &opts);
MakeStructOptions opts{field_names, options...};
return CallFunction("make_struct", args, &opts);
}
} Project;
} MakeStruct;

TEST(Project, Scalar) {
TEST(MakeStruct, Scalar) {
auto i32 = MakeScalar(1);
auto f64 = MakeScalar(2.5);
auto str = MakeScalar("yo");

ASSERT_OK_AND_ASSIGN(auto expected,
StructScalar::Make({i32, f64, str}, {"i", "f", "s"}));
ASSERT_OK_AND_EQ(Datum(expected), Project({i32, f64, str}, {"i", "f", "s"}));
EXPECT_THAT(MakeStruct({i32, f64, str}, {"i", "f", "s"}),
ResultWith(Datum(*StructScalar::Make({i32, f64, str}, {"i", "f", "s"}))));

// Three field names but one input value
ASSERT_RAISES(Invalid, Project({str}, {"i", "f", "s"}));
// Names default to field_index
EXPECT_THAT(MakeStruct({i32, f64, str}),
ResultWith(Datum(*StructScalar::Make({i32, f64, str}, {"0", "1", "2"}))));

// No field names or input values is fine
expected.reset(new StructScalar{{}, struct_({})});
ASSERT_OK_AND_EQ(Datum(expected), Project(/*args=*/{}, /*field_names=*/{}));
EXPECT_THAT(MakeStruct({}), ResultWith(Datum(*StructScalar::Make({}, {}))));

// Three field names but one input value
EXPECT_THAT(MakeStruct({str}, {"i", "f", "s"}), Raises(StatusCode::Invalid));
}

TEST(Project, Array) {
TEST(MakeStruct, Array) {
std::vector<std::string> field_names{"i", "s"};

auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]");
auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
ASSERT_OK_AND_ASSIGN(Datum expected, StructArray::Make({i32, str}, field_names));

ASSERT_OK_AND_EQ(expected, Project({i32, str}, field_names));
EXPECT_THAT(MakeStruct({i32, str}, {"i", "s"}),
ResultWith(Datum(*StructArray::Make({i32, str}, field_names))));

// Scalars are broadcast to the length of the arrays
ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")}, field_names));
EXPECT_THAT(MakeStruct({i32, MakeScalar("aa")}, {"i", "s"}),
ResultWith(Datum(*StructArray::Make({i32, str}, field_names))));

// Array length mismatch
ASSERT_RAISES(Invalid, Project({i32->Slice(1), str}, field_names));
EXPECT_THAT(MakeStruct({i32->Slice(1), str}, field_names), Raises(StatusCode::Invalid));
}

TEST(Project, NullableMetadataPassedThru) {
TEST(MakeStruct, NullableMetadataPassedThru) {
auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]");
auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");

Expand All @@ -90,19 +98,20 @@ TEST(Project, NullableMetadataPassedThru) {
key_value_metadata({"a", "b"}, {"ALPHA", "BRAVO"}), nullptr};

ASSERT_OK_AND_ASSIGN(auto proj,
Project({i32, str}, field_names, nullability, metadata));
MakeStruct({i32, str}, field_names, nullability, metadata));

AssertTypeEqual(*proj.type(), StructType({
field("i", int32(), /*nullable=*/true, metadata[0]),
field("s", utf8(), /*nullable=*/false, nullptr),
}));

// error: projecting an array containing nulls with nullable=false
str = ArrayFromJSON(utf8(), R"(["aa", null, "aa"])");
ASSERT_RAISES(Invalid, Project({i32, str}, field_names, nullability, metadata));
EXPECT_THAT(MakeStruct({i32, ArrayFromJSON(utf8(), R"(["aa", null, "aa"])")},
field_names, nullability, metadata),
Raises(StatusCode::Invalid));
}

TEST(Project, ChunkedArray) {
TEST(MakeStruct, ChunkedArray) {
std::vector<std::string> field_names{"i", "s"};

auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]");
Expand All @@ -122,16 +131,16 @@ TEST(Project, ChunkedArray) {
ASSERT_OK_AND_ASSIGN(Datum expected,
ChunkedArray::Make({expected_0, expected_1, expected_2}));

ASSERT_OK_AND_EQ(expected, Project({i32, str}, field_names));
ASSERT_OK_AND_EQ(expected, MakeStruct({i32, str}, field_names));

// Scalars are broadcast to the length of the arrays
ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")}, field_names));
ASSERT_OK_AND_EQ(expected, MakeStruct({i32, MakeScalar("aa")}, field_names));

// Array length mismatch
ASSERT_RAISES(Invalid, Project({i32->Slice(1), str}, field_names));
ASSERT_RAISES(Invalid, MakeStruct({i32->Slice(1), str}, field_names));
}

TEST(Project, ChunkedArrayDifferentChunking) {
TEST(MakeStruct, ChunkedArrayDifferentChunking) {
std::vector<std::string> field_names{"i", "s"};

auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]");
Expand Down Expand Up @@ -159,13 +168,13 @@ TEST(Project, ChunkedArrayDifferentChunking) {

ASSERT_OK_AND_ASSIGN(Datum expected, ChunkedArray::Make(expected_chunks));

ASSERT_OK_AND_EQ(expected, Project({i32, str}, field_names));
ASSERT_OK_AND_EQ(expected, MakeStruct({i32, str}, field_names));

// Scalars are broadcast to the length of the arrays
ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")}, field_names));
ASSERT_OK_AND_EQ(expected, MakeStruct({i32, MakeScalar("aa")}, field_names));

// Array length mismatch
ASSERT_RAISES(Invalid, Project({i32->Slice(1), str}, field_names));
ASSERT_RAISES(Invalid, MakeStruct({i32->Slice(1), str}, field_names));
}

} // namespace compute
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ Result<EnumeratedRecordBatchGenerator> AsyncScanner::ScanBatchesUnorderedAsync(
compute::MakeFilterNode(scan, "filter", scan_options_->filter));

auto exprs = scan_options_->projection.call()->arguments;
auto names = checked_cast<const compute::ProjectOptions*>(
auto names = checked_cast<const compute::MakeStructOptions*>(
scan_options_->projection.call()->options.get())
->field_names;
ARROW_ASSIGN_OR_RAISE(
Expand Down
Loading