From 93398c3cce1930331d6c30b2b3da9484b4a4af94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Tue, 25 Feb 2020 22:15:35 -0500 Subject: [PATCH 01/21] Draft engine component --- cpp/src/arrow/CMakeLists.txt | 6 +- cpp/src/arrow/engine/CMakeLists.txt | 26 +++ cpp/src/arrow/engine/catalog.cc | 153 +++++++++++++++++ cpp/src/arrow/engine/catalog.h | 96 +++++++++++ cpp/src/arrow/engine/catalog_test.cc | 135 +++++++++++++++ cpp/src/arrow/engine/expression.cc | 173 ++++++++++++++++++++ cpp/src/arrow/engine/expression.h | 191 ++++++++++++++++++++++ cpp/src/arrow/engine/expression_test.cc | 71 ++++++++ cpp/src/arrow/engine/logical_plan.cc | 106 ++++++++++++ cpp/src/arrow/engine/logical_plan.h | 85 ++++++++++ cpp/src/arrow/engine/logical_plan_test.cc | 63 +++++++ cpp/src/arrow/engine/type_traits.h | 35 ++++ cpp/src/arrow/testing/gtest_util.cc | 55 +++++++ cpp/src/arrow/testing/gtest_util.h | 3 + 14 files changed, 1197 insertions(+), 1 deletion(-) create mode 100644 cpp/src/arrow/engine/CMakeLists.txt create mode 100644 cpp/src/arrow/engine/catalog.cc create mode 100644 cpp/src/arrow/engine/catalog.h create mode 100644 cpp/src/arrow/engine/catalog_test.cc create mode 100644 cpp/src/arrow/engine/expression.cc create mode 100644 cpp/src/arrow/engine/expression.h create mode 100644 cpp/src/arrow/engine/expression_test.cc create mode 100644 cpp/src/arrow/engine/logical_plan.cc create mode 100644 cpp/src/arrow/engine/logical_plan.h create mode 100644 cpp/src/arrow/engine/logical_plan_test.cc create mode 100644 cpp/src/arrow/engine/type_traits.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 3454cd0c87d..7552337bc72 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -339,7 +339,10 @@ if(ARROW_COMPUTE) compute/kernels/match.cc compute/kernels/util_internal.cc compute/operations/cast.cc - compute/operations/literal.cc) + compute/operations/literal.cc + engine/catalog.cc + engine/expression.cc + engine/logical_plan.cc) endif() if(ARROW_FILESYSTEM) @@ -576,6 +579,7 @@ endif() if(ARROW_COMPUTE) add_subdirectory(compute) + add_subdirectory(engine) endif() if(ARROW_CUDA) diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt new file mode 100644 index 00000000000..adb9cd15598 --- /dev/null +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +arrow_install_all_headers("arrow/engine") + +# +# Unit tests +# + +add_arrow_test(catalog_test PREFIX arrow-engine) +add_arrow_test(expression_test PREFIX arrow-engine) +add_arrow_test(logical_plan_test PREFIX arrow-engine) diff --git a/cpp/src/arrow/engine/catalog.cc b/cpp/src/arrow/engine/catalog.cc new file mode 100644 index 00000000000..8381d502437 --- /dev/null +++ b/cpp/src/arrow/engine/catalog.cc @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/catalog.h" + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/table.h" + +#include "arrow/dataset/dataset.h" + +namespace arrow { +namespace engine { + +// +// Catalog +// + +Catalog::Catalog(std::unordered_map tables) : tables_(std::move(tables)) {} + +Result Catalog::Get(const Key& key) const { + auto value = tables_.find(key); + if (value != tables_.end()) return value->second; + return Status::KeyError("Table '", key, "' not found in catalog."); +} + +Result> Catalog::GetSchema(const Key& key) const { + auto as_schema = [](const Value& v) -> Result> { + return v.schema(); + }; + return Get(key).Map(as_schema); +} + +Result> Catalog::Make(const std::vector& tables) { + CatalogBuilder builder; + + for (const auto& key_val : tables) { + RETURN_NOT_OK(builder.Add(key_val)); + } + + return builder.Finish(); +} + +// +// Catalog::Entry +// + +using Entry = Catalog::Entry; + +Entry::Entry(std::shared_ptr table) : entry_(std::move(table)) {} +Entry::Entry(std::shared_ptr dataset) : entry_(std::move(dataset)) {} + +Entry::Kind Entry::kind() const { + if (util::holds_alternative>(entry_)) { + return TABLE; + } + + if (util::holds_alternative>(entry_)) { + return DATASET; + } + + return UNKNOWN; +} + +std::shared_ptr
Entry::table() const { + if (kind() == TABLE) return util::get>(entry_); + return nullptr; +} + +std::shared_ptr Entry::dataset() const { + if (kind() == DATASET) return util::get>(entry_); + return nullptr; +} + +std::shared_ptr Entry::schema() const { + switch (kind()) { + case TABLE: + return table()->schema(); + case DATASET: + return dataset()->schema(); + default: + return nullptr; + } + + return nullptr; +} + +// +// CatalogBuilder +// + +Status CatalogBuilder::Add(const Key& key, const Value& value) { + if (key.empty()) { + return Status::Invalid("Key in catalog can't be empty"); + } + + switch (value.kind()) { + case Entry::TABLE: { + if (value.table() == nullptr) { + return Status::Invalid("Table entry can't be null."); + } + break; + } + case Entry::DATASET: { + if (value.dataset() == nullptr) { + return Status::Invalid("Table entry can't be null."); + } + break; + } + default: + return Status::NotImplemented("Unknown entry kind"); + } + + auto inserted = tables_.insert({key, value}); + if (!inserted.second) { + return Status::KeyError("Table '", key, "' already in catalog."); + } + + return Status::OK(); +} + +Status CatalogBuilder::Add(const Key& key, std::shared_ptr
table) { + return Add(key, Entry(std::move(table))); +} + +Status CatalogBuilder::Add(const Key& key, std::shared_ptr dataset) { + return Add(key, Entry(std::move(dataset))); +} + +Status CatalogBuilder::Add(const KeyValue& key_value) { + return Add(key_value.first, key_value.second); +} + +Result> CatalogBuilder::Finish() { + return std::shared_ptr(new Catalog(std::move(tables_))); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h new file mode 100644 index 00000000000..3f111e306a1 --- /dev/null +++ b/cpp/src/arrow/engine/catalog.h @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/type_fwd.h" +#include "arrow/util/variant.h" + +namespace arrow { + +namespace dataset { +class Dataset; +} + +namespace engine { + +/// Catalog is made of named Table/Dataset to be referenced in LogicalPlans. +class Catalog { + public: + class Entry; + + using Key = std::string; + using Value = Entry; + using KeyValue = std::pair; + + static Result> Make(const std::vector& tables); + + Result Get(const Key& name) const; + Result> GetSchema(const Key& name) const; + + class Entry { + public: + enum Kind { + TABLE = 0, + DATASET, + UNKNOWN, + }; + + explicit Entry(std::shared_ptr
table); + explicit Entry(std::shared_ptr dataset); + + Kind kind() const; + std::shared_ptr
table() const; + std::shared_ptr dataset() const; + + std::shared_ptr schema() const; + + private: + util::variant, std::shared_ptr> entry_; + }; + + private: + friend class CatalogBuilder; + explicit Catalog(std::unordered_map tables); + + std::unordered_map tables_; +}; + +class CatalogBuilder { + public: + using Key = Catalog::Key; + using Value = Catalog::Value; + using KeyValue = Catalog::KeyValue; + + Status Add(const Key& key, const Value& value); + Status Add(const Key& key, std::shared_ptr
); + Status Add(const Key& key, std::shared_ptr); + Status Add(const KeyValue& key_value); + + Result> Finish(); + + private: + std::unordered_map tables_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/catalog_test.cc b/cpp/src/arrow/engine/catalog_test.cc new file mode 100644 index 00000000000..4542d049b98 --- /dev/null +++ b/cpp/src/arrow/engine/catalog_test.cc @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/testing/gtest_util.h" + +#include "arrow/engine/catalog.h" +#include "arrow/table.h" +#include "arrow/type.h" + +namespace arrow { +namespace engine { + +using Entry = Catalog::Entry; + +class TestCatalog : public testing::Test { + public: + std::shared_ptr schema_ = schema({field("f", int32())}); + std::shared_ptr
table(std::shared_ptr schema) const { + return MockTable(schema); + } + std::shared_ptr
table() const { return table(schema_); } +}; + +void AssertCatalogKeyIs(const std::shared_ptr& catalog, const Catalog::Key& key, + const std::shared_ptr
& expected) { + ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); + ASSERT_EQ(t.kind(), Catalog::Entry::Kind::TABLE); + AssertTablesEqual(*t.table(), *expected); + + ASSERT_OK_AND_ASSIGN(auto schema, catalog->GetSchema(key)); + AssertSchemaEqual(*schema, *expected->schema()); +} + +TEST_F(TestCatalog, EmptyCatalog) { + ASSERT_OK_AND_ASSIGN(auto empty_catalog, Catalog::Make({})); + ASSERT_RAISES(KeyError, empty_catalog->Get("")); + ASSERT_RAISES(KeyError, empty_catalog->Get("a_key")); +} + +TEST_F(TestCatalog, Make) { + auto key_1 = "a"; + auto table_1 = table(schema({field(key_1, int32())})); + auto key_2 = "b"; + auto table_2 = table(schema({field(key_2, int32())})); + auto key_3 = "c"; + auto table_3 = table(schema({field(key_3, int32())})); + + std::vector tables{ + {key_1, Entry(table_1)}, {key_2, Entry(table_2)}, {key_3, Entry(table_3)}}; + + ASSERT_OK_AND_ASSIGN(auto catalog, Catalog::Make(std::move(tables))); + AssertCatalogKeyIs(catalog, key_1, table_1); + AssertCatalogKeyIs(catalog, key_2, table_2); + AssertCatalogKeyIs(catalog, key_3, table_3); +} + +class TestCatalogBuilder : public TestCatalog {}; + +TEST_F(TestCatalogBuilder, EmptyCatalog) { + CatalogBuilder builder; + ASSERT_OK_AND_ASSIGN(auto empty_catalog, builder.Finish()); + ASSERT_RAISES(KeyError, empty_catalog->Get("a_key")); +} + +TEST_F(TestCatalogBuilder, Basic) { + auto key_1 = "a"; + auto table_1 = table(schema({field(key_1, int32())})); + auto key_2 = "b"; + auto table_2 = table(schema({field(key_2, int32())})); + auto key_3 = "c"; + auto table_3 = table(schema({field(key_3, int32())})); + + CatalogBuilder builder; + ASSERT_OK(builder.Add(key_1, table_1)); + ASSERT_OK(builder.Add(key_2, table_2)); + ASSERT_OK(builder.Add(key_3, table_3)); + ASSERT_OK_AND_ASSIGN(auto catalog, builder.Finish()); + + AssertCatalogKeyIs(catalog, key_1, table_1); + AssertCatalogKeyIs(catalog, key_2, table_2); + AssertCatalogKeyIs(catalog, key_3, table_3); + + ASSERT_RAISES(KeyError, catalog->Get("invalid_key")); +} + +TEST_F(TestCatalogBuilder, NullOrEmptyKeys) { + CatalogBuilder builder; + + auto invalid_key = ""; + // Invalid empty key + ASSERT_RAISES(Invalid, builder.Add(invalid_key, table())); + + auto valid_key = "valid_key"; + // Invalid nullptr Table + ASSERT_RAISES(Invalid, builder.Add(valid_key, std::shared_ptr
{})); + // Invalid nullptr Dataset + ASSERT_RAISES(Invalid, builder.Add(valid_key, std::shared_ptr{})); +} + +TEST_F(TestCatalogBuilder, DuplicateKeys) { + CatalogBuilder builder; + + auto key = "a_key"; + + ASSERT_OK(builder.Add(key, table())); + // Key already in catalog + ASSERT_RAISES(KeyError, builder.Add(key, table())); + + // Should still yield a valid catalog if requested. + ASSERT_OK_AND_ASSIGN(auto catalog, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); + ASSERT_EQ(t.kind(), Catalog::Entry::Kind::TABLE); + AssertTablesEqual(*t.table(), *table()); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc new file mode 100644 index 00000000000..1f49c81cc8d --- /dev/null +++ b/cpp/src/arrow/engine/expression.cc @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/expression.h" +#include "arrow/scalar.h" + +namespace arrow { +namespace engine { + +// +// ExprType +// + +ExprType ExprType::Scalar(std::shared_ptr type) { + return ExprType(std::move(type), Shape::SCALAR); +} + +ExprType ExprType::Array(std::shared_ptr type) { + return ExprType(std::move(type), Shape::ARRAY); +} + +ExprType ExprType::Table(std::shared_ptr schema) { + return ExprType(std::move(schema), Shape::TABLE); +} + +ExprType::ExprType(std::shared_ptr schema, Shape shape) + : type_(std::move(schema)), shape_(shape) { + DCHECK_EQ(shape, Shape::TABLE); +} + +ExprType::ExprType(std::shared_ptr type, Shape shape) + : type_(std::move(type)), shape_(shape) { + DCHECK_NE(shape, Shape::TABLE); +} + +std::shared_ptr ExprType::schema() const { + if (shape_ == TABLE) { + return util::get>(type_); + } + + return nullptr; +} + +std::shared_ptr ExprType::data_type() const { + if (shape_ != TABLE) { + return util::get>(type_); + } + + return nullptr; +} + +bool ExprType::Equals(const ExprType& type) const { + if (this == &type) { + return true; + } + + if (shape() != type.shape()) { + return false; + } + + switch (shape()) { + case SCALAR: + return data_type()->Equals(type.data_type()); + case ARRAY: + return data_type()->Equals(type.data_type()); + case TABLE: + return schema()->Equals(type.schema()); + default: + break; + } + + return false; +} + +bool ExprType::operator==(const ExprType& rhs) const { return Equals(rhs); } + +// +// Expr +// + +// +// ScalarExpr +// + +ScalarExpr::ScalarExpr(std::shared_ptr scalar) + : Expr(SCALAR), scalar_(std::move(scalar)) {} + +Result> ScalarExpr::Make(std::shared_ptr scalar) { + if (scalar == nullptr) { + return Status::Invalid("ScalarExpr's scalar must be non-null"); + } + + return std::shared_ptr(new ScalarExpr(std::move(scalar))); +} + +ExprType ScalarExpr::type() const { return ExprType::Scalar(scalar_->type); } + +// +// FieldRefExpr +// + +FieldRefExpr::FieldRefExpr(std::shared_ptr field) + : Expr(FIELD_REF), field_(std::move(field)) {} + +Result> FieldRefExpr::Make(std::shared_ptr field) { + if (field == nullptr) { + return Status::Invalid("FieldRefExpr's field must be non-null"); + } + + return std::shared_ptr(new FieldRefExpr(std::move(field))); +} + +ExprType FieldRefExpr::type() const { return ExprType::Scalar(field_->type()); } + +// +// ScanRelExpr +// + +ScanRelExpr::ScanRelExpr(Catalog::Entry input) + : Expr(SCAN_REL), input_(std::move(input)) {} + +Result> ScanRelExpr::Make(Catalog::Entry input) { + return std::shared_ptr(new ScanRelExpr(std::move(input))); +} + +ExprType ScanRelExpr::type() const { return ExprType::Table(input_.schema()); } + +// +// FilterRelExpr +// + +Result> FilterRelExpr::Make(std::shared_ptr input, + std::shared_ptr predicate) { + if (input == nullptr) { + return Status::Invalid("FilterRelExpr's input must be non-null."); + } + + if (predicate == nullptr) { + return Status::Invalid("FilterRelExpr's predicate must be non-null."); + } + + if (!predicate->type().Equals(ExprType::Scalar(boolean()))) { + return Status::Invalid("Filter's predicate expression must be a boolean scalar"); + } + + return std::shared_ptr( + new FilterRelExpr(std::move(input), std::move(predicate))); +} + +FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate) + : Expr(FILTER_REL), input_(std::move(input)), predicate_(std::move(predicate)) { + DCHECK_NE(input_, nullptr); + DCHECK_NE(predicate_, nullptr); +} + +ExprType FilterRelExpr::type() const { return ExprType::Table(input_->type().schema()); } + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h new file mode 100644 index 00000000000..2db5def1603 --- /dev/null +++ b/cpp/src/arrow/engine/expression.h @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/engine/catalog.h" +#include "arrow/type_fwd.h" +#include "arrow/util/variant.h" + +namespace arrow { +namespace engine { + +// Expression +class ARROW_EXPORT ExprType { + public: + enum Shape { + SCALAR, + ARRAY, + TABLE, + }; + + static ExprType Scalar(std::shared_ptr type); + static ExprType Array(std::shared_ptr type); + static ExprType Table(std::shared_ptr schema); + + /// \brief Shape of the expression. + Shape shape() const { return shape_; } + /// \brief Schema of the expression if of table shape. + std::shared_ptr schema() const; + std::shared_ptr data_type() const; + + bool Equals(const ExprType& type) const; + bool operator==(const ExprType& rhs) const; + + std::string ToString() const; + + private: + ExprType(std::shared_ptr schema, Shape shape); + ExprType(std::shared_ptr type, Shape shape); + + util::variant, std::shared_ptr> type_; + Shape shape_; +}; + +/// Represents an expression tree +class ARROW_EXPORT Expr { + public: + enum Kind { + // + SCALAR, + FIELD_REF, + // + EQ_OP, + // + SCAN_REL, + FILTER_REL, + }; + + Kind kind() const { return kind_; } + virtual ExprType type() const = 0; + + /// Returns true iff the expressions are identical; does not check for equivalence. + /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to + /// (false). + bool Equals(const Expr& other) const; + bool Equals(const std::shared_ptr& other) const; + + /// Return a string representing the expression + std::string ToString() const; + + virtual ~Expr() = default; + + protected: + explicit Expr(Kind kind) : kind_(kind) {} + + private: + Kind kind_; +}; + +// +// Value Expressions +// + +// An unnamed scalar literal expression. +class ScalarExpr : public Expr { + public: + static Result> Make(std::shared_ptr scalar); + + const std::shared_ptr& scalar() const { return scalar_; } + + ExprType type() const override; + + private: + explicit ScalarExpr(std::shared_ptr scalar); + std::shared_ptr scalar_; +}; + +// References a column in a table/dataset +class FieldRefExpr : public Expr { + public: + static Result> Make(std::shared_ptr field); + + const std::shared_ptr& field() const { return field_; } + + ExprType type() const override; + + private: + explicit FieldRefExpr(std::shared_ptr field); + + std::shared_ptr field_; +}; + +// +// Operators expression +// + +using ExprVector = std::vector>; + +class OpExpr { + public: + const ExprVector& inputs() const { return inputs_; } + + protected: + explicit OpExpr(ExprVector inputs) : inputs_(std::move(inputs)) {} + ExprVector inputs_; +}; + +template +class BinaryOpExpr : public Expr, private OpExpr { + public: + const std::shared_ptr& left() const { return inputs_[0]; } + const std::shared_ptr& right() const { return inputs_[1]; } + + protected: + BinaryOpExpr(std::shared_ptr left, std::shared_ptr right) + : Expr(KIND), OpExpr({std::move(left), std::move(right)}) {} +}; + +class EqOpExpr : public BinaryOpExpr {}; + +// +// Relational Expressions +// + +class ScanRelExpr : public Expr { + public: + static Result> Make(Catalog::Entry input); + + ExprType type() const override; + + private: + explicit ScanRelExpr(Catalog::Entry input); + + Catalog::Entry input_; +}; + +class FilterRelExpr : public Expr { + public: + static Result> Make(std::shared_ptr input, + std::shared_ptr predicate); + + ExprType type() const override; + + private: + FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate); + + std::shared_ptr input_; + std::shared_ptr predicate_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc new file mode 100644 index 00000000000..81fe0dadfbd --- /dev/null +++ b/cpp/src/arrow/engine/expression_test.cc @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include "arrow/engine/expression.h" +#include "arrow/scalar.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/type.h" + +namespace arrow { +namespace engine { + +class ExprTest : public testing::Test {}; + +TEST_F(ExprTest, ExprType) { + auto i32 = int32(); + auto s = schema({field("i32", i32)}); + + auto scalar = ExprType::Scalar(i32); + EXPECT_EQ(scalar.shape(), ExprType::Shape::SCALAR); + EXPECT_TRUE(scalar.data_type()->Equals(i32)); + EXPECT_EQ(scalar.schema(), nullptr); + + auto array = ExprType::Array(i32); + EXPECT_EQ(array.shape(), ExprType::Shape::ARRAY); + EXPECT_TRUE(array.data_type()->Equals(i32)); + EXPECT_EQ(array.schema(), nullptr); + + auto table = ExprType::Table(s); + EXPECT_EQ(table.shape(), ExprType::Shape::TABLE); + EXPECT_EQ(table.data_type(), nullptr); + EXPECT_TRUE(table.schema()->Equals(s)); +} + +TEST_F(ExprTest, ScalarExpr) { + ASSERT_RAISES(Invalid, ScalarExpr::Make(nullptr)); + + auto i32 = int32(); + ASSERT_OK_AND_ASSIGN(auto value, MakeScalar(i32, 10)); + ASSERT_OK_AND_ASSIGN(auto expr, ScalarExpr::Make(value)); + EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); + EXPECT_EQ(*expr->scalar(), *value); +} + +TEST_F(ExprTest, FieldRefExpr) { + ASSERT_RAISES(Invalid, FieldRefExpr::Make(nullptr)); + + auto i32 = int32(); + auto f = field("i32", i32); + + ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(f)); + EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); + EXPECT_TRUE(expr->field()->Equals(f)); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc new file mode 100644 index 00000000000..3208ad9ed70 --- /dev/null +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/logical_plan.h" + +#include "arrow/engine/expression.h" +#include "arrow/result.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace engine { + +// +// LogicalPlan +// + +LogicalPlan::LogicalPlan(std::shared_ptr root) : root_(std::move(root)) { + DCHECK_NE(root_, nullptr); +} + +bool LogicalPlan::Equals(const LogicalPlan& other) const { + if (this == &other) { + return true; + } + + return root()->Equals(other.root()); +} + +std::string LogicalPlan::ToString() const { return root_->ToString(); } + +// +// LogicalPlanBuilder +// + +LogicalPlanBuilder::LogicalPlanBuilder(std::shared_ptr catalog) + : catalog_(std::move(catalog)) {} + +Status LogicalPlanBuilder::Scan(const std::string& table_name) { + if (catalog_ == nullptr) { + return Status::Invalid("Cannot scan from an empty catalog"); + } + + ARROW_ASSIGN_OR_RAISE(auto table, catalog_->Get(table_name)); + ARROW_ASSIGN_OR_RAISE(auto scan, ScanRelExpr::Make(std::move(table))); + return Push(std::move(scan)); +} + +Status LogicalPlanBuilder::Filter(std::shared_ptr predicate) { + ARROW_ASSIGN_OR_RAISE(auto input, Peek()); + ARROW_ASSIGN_OR_RAISE(auto filter, + FilterRelExpr::Make(std::move(input), std::move(predicate))); + return Push(std::move(filter)); +} + +Result> LogicalPlanBuilder::Finish() { + if (stack_.empty()) { + return Status::Invalid("LogicalPlan is empty, nothing to construct."); + } + + ARROW_ASSIGN_OR_RAISE(auto root, Pop()); + if (!stack_.empty()) { + return Status::Invalid("LogicalPlan is ignoring operators left on the stack."); + } + + return std::make_shared(root); +} + +Result> LogicalPlanBuilder::Peek() { + if (stack_.empty()) { + return Status::Invalid("No Expr left on stack"); + } + + return stack_.top(); +} + +Result> LogicalPlanBuilder::Pop() { + ARROW_ASSIGN_OR_RAISE(auto top, Peek()); + stack_.pop(); + return top; +} + +Status LogicalPlanBuilder::Push(std::shared_ptr node) { + if (node == nullptr) { + return Status::Invalid(__FUNCTION__, " can't push a nullptr node."); + } + + stack_.push(std::move(node)); + return Status::OK(); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h new file mode 100644 index 00000000000..95de46397e4 --- /dev/null +++ b/cpp/src/arrow/engine/logical_plan.h @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/type_fwd.h" +#include "arrow/util/compare.h" +#include "arrow/util/variant.h" + +namespace arrow { + +namespace dataset { +class Dataset; +} + +namespace engine { + +class Catalog; +class Expr; + +class LogicalPlan : public util::EqualityComparable { + public: + explicit LogicalPlan(std::shared_ptr root); + + std::shared_ptr root() const { return root_; } + + bool Equals(const LogicalPlan& other) const; + std::string ToString() const; + + private: + std::shared_ptr root_; +}; + +class LogicalPlanBuilder { + public: + explicit LogicalPlanBuilder(std::shared_ptr catalog); + + /// \defgroup leaf-nodes Leaf nodes in the logical plan + /// @{ + + // Anonymous values literal. + Status Scalar(const std::shared_ptr& array); + Status Array(const std::shared_ptr& array); + Status Table(const std::shared_ptr
& table); + + // Named values + Status Scan(const std::string& table_name); + + /// @} + + Status Filter(std::shared_ptr predicate); + + Result> Finish(); + + private: + Status Push(std::shared_ptr); + Result> Pop(); + Result> Peek(); + + std::shared_ptr catalog_; + std::stack> stack_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc new file mode 100644 index 00000000000..8fe9b841d53 --- /dev/null +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/engine/catalog.h" +#include "arrow/engine/logical_plan.h" +#include "arrow/testing/gtest_common.h" + +namespace arrow { +namespace engine { + + +class LogicalPlanBuilderTest : public testing::Test { + protected: + void SetUp() override { + CatalogBuilder builder; + ASSERT_OK(builder.Add(table_1, MockTable(schema_1))); + ASSERT_OK_AND_ASSIGN(catalog, builder.Finish()); + } + std::string table_1 = "table_1"; + std::shared_ptr schema_1 = schema({field("i32", int32())}); + std::shared_ptr catalog; +}; + +TEST_F(LogicalPlanBuilderTest, BasicScan) { + LogicalPlanBuilder builder(catalog); + ASSERT_OK(builder.Scan(table_1)); + ASSERT_OK_AND_ASSIGN(auto plan, builder.Finish()); +} + +using testing::HasSubstr; + +TEST_F(LogicalPlanBuilderTest, ErrorEmptyFinish) { + LogicalPlanBuilder builder(catalog); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("LogicalPlan is empty"), + builder.Finish()); +} + +TEST_F(LogicalPlanBuilderTest, ErrorOperatorsLeftOnStack) { + LogicalPlanBuilder builder(catalog); + ASSERT_OK(builder.Scan(table_1)); + ASSERT_OK(builder.Scan(table_1)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("LogicalPlan is ignoring operators"), + builder.Finish()); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h new file mode 100644 index 00000000000..b352b114a25 --- /dev/null +++ b/cpp/src/arrow/engine/type_traits.h @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/engine/expression.h" + +namespace arrow { +namespace engine { + + +using Op = OpExpr; + + +} +} diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 8caf3f1cec9..566377a8fdf 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -321,6 +321,61 @@ void CompareBatch(const RecordBatch& left, const RecordBatch& right, } } +namespace detail { + +class MockTable : public Table { + public: + explicit MockTable(std::shared_ptr schema, int num_rows = 0) { + schema_ = std::move(schema); + num_rows_ = num_rows; + } + + static std::shared_ptr
Make(std::shared_ptr schema) { + return std::make_shared(std::move(schema)); + } + + std::shared_ptr column(int i) const override { + return std::make_shared(ArrayVector{}, schema_->field(i)->type()); + } + std::shared_ptr
Slice(int64_t offset, int64_t length) const override { + return nullptr; + } + + Status RemoveColumn(int i, std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + Status AddColumn(int i, std::shared_ptr field_arg, + std::shared_ptr column, + std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + Status SetColumn(int i, std::shared_ptr field_arg, + std::shared_ptr column, + std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + std::shared_ptr
ReplaceSchemaMetadata( + const std::shared_ptr& metadata) const override { + return nullptr; + } + + Status Flatten(MemoryPool* pool, std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + Status Validate() const override { return Status::OK(); } + Status ValidateFull() const override { return Status::OK(); } +}; + +} // namespace detail + +std::shared_ptr
MockTable(std::shared_ptr schema) { + return detail::MockTable::Make(schema); +} + class LocaleGuard::Impl { public: explicit Impl(const char* new_locale) : global_locale_(std::locale()) { diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 93ea12ddcf8..82ea1cb60d1 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -397,6 +397,9 @@ inline void BitmapFromVector(const std::vector& is_valid, ASSERT_OK(GetBitmapFromVector(is_valid, out)); } +// Returns a table with 0 rows of a given schema. +std::shared_ptr
MockTable(std::shared_ptr schema); + template void AssertSortedEquals(std::vector u, std::vector v) { std::sort(u.begin(), u.end()); From a3d5c67943aef4210d52d335bed5c8b06518ab5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 27 Feb 2020 10:33:26 -0500 Subject: [PATCH 02/21] update --- cpp/src/arrow/engine/expression.cc | 79 +++++++---- cpp/src/arrow/engine/expression.h | 155 +++++++++++++++++++--- cpp/src/arrow/engine/expression_test.cc | 65 ++++++++- cpp/src/arrow/engine/logical_plan.cc | 78 ++++++----- cpp/src/arrow/engine/logical_plan.h | 51 +++++-- cpp/src/arrow/engine/logical_plan_test.cc | 35 ++--- 6 files changed, 342 insertions(+), 121 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 1f49c81cc8d..8899f01c9b5 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -17,6 +17,7 @@ #include "arrow/engine/expression.h" #include "arrow/scalar.h" +#include "arrow/type.h" namespace arrow { namespace engine { @@ -88,21 +89,22 @@ bool ExprType::Equals(const ExprType& type) const { bool ExprType::operator==(const ExprType& rhs) const { return Equals(rhs); } -// -// Expr -// +#define PRECONDITION(cond, ...) \ + do { \ + if (ARROW_PREDICT_FALSE(!(cond))) { \ + return Status::Invalid(__VA_ARGS__); \ + } \ + } while (false) // // ScalarExpr // ScalarExpr::ScalarExpr(std::shared_ptr scalar) - : Expr(SCALAR), scalar_(std::move(scalar)) {} + : Expr(SCALAR_LITERAL), scalar_(std::move(scalar)) {} Result> ScalarExpr::Make(std::shared_ptr scalar) { - if (scalar == nullptr) { - return Status::Invalid("ScalarExpr's scalar must be non-null"); - } + PRECONDITION(scalar != nullptr, "ScalarExpr's scalar must be non-null"); return std::shared_ptr(new ScalarExpr(std::move(scalar))); } @@ -114,18 +116,43 @@ ExprType ScalarExpr::type() const { return ExprType::Scalar(scalar_->type); } // FieldRefExpr::FieldRefExpr(std::shared_ptr field) - : Expr(FIELD_REF), field_(std::move(field)) {} + : Expr(FIELD_REFERENCE), field_(std::move(field)) {} Result> FieldRefExpr::Make(std::shared_ptr field) { - if (field == nullptr) { - return Status::Invalid("FieldRefExpr's field must be non-null"); - } + PRECONDITION(field != nullptr, "FieldRefExpr's field must be non-null"); return std::shared_ptr(new FieldRefExpr(std::move(field))); } ExprType FieldRefExpr::type() const { return ExprType::Scalar(field_->type()); } +// +// Comparisons +// + +Status ValidateCompareOpInputs(std::shared_ptr left, std::shared_ptr right) { + PRECONDITION(left != nullptr, "EqualCmpExpr's left operand must be non-null"); + PRECONDITION(right != nullptr, "EqualCmpExpr's right operand must be non-null"); + // TODO(fsaintjacques): Add support for broadcast. + return Status::OK(); +} + +#define COMPARE_MAKE_IMPL(ExprClass) \ + Result> ExprClass::Make(std::shared_ptr left, \ + std::shared_ptr right) { \ + RETURN_NOT_OK(ValidateCompareOpInputs(left, right)); \ + return std::shared_ptr(new ExprClass(std::move(left), std::move(right))); \ + } + +COMPARE_MAKE_IMPL(EqualCmpExpr) +COMPARE_MAKE_IMPL(NotEqualCmpExpr) +COMPARE_MAKE_IMPL(GreaterThanCmpExpr) +COMPARE_MAKE_IMPL(GreaterEqualThanCmpExpr) +COMPARE_MAKE_IMPL(LowerThanCmpExpr) +COMPARE_MAKE_IMPL(LowerEqualThanCmpExpr) + +#undef COMPARE_MAKE_IMPL + // // ScanRelExpr // @@ -133,7 +160,7 @@ ExprType FieldRefExpr::type() const { return ExprType::Scalar(field_->type()); } ScanRelExpr::ScanRelExpr(Catalog::Entry input) : Expr(SCAN_REL), input_(std::move(input)) {} -Result> ScanRelExpr::Make(Catalog::Entry input) { +Result> ScanRelExpr::Make(Catalog::Entry input) { return std::shared_ptr(new ScanRelExpr(std::move(input))); } @@ -143,31 +170,27 @@ ExprType ScanRelExpr::type() const { return ExprType::Table(input_.schema()); } // FilterRelExpr // -Result> FilterRelExpr::Make(std::shared_ptr input, - std::shared_ptr predicate) { - if (input == nullptr) { - return Status::Invalid("FilterRelExpr's input must be non-null."); - } +Result> FilterRelExpr::Make( + std::shared_ptr input, std::shared_ptr predicate) { + PRECONDITION(input != nullptr, "FilterRelExpr's input must be non-null."); + PRECONDITION(input->type().IsTable(), "FilterRelExpr's input must be a table."); + PRECONDITION(predicate != nullptr, "FilterRelExpr's predicate must be non-null."); + PRECONDITION(predicate->type().IsPredicate(), + "FilterRelExpr's predicate must be a predicate"); - if (predicate == nullptr) { - return Status::Invalid("FilterRelExpr's predicate must be non-null."); - } - - if (!predicate->type().Equals(ExprType::Scalar(boolean()))) { - return Status::Invalid("Filter's predicate expression must be a boolean scalar"); - } + // TODO(fsaintjacques): check fields referenced in predicate are found in + // input. return std::shared_ptr( new FilterRelExpr(std::move(input), std::move(predicate))); } FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate) - : Expr(FILTER_REL), input_(std::move(input)), predicate_(std::move(predicate)) { - DCHECK_NE(input_, nullptr); - DCHECK_NE(predicate_, nullptr); -} + : Expr(FILTER_REL), input_(std::move(input)), predicate_(std::move(predicate)) {} ExprType FilterRelExpr::type() const { return ExprType::Table(input_->type().schema()); } +#undef PRECONDITION + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 2db5def1603..c474d35ecb8 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -23,38 +23,76 @@ #include #include "arrow/engine/catalog.h" +#include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/util/variant.h" namespace arrow { namespace engine { -// Expression +/// ExprType is a class representing the type of an Expression. The type is +/// composed of a shape and a DataType or a Schema depending on the shape. +/// +/// ExprType is mainly used to validate arguments for operator expressions, e.g. +/// relational operator expressions expect inputs of Table shape. +/// +/// The sum-type representation would be: +/// +/// enum ExprType { +/// ScalarType(DataType), +/// ArrayType(DataType), +/// TableType(Schema), +/// } class ARROW_EXPORT ExprType { public: enum Shape { + // The expression yields a Scalar, e.g. "1". SCALAR, + // The expression yields an Array, e.g. "[1, 2, 3]". ARRAY, + // The expression yields a Table, e.g. "{'a': [1, 2], 'b': [true, false]}" TABLE, }; + /// Construct a Scalar type. static ExprType Scalar(std::shared_ptr type); + /// Construct an Array type. static ExprType Array(std::shared_ptr type); + /// Construct a Table type. static ExprType Table(std::shared_ptr schema); /// \brief Shape of the expression. Shape shape() const { return shape_; } + + /// \brief DataType of the expression if a scalar or an array. + std::shared_ptr data_type() const; /// \brief Schema of the expression if of table shape. std::shared_ptr schema() const; - std::shared_ptr data_type() const; + + /// \brief Indicate if the type is a Scalar. + bool IsScalar() const { return shape_ == SCALAR; } + /// \brief Indicate if the type is an Array. + bool IsArray() const { return shape_ == ARRAY; } + /// \brief Indicate if the type is a Table. + bool IsTable() const { return shape_ == TABLE; } + + template + bool HasType() const { + return (shape_ == SCALAR || shape_ == ARRAY) && + util::get>(type_)->id() == TYPE_ID; + } + + /// \brief Indicate if the type is a predicate, i.e. a boolean scalar. + bool IsPredicate() const { return IsScalar() && HasType(); } bool Equals(const ExprType& type) const; bool operator==(const ExprType& rhs) const; - std::string ToString() const; private: + /// Table constructor ExprType(std::shared_ptr schema, Shape shape); + /// Scalar or Array constructor ExprType(std::shared_ptr type, Shape shape); util::variant, std::shared_ptr> type_; @@ -64,20 +102,37 @@ class ARROW_EXPORT ExprType { /// Represents an expression tree class ARROW_EXPORT Expr { public: + // Tag identifier for the expression type. enum Kind { - // - SCALAR, - FIELD_REF, - // - EQ_OP, - // + // A Scalar literal, i.e. a constant. + SCALAR_LITERAL, + // A Field reference in a schema. + FIELD_REFERENCE, + + // Equal compare operator + EQ_CMP_OP, + // Not-Equal compare operator + NE_CMP_OP, + // Greater-Than compare operator + GT_CMP_OP, + // Greater-Equal-Than compare operator + GE_CMP_OP, + // Lower-Than compare operator + LT_CMP_OP, + // Lower-Equal-Than compare operator + LE_CMP_OP, + + // Scan relational operator SCAN_REL, + // Filter relational operator FILTER_REL, }; - Kind kind() const { return kind_; } + // Return the type and shape of the resulting expression. virtual ExprType type() const = 0; + Kind kind() const { return kind_; } + /// Returns true iff the expressions are identical; does not check for equivalence. /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to /// (false). @@ -111,6 +166,7 @@ class ScalarExpr : public Expr { private: explicit ScalarExpr(std::shared_ptr scalar); + std::shared_ptr scalar_; }; @@ -130,7 +186,7 @@ class FieldRefExpr : public Expr { }; // -// Operators expression +// Operator expressions // using ExprVector = std::vector>; @@ -147,15 +203,80 @@ class OpExpr { template class BinaryOpExpr : public Expr, private OpExpr { public: - const std::shared_ptr& left() const { return inputs_[0]; } - const std::shared_ptr& right() const { return inputs_[1]; } + const std::shared_ptr& left_operand() const { return inputs_[0]; } + const std::shared_ptr& right_operand() const { return inputs_[1]; } protected: BinaryOpExpr(std::shared_ptr left, std::shared_ptr right) : Expr(KIND), OpExpr({std::move(left), std::move(right)}) {} }; -class EqOpExpr : public BinaryOpExpr {}; +// +// Comparison expressions +// + +template +class CmpOpExpr : public BinaryOpExpr { + public: + ExprType type() const override { return ExprType::Scalar(boolean()); }; + + protected: + using BinaryOpExpr::BinaryOpExpr; +}; + +class EqualCmpExpr : public CmpOpExpr { + public: + static Result> Make(std::shared_ptr left, + std::shared_ptr right); + + protected: + using CmpOpExpr::CmpOpExpr; +}; + +class NotEqualCmpExpr : public CmpOpExpr { + public: + static Result> Make(std::shared_ptr left, + std::shared_ptr right); + + protected: + using CmpOpExpr::CmpOpExpr; +}; + +class GreaterThanCmpExpr : public CmpOpExpr { + public: + static Result> Make(std::shared_ptr left, + std::shared_ptr right); + + protected: + using CmpOpExpr::CmpOpExpr; +}; + +class GreaterEqualThanCmpExpr : public CmpOpExpr { + public: + static Result> Make( + std::shared_ptr left, std::shared_ptr right); + + protected: + using CmpOpExpr::CmpOpExpr; +}; + +class LowerThanCmpExpr : public CmpOpExpr { + public: + static Result> Make(std::shared_ptr left, + std::shared_ptr right); + + protected: + using CmpOpExpr::CmpOpExpr; +}; + +class LowerEqualThanCmpExpr : public CmpOpExpr { + public: + static Result> Make(std::shared_ptr left, + std::shared_ptr right); + + protected: + using CmpOpExpr::CmpOpExpr; +}; // // Relational Expressions @@ -163,7 +284,7 @@ class EqOpExpr : public BinaryOpExpr {}; class ScanRelExpr : public Expr { public: - static Result> Make(Catalog::Entry input); + static Result> Make(Catalog::Entry input); ExprType type() const override; @@ -175,8 +296,8 @@ class ScanRelExpr : public Expr { class FilterRelExpr : public Expr { public: - static Result> Make(std::shared_ptr input, - std::shared_ptr predicate); + static Result> Make(std::shared_ptr input, + std::shared_ptr predicate); ExprType type() const override; diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 81fe0dadfbd..5b980292516 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -14,19 +14,21 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// #include "arrow/engine/expression.h" #include "arrow/scalar.h" +#include "arrow/testing/gmock.h" #include "arrow/testing/gtest_common.h" #include "arrow/type.h" +using testing::Pointee; + namespace arrow { namespace engine { -class ExprTest : public testing::Test {}; +class ExprTypeTest : public testing::Test {}; -TEST_F(ExprTest, ExprType) { +TEST_F(ExprTypeTest, Basic) { auto i32 = int32(); auto s = schema({field("i32", i32)}); @@ -34,37 +36,88 @@ TEST_F(ExprTest, ExprType) { EXPECT_EQ(scalar.shape(), ExprType::Shape::SCALAR); EXPECT_TRUE(scalar.data_type()->Equals(i32)); EXPECT_EQ(scalar.schema(), nullptr); + EXPECT_TRUE(scalar.IsScalar()); + EXPECT_FALSE(scalar.IsArray()); + EXPECT_FALSE(scalar.IsTable()); auto array = ExprType::Array(i32); EXPECT_EQ(array.shape(), ExprType::Shape::ARRAY); EXPECT_TRUE(array.data_type()->Equals(i32)); EXPECT_EQ(array.schema(), nullptr); + EXPECT_FALSE(array.IsScalar()); + EXPECT_TRUE(array.IsArray()); + EXPECT_FALSE(array.IsTable()); auto table = ExprType::Table(s); EXPECT_EQ(table.shape(), ExprType::Shape::TABLE); EXPECT_EQ(table.data_type(), nullptr); EXPECT_TRUE(table.schema()->Equals(s)); + EXPECT_FALSE(table.IsScalar()); + EXPECT_FALSE(table.IsArray()); + EXPECT_TRUE(table.IsTable()); +} + +TEST_F(ExprTypeTest, IsPredicate) { + auto bool_scalar = ExprType::Scalar(boolean()); + EXPECT_TRUE(bool_scalar.IsPredicate()); + + auto bool_array = ExprType::Array(boolean()); + EXPECT_FALSE(bool_array.IsPredicate()); + + auto bool_table = ExprType::Table(schema({field("b", boolean())})); + EXPECT_FALSE(bool_table.IsPredicate()); + + auto i32_scalar = ExprType::Scalar(int32()); + EXPECT_FALSE(i32_scalar.IsPredicate()); } +class ExprTest : public testing::Test {}; + TEST_F(ExprTest, ScalarExpr) { ASSERT_RAISES(Invalid, ScalarExpr::Make(nullptr)); auto i32 = int32(); ASSERT_OK_AND_ASSIGN(auto value, MakeScalar(i32, 10)); ASSERT_OK_AND_ASSIGN(auto expr, ScalarExpr::Make(value)); + EXPECT_EQ(expr->kind(), Expr::SCALAR_LITERAL); EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); EXPECT_EQ(*expr->scalar(), *value); + } TEST_F(ExprTest, FieldRefExpr) { ASSERT_RAISES(Invalid, FieldRefExpr::Make(nullptr)); auto i32 = int32(); - auto f = field("i32", i32); + auto f_i32 = field("i32", i32); - ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(f)); + ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(f_i32)); + EXPECT_EQ(expr->kind(), Expr::FIELD_REFERENCE); EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); - EXPECT_TRUE(expr->field()->Equals(f)); + EXPECT_THAT(expr->field(), IsPtrEqual(f_i32)); +} + +TEST_F(ExprTest, EqualCmpExpr) { + auto i32 = int32(); + + auto f_i32 = field("i32", i32); + ASSERT_OK_AND_ASSIGN(auto f_expr, FieldRefExpr::Make(f_i32)); + + ASSERT_OK_AND_ASSIGN(auto s_i32, MakeScalar(i32, 42)); + ASSERT_OK_AND_ASSIGN(auto s_expr, ScalarExpr::Make(s_i32)); + + ASSERT_RAISES(Invalid, EqualCmpExpr::Make(nullptr, nullptr)); + ASSERT_RAISES(Invalid, EqualCmpExpr::Make(s_expr, nullptr)); + ASSERT_RAISES(Invalid, EqualCmpExpr::Make(nullptr, f_expr)); + + ASSERT_OK_AND_ASSIGN(auto expr, EqualCmpExpr::Make(f_expr, s_expr)); + + EXPECT_EQ(expr->kind(), Expr::EQ_CMP_OP); + EXPECT_EQ(expr->type(), ExprType::Scalar(boolean())); + /* + EXPECT_THAT(expr->left_operand(), IsPtrEqual(f_expr)); + EXPECT_THAT(expr->right_operand(), Pointee(s_expr)); + */ } } // namespace engine diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc index 3208ad9ed70..ea4d78ae449 100644 --- a/cpp/src/arrow/engine/logical_plan.cc +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -46,61 +46,57 @@ std::string LogicalPlan::ToString() const { return root_->ToString(); } // LogicalPlanBuilder // -LogicalPlanBuilder::LogicalPlanBuilder(std::shared_ptr catalog) - : catalog_(std::move(catalog)) {} +LogicalPlanBuilder::LogicalPlanBuilder(LogicalPlanBuilderOptions options) + : catalog_(options.catalog) {} -Status LogicalPlanBuilder::Scan(const std::string& table_name) { - if (catalog_ == nullptr) { - return Status::Invalid("Cannot scan from an empty catalog"); - } +using ResultExpr = LogicalPlanBuilder::ResultExpr; - ARROW_ASSIGN_OR_RAISE(auto table, catalog_->Get(table_name)); - ARROW_ASSIGN_OR_RAISE(auto scan, ScanRelExpr::Make(std::move(table))); - return Push(std::move(scan)); -} +#define ERROR_IF(cond, ...) \ + do { \ + if (ARROW_PREDICT_FALSE(cond)) { \ + return Status::Invalid(__VA_ARGS__); \ + } \ + } while (false) + +// +// Leaf builder. +// -Status LogicalPlanBuilder::Filter(std::shared_ptr predicate) { - ARROW_ASSIGN_OR_RAISE(auto input, Peek()); - ARROW_ASSIGN_OR_RAISE(auto filter, - FilterRelExpr::Make(std::move(input), std::move(predicate))); - return Push(std::move(filter)); +ResultExpr LogicalPlanBuilder::Scalar(const std::shared_ptr& scalar) { + return ScalarExpr::Make(scalar); } -Result> LogicalPlanBuilder::Finish() { - if (stack_.empty()) { - return Status::Invalid("LogicalPlan is empty, nothing to construct."); - } +ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, + const std::string& field_name) { + ERROR_IF(input == nullptr, "Input expression must be non-null"); - ARROW_ASSIGN_OR_RAISE(auto root, Pop()); - if (!stack_.empty()) { - return Status::Invalid("LogicalPlan is ignoring operators left on the stack."); - } + auto expr_type = input->type(); + ERROR_IF(!expr_type.IsTable(), "Input expression does not have a Table shape."); - return std::make_shared(root); + auto field = expr_type.schema()->GetFieldByName(field_name); + ERROR_IF(field == nullptr, "Cannot reference field '", field_name, "' in schema."); + + return FieldRefExpr::Make(std::move(field)); } -Result> LogicalPlanBuilder::Peek() { - if (stack_.empty()) { - return Status::Invalid("No Expr left on stack"); - } +// +// Relational +// - return stack_.top(); +ResultExpr LogicalPlanBuilder::Scan(const std::string& table_name) { + ERROR_IF(catalog_ == nullptr, "Cannot scan from an empty catalog"); + ARROW_ASSIGN_OR_RAISE(auto table, catalog_->Get(table_name)); + return ScanRelExpr::Make(std::move(table)); } -Result> LogicalPlanBuilder::Pop() { - ARROW_ASSIGN_OR_RAISE(auto top, Peek()); - stack_.pop(); - return top; +ResultExpr LogicalPlanBuilder::Filter(const std::shared_ptr& input, + const std::shared_ptr& predicate) { + ERROR_IF(input == nullptr, "Input expression can't be null."); + ERROR_IF(predicate == nullptr, "Predicate expression can't be null."); + return FilterRelExpr::Make(std::move(input), std::move(predicate)); } -Status LogicalPlanBuilder::Push(std::shared_ptr node) { - if (node == nullptr) { - return Status::Invalid(__FUNCTION__, " can't push a nullptr node."); - } - - stack_.push(std::move(node)); - return Status::OK(); -} +#undef ERROR_IF } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index 95de46397e4..e22c48ccd60 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -51,34 +51,57 @@ class LogicalPlan : public util::EqualityComparable { std::shared_ptr root_; }; +struct LogicalPlanBuilderOptions { + std::shared_ptr catalog; +}; + class LogicalPlanBuilder { public: - explicit LogicalPlanBuilder(std::shared_ptr catalog); + using ResultExpr = Result>; + + explicit LogicalPlanBuilder(LogicalPlanBuilderOptions options = {}); /// \defgroup leaf-nodes Leaf nodes in the logical plan /// @{ - // Anonymous values literal. - Status Scalar(const std::shared_ptr& array); - Status Array(const std::shared_ptr& array); - Status Table(const std::shared_ptr
& table); + /// \brief Construct a Scalar literal. + ResultExpr Scalar(const std::shared_ptr& scalar); - // Named values - Status Scan(const std::string& table_name); + /// \brief References a field by name. + ResultExpr Field(const std::shared_ptr& input, const std::string& field_name); + + /// \brief Scan a Table/Dataset from the Catalog. + ResultExpr Scan(const std::string& table_name); /// @} - Status Filter(std::shared_ptr predicate); + /// \defgroup rel-nodes Relational operator nodes in the logical plan - Result> Finish(); + ResultExpr Filter(const std::shared_ptr& input, + const std::shared_ptr& predicate); - private: - Status Push(std::shared_ptr); - Result> Pop(); - Result> Peek(); + /* + /// \brief Project (mutate) columns with given expressions. + ResultExpr Project(const std::vector>& expressions); + ResultExpr Mutate(const std::vector>& expressions); + /// \brief Project (select) columns by names. + /// + /// This is a simplified version of Project where columns are selected by + /// names. Duplicate and ordering are preserved. + ResultExpr Project(const std::vector& column_names); + + /// \brief Project (select) columns by indices. + /// + /// This is a simplified version of Project where columns are selected by + /// indices. Duplicate and ordering are preserved. + ResultExpr Project(const std::vector& column_indices); + */ + + /// @} + + private: std::shared_ptr catalog_; - std::stack> stack_; }; } // namespace engine diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc index 8fe9b841d53..200918b1979 100644 --- a/cpp/src/arrow/engine/logical_plan_test.cc +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -21,10 +21,11 @@ #include "arrow/engine/logical_plan.h" #include "arrow/testing/gtest_common.h" +using testing::HasSubstr; + namespace arrow { namespace engine { - class LogicalPlanBuilderTest : public testing::Test { protected: void SetUp() override { @@ -37,26 +38,30 @@ class LogicalPlanBuilderTest : public testing::Test { std::shared_ptr catalog; }; +TEST_F(LogicalPlanBuilderTest, Scalar) { + LogicalPlanBuilder builder{{catalog}}; + auto forthy_two = MakeScalar(42); + EXPECT_OK_AND_ASSIGN(auto scalar, builder.Scalar(forthy_two)); +} + TEST_F(LogicalPlanBuilderTest, BasicScan) { - LogicalPlanBuilder builder(catalog); + LogicalPlanBuilder builder{{catalog}}; ASSERT_OK(builder.Scan(table_1)); - ASSERT_OK_AND_ASSIGN(auto plan, builder.Finish()); } -using testing::HasSubstr; +TEST_F(LogicalPlanBuilderTest, FieldReferenceByName) { + LogicalPlanBuilder builder{{catalog}}; -TEST_F(LogicalPlanBuilderTest, ErrorEmptyFinish) { - LogicalPlanBuilder builder(catalog); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("LogicalPlan is empty"), - builder.Finish()); -} + // Input must be non-null. + ASSERT_RAISES(Invalid, builder.Field(nullptr, "i32")); -TEST_F(LogicalPlanBuilderTest, ErrorOperatorsLeftOnStack) { - LogicalPlanBuilder builder(catalog); - ASSERT_OK(builder.Scan(table_1)); - ASSERT_OK(builder.Scan(table_1)); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("LogicalPlan is ignoring operators"), - builder.Finish()); + // The input must have a Table shape. + auto forthy_two = MakeScalar(42); + EXPECT_OK_AND_ASSIGN(auto scalar, builder.Scalar(forthy_two)); + ASSERT_RAISES(Invalid, builder.Field(scalar, "not_found")); + + EXPECT_OK_AND_ASSIGN(auto table_scan, builder.Scan(table_1)); + EXPECT_OK_AND_ASSIGN(auto field_ref, builder.Field(table_scan, "i32")); } } // namespace engine From 8691ace914ede1a44b382864e5c5ed1e294c4e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 27 Feb 2020 22:38:54 -0500 Subject: [PATCH 03/21] Update --- cpp/src/arrow/engine/catalog.cc | 2 + cpp/src/arrow/engine/catalog.h | 3 + cpp/src/arrow/engine/expression.cc | 153 +++++++++++---- cpp/src/arrow/engine/expression.h | 251 +++++++++++++++++------- cpp/src/arrow/engine/expression_test.cc | 107 +++++++++- cpp/src/arrow/engine/logical_plan.cc | 2 + cpp/src/arrow/engine/logical_plan.h | 2 +- cpp/src/arrow/engine/type_fwd.h | 39 ++++ cpp/src/arrow/engine/type_traits.h | 20 +- cpp/src/arrow/testing/gmock.h | 28 +++ 10 files changed, 481 insertions(+), 126 deletions(-) create mode 100644 cpp/src/arrow/engine/type_fwd.h create mode 100644 cpp/src/arrow/testing/gmock.h diff --git a/cpp/src/arrow/engine/catalog.cc b/cpp/src/arrow/engine/catalog.cc index 8381d502437..be8f73e551c 100644 --- a/cpp/src/arrow/engine/catalog.cc +++ b/cpp/src/arrow/engine/catalog.cc @@ -86,6 +86,8 @@ std::shared_ptr Entry::dataset() const { return nullptr; } +bool Entry::operator==(const Entry& other) const { return entry_ == other.entry_; } + std::shared_ptr Entry::schema() const { switch (kind()) { case TABLE: diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h index 3f111e306a1..e982862d940 100644 --- a/cpp/src/arrow/engine/catalog.h +++ b/cpp/src/arrow/engine/catalog.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "arrow/type_fwd.h" @@ -64,6 +65,8 @@ class Catalog { std::shared_ptr schema() const; + bool operator==(const Entry& other) const; + private: util::variant, std::shared_ptr> entry_; }; diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 8899f01c9b5..2b07cd01668 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -16,8 +16,10 @@ // under the License. #include "arrow/engine/expression.h" +#include "arrow/engine/type_traits.h" #include "arrow/scalar.h" #include "arrow/type.h" +#include "arrow/util/checked_cast.h" namespace arrow { namespace engine { @@ -87,15 +89,100 @@ bool ExprType::Equals(const ExprType& type) const { return false; } -bool ExprType::operator==(const ExprType& rhs) const { return Equals(rhs); } - -#define PRECONDITION(cond, ...) \ +#define ERROR_IF(cond, ...) \ do { \ - if (ARROW_PREDICT_FALSE(!(cond))) { \ + if (ARROW_PREDICT_FALSE(cond)) { \ return Status::Invalid(__VA_ARGS__); \ } \ } while (false) +// +// Expr +// + +std::string Expr::kind_name() const { + switch (kind_) { + case Expr::SCALAR_LITERAL: + return "scalar"; + case Expr::FIELD_REFERENCE: + return "field_ref"; + + case Expr::EQ_CMP_OP: + return "eq_cmp"; + case Expr::NE_CMP_OP: + return "ne_cmp"; + case Expr::GT_CMP_OP: + return "gt_cmp"; + case Expr::GE_CMP_OP: + return "ge_cmp"; + case Expr::LT_CMP_OP: + return "lt_cmp"; + case Expr::LE_CMP_OP: + return "le_cmp"; + + case Expr::EMPTY_REL: + return "empty_rel"; + case Expr::SCAN_REL: + return "scan_rel"; + case Expr::FILTER_REL: + return "filter_rel"; + } + + return "unknown expr"; +} + +struct ExprEqualityVisitor { + bool operator()(const ScalarExpr& rhs) const { + auto lhs_scalar = internal::checked_cast(lhs); + return lhs_scalar.scalar()->Equals(*rhs.scalar()); + } + + bool operator()(const FieldRefExpr& rhs) const { + auto lhs_field = internal::checked_cast(lhs); + return lhs_field.field()->Equals(*rhs.field()); + } + + template + enable_if_compare_expr operator()(const E& rhs) const { + auto lhs_cmp = internal::checked_cast(lhs); + return (lhs_cmp.left_operand()->Equals(rhs.left_operand()) && + lhs_cmp.right_operand()->Equals(rhs.right_operand())) || + (lhs_cmp.left_operand()->Equals(rhs.right_operand()) && + lhs_cmp.left_operand()->Equals(rhs.right_operand())); + } + + bool operator()(const EmptyRelExpr& rhs) const { + auto lhs_empty = internal::checked_cast(lhs); + return lhs_empty.schema()->Equals(rhs.schema()); + } + + bool operator()(const ScanRelExpr& rhs) const { + auto lhs_scan = internal::checked_cast(lhs); + // Performs a pointer equality on Table/Dataset + return lhs_scan.input() == rhs.input(); + } + + bool operator()(const Expr&) const { return false; } + + static bool Visit(const Expr& lhs, const Expr& rhs) { + return VisitExpr(rhs, ExprEqualityVisitor{lhs}); + } + + const Expr& lhs; +}; + +bool Expr::Equals(const Expr& other) const { + if (this == &other) { + return true; + } + + if (kind() != other.kind() || type() != other.type()) { + return false; + } + + return ExprEqualityVisitor::Visit(*this, other); +} + // // ScalarExpr // @@ -104,7 +191,7 @@ ScalarExpr::ScalarExpr(std::shared_ptr scalar) : Expr(SCALAR_LITERAL), scalar_(std::move(scalar)) {} Result> ScalarExpr::Make(std::shared_ptr scalar) { - PRECONDITION(scalar != nullptr, "ScalarExpr's scalar must be non-null"); + ERROR_IF(scalar == nullptr, "ScalarExpr's scalar must be non-null"); return std::shared_ptr(new ScalarExpr(std::move(scalar))); } @@ -119,7 +206,7 @@ FieldRefExpr::FieldRefExpr(std::shared_ptr field) : Expr(FIELD_REFERENCE), field_(std::move(field)) {} Result> FieldRefExpr::Make(std::shared_ptr field) { - PRECONDITION(field != nullptr, "FieldRefExpr's field must be non-null"); + ERROR_IF(field == nullptr, "FieldRefExpr's field must be non-null"); return std::shared_ptr(new FieldRefExpr(std::move(field))); } @@ -130,53 +217,49 @@ ExprType FieldRefExpr::type() const { return ExprType::Scalar(field_->type()); } // Comparisons // -Status ValidateCompareOpInputs(std::shared_ptr left, std::shared_ptr right) { - PRECONDITION(left != nullptr, "EqualCmpExpr's left operand must be non-null"); - PRECONDITION(right != nullptr, "EqualCmpExpr's right operand must be non-null"); +Status ValidateCompareOpInputs(const std::shared_ptr& left, + const std::shared_ptr& right) { + ERROR_IF(left == nullptr, "EqualCmpExpr's left operand must be non-null"); + ERROR_IF(right == nullptr, "EqualCmpExpr's right operand must be non-null"); + // TODO(fsaintjacques): Add support for broadcast. + ERROR_IF(left->type() != right->type(), + "Compare operator operands must be of same type."); + return Status::OK(); } -#define COMPARE_MAKE_IMPL(ExprClass) \ - Result> ExprClass::Make(std::shared_ptr left, \ - std::shared_ptr right) { \ - RETURN_NOT_OK(ValidateCompareOpInputs(left, right)); \ - return std::shared_ptr(new ExprClass(std::move(left), std::move(right))); \ - } - -COMPARE_MAKE_IMPL(EqualCmpExpr) -COMPARE_MAKE_IMPL(NotEqualCmpExpr) -COMPARE_MAKE_IMPL(GreaterThanCmpExpr) -COMPARE_MAKE_IMPL(GreaterEqualThanCmpExpr) -COMPARE_MAKE_IMPL(LowerThanCmpExpr) -COMPARE_MAKE_IMPL(LowerEqualThanCmpExpr) +// +// +// -#undef COMPARE_MAKE_IMPL +Result> EmptyRelExpr::Make(std::shared_ptr schema) { + ERROR_IF(schema == nullptr, "EmptyRelExpr schema must be non-null"); + return std::shared_ptr(new EmptyRelExpr(std::move(schema))); +} // // ScanRelExpr // ScanRelExpr::ScanRelExpr(Catalog::Entry input) - : Expr(SCAN_REL), input_(std::move(input)) {} + : RelExpr(input.schema()), input_(std::move(input)) {} Result> ScanRelExpr::Make(Catalog::Entry input) { return std::shared_ptr(new ScanRelExpr(std::move(input))); } -ExprType ScanRelExpr::type() const { return ExprType::Table(input_.schema()); } - // // FilterRelExpr // Result> FilterRelExpr::Make( std::shared_ptr input, std::shared_ptr predicate) { - PRECONDITION(input != nullptr, "FilterRelExpr's input must be non-null."); - PRECONDITION(input->type().IsTable(), "FilterRelExpr's input must be a table."); - PRECONDITION(predicate != nullptr, "FilterRelExpr's predicate must be non-null."); - PRECONDITION(predicate->type().IsPredicate(), - "FilterRelExpr's predicate must be a predicate"); + ERROR_IF(input == nullptr, "FilterRelExpr's input must be non-null."); + ERROR_IF(!input->type().IsTable(), "FilterRelExpr's input must be a table."); + ERROR_IF(predicate == nullptr, "FilterRelExpr's predicate must be non-null."); + ERROR_IF(!predicate->type().IsPredicate(), + "FilterRelExpr's predicate must be a predicate"); // TODO(fsaintjacques): check fields referenced in predicate are found in // input. @@ -186,11 +269,11 @@ Result> FilterRelExpr::Make( } FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate) - : Expr(FILTER_REL), input_(std::move(input)), predicate_(std::move(predicate)) {} - -ExprType FilterRelExpr::type() const { return ExprType::Table(input_->type().schema()); } + : UnaryOpExpr(std::move(input)), + RelExpr(operand()->type().schema()), + predicate_(std::move(predicate)) {} -#undef PRECONDITION +#undef ERROR_IF } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index c474d35ecb8..e799ac9e56e 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -19,12 +19,13 @@ #include #include -#include -#include +#include #include "arrow/engine/catalog.h" +#include "arrow/engine/type_fwd.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/util/compare.h" #include "arrow/util/variant.h" namespace arrow { @@ -43,9 +44,9 @@ namespace engine { /// ArrayType(DataType), /// TableType(Schema), /// } -class ARROW_EXPORT ExprType { +class ARROW_EXPORT ExprType : public util::EqualityComparable { public: - enum Shape { + enum Shape : uint8_t { // The expression yields a Scalar, e.g. "1". SCALAR, // The expression yields an Array, e.g. "[1, 2, 3]". @@ -86,7 +87,7 @@ class ARROW_EXPORT ExprType { bool IsPredicate() const { return IsScalar() && HasType(); } bool Equals(const ExprType& type) const; - bool operator==(const ExprType& rhs) const; + std::string ToString() const; private: @@ -100,7 +101,7 @@ class ARROW_EXPORT ExprType { }; /// Represents an expression tree -class ARROW_EXPORT Expr { +class ARROW_EXPORT Expr : public util::EqualityComparable { public: // Tag identifier for the expression type. enum Kind { @@ -117,29 +118,32 @@ class ARROW_EXPORT Expr { GT_CMP_OP, // Greater-Equal-Than compare operator GE_CMP_OP, - // Lower-Than compare operator + // Less-Than compare operator LT_CMP_OP, - // Lower-Equal-Than compare operator + // Less-Equal-Than compare operator LE_CMP_OP, + // Empty relation with a known schema. + EMPTY_REL, // Scan relational operator SCAN_REL, // Filter relational operator FILTER_REL, }; - // Return the type and shape of the resulting expression. - virtual ExprType type() const = 0; - + /// \brief Return the kind of the expression. Kind kind() const { return kind_; } + /// \brief Return a string representation of the kind. + std::string kind_name() const; + + /// \brief Return the type and shape of the resulting expression. + virtual ExprType type() const = 0; - /// Returns true iff the expressions are identical; does not check for equivalence. - /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to - /// (false). + /// \brief Indicate if the expressions bool Equals(const Expr& other) const; - bool Equals(const std::shared_ptr& other) const; + using util::EqualityComparable::Equals; - /// Return a string representing the expression + /// \brief Return a string representing the expression std::string ToString() const; virtual ~Expr() = default; @@ -147,16 +151,76 @@ class ARROW_EXPORT Expr { protected: explicit Expr(Kind kind) : kind_(kind) {} - private: Kind kind_; }; +// The following traits are used to break cycle between CRTP base classes and +// their derived counterparts to extract the Expr::Kind and other static +// properties from the forward declared class. +template +struct expr_traits; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::SCALAR_LITERAL; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::FIELD_REFERENCE; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::EQ_CMP_OP; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::NE_CMP_OP; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::GT_CMP_OP; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::GE_CMP_OP; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::LT_CMP_OP; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::LE_CMP_OP; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::EMPTY_REL; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::SCAN_REL; +}; + +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::FILTER_REL; +}; + // // Value Expressions // // An unnamed scalar literal expression. -class ScalarExpr : public Expr { +class ARROW_EXPORT ScalarExpr : public Expr { public: static Result> Make(std::shared_ptr scalar); @@ -171,7 +235,7 @@ class ScalarExpr : public Expr { }; // References a column in a table/dataset -class FieldRefExpr : public Expr { +class ARROW_EXPORT FieldRefExpr : public Expr { public: static Result> Make(std::shared_ptr field); @@ -189,104 +253,114 @@ class FieldRefExpr : public Expr { // Operator expressions // -using ExprVector = std::vector>; - -class OpExpr { +class ARROW_EXPORT UnaryOpExpr { public: - const ExprVector& inputs() const { return inputs_; } + const std::shared_ptr& operand() const { return operand_; } protected: - explicit OpExpr(ExprVector inputs) : inputs_(std::move(inputs)) {} - ExprVector inputs_; + explicit UnaryOpExpr(std::shared_ptr operand) : operand_(std::move(operand)) {} + + std::shared_ptr operand_; }; -template -class BinaryOpExpr : public Expr, private OpExpr { +class ARROW_EXPORT BinaryOpExpr { public: - const std::shared_ptr& left_operand() const { return inputs_[0]; } - const std::shared_ptr& right_operand() const { return inputs_[1]; } + const std::shared_ptr& left_operand() const { return left_operand_; } + const std::shared_ptr& right_operand() const { return right_operand_; } protected: BinaryOpExpr(std::shared_ptr left, std::shared_ptr right) - : Expr(KIND), OpExpr({std::move(left), std::move(right)}) {} + : left_operand_(std::move(left)), right_operand_(std::move(right)) {} + + std::shared_ptr left_operand_; + std::shared_ptr right_operand_; }; // // Comparison expressions // -template -class CmpOpExpr : public BinaryOpExpr { +Status ValidateCompareOpInputs(const std::shared_ptr& left, + const std::shared_ptr& right); + +template +class ARROW_EXPORT CmpOpExpr : public BinaryOpExpr, public Expr { public: ExprType type() const override { return ExprType::Scalar(boolean()); }; + static Result> Make(std::shared_ptr left, + std::shared_ptr right) { + ARROW_RETURN_NOT_OK(ValidateCompareOpInputs(left, right)); + return std::shared_ptr(new Self(std::move(left), std::move(right))); + } + protected: - using BinaryOpExpr::BinaryOpExpr; + CmpOpExpr(std::shared_ptr left, std::shared_ptr right) + : BinaryOpExpr(std::move(left), std::move(right)), + Expr(expr_traits::kind_id) {} }; -class EqualCmpExpr : public CmpOpExpr { - public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right); - +class ARROW_EXPORT EqualCmpExpr : public CmpOpExpr { protected: - using CmpOpExpr::CmpOpExpr; + using CmpOpExpr::CmpOpExpr; }; -class NotEqualCmpExpr : public CmpOpExpr { - public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right); - +class ARROW_EXPORT NotEqualCmpExpr : public CmpOpExpr { protected: - using CmpOpExpr::CmpOpExpr; + using CmpOpExpr::CmpOpExpr; }; -class GreaterThanCmpExpr : public CmpOpExpr { - public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right); +class ARROW_EXPORT GreaterThanCmpExpr : public CmpOpExpr { + protected: + using CmpOpExpr::CmpOpExpr; +}; +class ARROW_EXPORT GreaterEqualThanCmpExpr : public CmpOpExpr { protected: - using CmpOpExpr::CmpOpExpr; + using CmpOpExpr::CmpOpExpr; }; -class GreaterEqualThanCmpExpr : public CmpOpExpr { - public: - static Result> Make( - std::shared_ptr left, std::shared_ptr right); +class ARROW_EXPORT LessThanCmpExpr : public CmpOpExpr { + protected: + using CmpOpExpr::CmpOpExpr; +}; +class ARROW_EXPORT LessEqualThanCmpExpr : public CmpOpExpr { protected: - using CmpOpExpr::CmpOpExpr; + using CmpOpExpr::CmpOpExpr; }; -class LowerThanCmpExpr : public CmpOpExpr { +// +// Relational Expressions +// + +template +class ARROW_EXPORT RelExpr : public Expr { public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right); + ExprType type() const override { return ExprType::Table(schema_); } + + const std::shared_ptr& schema() const { return schema_; } protected: - using CmpOpExpr::CmpOpExpr; + explicit RelExpr(std::shared_ptr schema) + : Expr(expr_traits::kind_id), schema_(std::move(schema)) {} + + std::shared_ptr schema_; }; -class LowerEqualThanCmpExpr : public CmpOpExpr { +class ARROW_EXPORT EmptyRelExpr : public RelExpr { public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right); + static Result> Make(std::shared_ptr schema); protected: - using CmpOpExpr::CmpOpExpr; + using RelExpr::RelExpr; }; -// -// Relational Expressions -// - -class ScanRelExpr : public Expr { +class ARROW_EXPORT ScanRelExpr : public RelExpr { public: static Result> Make(Catalog::Entry input); - ExprType type() const override; + const Catalog::Entry& input() const { return input_; } private: explicit ScanRelExpr(Catalog::Entry input); @@ -294,19 +368,52 @@ class ScanRelExpr : public Expr { Catalog::Entry input_; }; -class FilterRelExpr : public Expr { +class ARROW_EXPORT FilterRelExpr : public UnaryOpExpr, public RelExpr { public: static Result> Make(std::shared_ptr input, std::shared_ptr predicate); - ExprType type() const override; + const std::shared_ptr& predicate() const { return predicate_; } private: FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate); - std::shared_ptr input_; std::shared_ptr predicate_; }; +template +auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { + switch (expr.kind()) { + case Expr::SCALAR_LITERAL: + return visitor(internal::checked_cast(expr)); + case Expr::FIELD_REFERENCE: + return visitor(internal::checked_cast(expr)); + + case Expr::EQ_CMP_OP: + return visitor(internal::checked_cast(expr)); + case Expr::NE_CMP_OP: + return visitor(internal::checked_cast(expr)); + case Expr::GT_CMP_OP: + return visitor(internal::checked_cast(expr)); + case Expr::GE_CMP_OP: + return visitor(internal::checked_cast(expr)); + case Expr::LT_CMP_OP: + return visitor(internal::checked_cast(expr)); + case Expr::LE_CMP_OP: + return visitor(internal::checked_cast(expr)); + + case Expr::EMPTY_REL: + return visitor(internal::checked_cast(expr)); + case Expr::SCAN_REL: + return visitor(internal::checked_cast(expr)); + case Expr::FILTER_REL: + // LEAVE LAST or update the outer return cast by moving it here. This is + // required for older compiler support. + break; + } + + return visitor(internal::checked_cast(expr)); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 5b980292516..153ee3a999a 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -21,6 +21,8 @@ #include "arrow/testing/gtest_common.h" #include "arrow/type.h" +using testing::HasSubstr; +using testing::Not; using testing::Pointee; namespace arrow { @@ -82,7 +84,6 @@ TEST_F(ExprTest, ScalarExpr) { EXPECT_EQ(expr->kind(), Expr::SCALAR_LITERAL); EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); EXPECT_EQ(*expr->scalar(), *value); - } TEST_F(ExprTest, FieldRefExpr) { @@ -97,7 +98,21 @@ TEST_F(ExprTest, FieldRefExpr) { EXPECT_THAT(expr->field(), IsPtrEqual(f_i32)); } -TEST_F(ExprTest, EqualCmpExpr) { +template +class CmpExprTest : public ExprTest { + public: + Expr::Kind kind() { return expr_traits::kind_id; } + + Result> Make(std::shared_ptr left, + std::shared_ptr right) { + return CmpClass::Make(std::move(left), std::move(right)); + } +}; + +using CompareExprs = ::testing::Types; + +TYPED_TEST_CASE(CmpExprTest, CompareExprs); +TYPED_TEST(CmpExprTest, BasicCompareExpr) { auto i32 = int32(); auto f_i32 = field("i32", i32); @@ -106,18 +121,90 @@ TEST_F(ExprTest, EqualCmpExpr) { ASSERT_OK_AND_ASSIGN(auto s_i32, MakeScalar(i32, 42)); ASSERT_OK_AND_ASSIGN(auto s_expr, ScalarExpr::Make(s_i32)); - ASSERT_RAISES(Invalid, EqualCmpExpr::Make(nullptr, nullptr)); - ASSERT_RAISES(Invalid, EqualCmpExpr::Make(s_expr, nullptr)); - ASSERT_RAISES(Invalid, EqualCmpExpr::Make(nullptr, f_expr)); + // Required fields + ASSERT_RAISES(Invalid, this->Make(nullptr, nullptr)); + ASSERT_RAISES(Invalid, this->Make(s_expr, nullptr)); + ASSERT_RAISES(Invalid, this->Make(nullptr, f_expr)); - ASSERT_OK_AND_ASSIGN(auto expr, EqualCmpExpr::Make(f_expr, s_expr)); + // Not type compatible + ASSERT_OK_AND_ASSIGN(auto s_i64, MakeScalar(int64(), 42L)); + ASSERT_OK_AND_ASSIGN(auto s_expr_i64, ScalarExpr::Make(s_i64)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must be of same type"), + this->Make(s_expr_i64, f_expr)); - EXPECT_EQ(expr->kind(), Expr::EQ_CMP_OP); + ASSERT_OK_AND_ASSIGN(auto expr, this->Make(f_expr, s_expr)); + EXPECT_EQ(expr->kind(), this->kind()); EXPECT_EQ(expr->type(), ExprType::Scalar(boolean())); - /* + EXPECT_TRUE(expr->type().IsPredicate()); + EXPECT_THAT(expr, IsPtrEqual(expr)); EXPECT_THAT(expr->left_operand(), IsPtrEqual(f_expr)); - EXPECT_THAT(expr->right_operand(), Pointee(s_expr)); - */ + EXPECT_THAT(expr->right_operand(), IsPtrEqual(s_expr)); + + ASSERT_OK_AND_ASSIGN(auto other, this->Make(f_expr, s_expr)); + EXPECT_THAT(expr, IsPtrEqual(other)); + // Compare operators supports commutativity + // TODO(fsaintjacques): what about floating point types? + ASSERT_OK_AND_ASSIGN(auto swapped, this->Make(s_expr, f_expr)); + EXPECT_THAT(expr, IsPtrEqual(swapped)); +} + +class RelExprTest : public ExprTest { + protected: + void SetUp() override { + CatalogBuilder builder; + ASSERT_OK(builder.Add(table_1, MockTable(schema_1))); + ASSERT_OK_AND_ASSIGN(catalog, builder.Finish()); + } + + std::string table_1 = "table_1"; + std::shared_ptr schema_1 = schema({field("i32", int32())}); + + std::shared_ptr catalog; +}; + +TEST_F(RelExprTest, EmptyRelExpr) { + ASSERT_RAISES(Invalid, EmptyRelExpr::Make(nullptr)); + + ASSERT_OK_AND_ASSIGN(auto empty, EmptyRelExpr::Make(schema_1)); + EXPECT_THAT(empty->type(), ExprType::Table(schema_1)); + EXPECT_THAT(empty->schema(), IsPtrEqual(schema_1)); + EXPECT_THAT(empty, IsPtrEqual(empty)); + + ASSERT_OK_AND_ASSIGN(auto other, EmptyRelExpr::Make(schema_1)); + EXPECT_THAT(other, IsPtrEqual(empty)); +} + +TEST_F(RelExprTest, ScanRelExpr) { + ASSERT_OK_AND_ASSIGN(auto table, catalog->Get(table_1)); + + ASSERT_OK_AND_ASSIGN(auto scan, ScanRelExpr::Make(table)); + EXPECT_THAT(scan, IsPtrEqual(scan)); + EXPECT_THAT(scan->type(), ExprType::Table(schema_1)); + EXPECT_THAT(scan->schema(), IsPtrEqual(schema_1)); + + ASSERT_OK_AND_ASSIGN(auto other, ScanRelExpr::Make(table)); + EXPECT_THAT(other, IsPtrEqual(scan)); +} + +TEST_F(RelExprTest, FilterRelExpr) { + ASSERT_OK_AND_ASSIGN(auto empty, EmptyRelExpr::Make(schema_1)); + ASSERT_OK_AND_ASSIGN(auto pred, ScalarExpr::Make(MakeScalar(true))); + + ASSERT_RAISES(Invalid, FilterRelExpr::Make(nullptr, nullptr)); + ASSERT_RAISES(Invalid, FilterRelExpr::Make(empty, nullptr)); + ASSERT_RAISES(Invalid, FilterRelExpr::Make(nullptr, pred)); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("input must be a table"), + FilterRelExpr::Make(pred, pred)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("predicate must be a predicate"), + FilterRelExpr::Make(empty, empty)); + + ASSERT_OK_AND_ASSIGN(auto filter, FilterRelExpr::Make(empty, pred)); + EXPECT_THAT(filter, IsPtrEqual(filter)); + EXPECT_THAT(filter->type(), ExprType::Table(schema_1)); + EXPECT_THAT(filter->schema(), IsPtrEqual(schema_1)); + EXPECT_THAT(filter->operand(), IsPtrEqual(empty)); + EXPECT_THAT(filter->predicate(), IsPtrEqual(pred)); } } // namespace engine diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc index ea4d78ae449..93a12bc65ad 100644 --- a/cpp/src/arrow/engine/logical_plan.cc +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -17,6 +17,8 @@ #include "arrow/engine/logical_plan.h" +#include + #include "arrow/engine/expression.h" #include "arrow/result.h" #include "arrow/util/logging.h" diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index e22c48ccd60..2342dd7117e 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -42,7 +42,7 @@ class LogicalPlan : public util::EqualityComparable { public: explicit LogicalPlan(std::shared_ptr root); - std::shared_ptr root() const { return root_; } + const std::shared_ptr& root() const { return root_; } bool Equals(const LogicalPlan& other) const; std::string ToString() const; diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h new file mode 100644 index 00000000000..a497866ba85 --- /dev/null +++ b/cpp/src/arrow/engine/type_fwd.h @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow { +namespace engine { + +class ExprType; + +class Expr; +class ScalarExpr; +class FieldRefExpr; +class EqualCmpExpr; +class NotEqualCmpExpr; +class GreaterThanCmpExpr; +class GreaterEqualThanCmpExpr; +class LessThanCmpExpr; +class LessEqualThanCmpExpr; +class EmptyRelExpr; +class ScanRelExpr; +class FilterRelExpr; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h index b352b114a25..2982ddbfb63 100644 --- a/cpp/src/arrow/engine/type_traits.h +++ b/cpp/src/arrow/engine/type_traits.h @@ -17,19 +17,23 @@ #pragma once -#include -#include -#include -#include - #include "arrow/engine/expression.h" +#include "arrow/type_traits.h" namespace arrow { namespace engine { +template +using is_compare_expr = std::is_base_of, E>; + +template +using enable_if_compare_expr = enable_if_t::value, Ret>; -using Op = OpExpr; +template +using is_relational_expr = std::is_base_of, E>; +template +using enable_if_relational_expr = enable_if_t::value, Ret>; -} -} +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/testing/gmock.h b/cpp/src/arrow/testing/gmock.h new file mode 100644 index 00000000000..f570f586031 --- /dev/null +++ b/cpp/src/arrow/testing/gmock.h @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +namespace arrow { + +MATCHER_P(IsEqual, other, "") { return arg.Equals(other); } + +MATCHER_P(IsPtrEqual, other, "") { return arg->Equals(*other); } + +} // namespace arrow From 77c5bf9de4065ed008afeaa8019b3e472722db07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 28 Feb 2020 10:07:17 -0500 Subject: [PATCH 04/21] Add key in Catalog::Entry --- cpp/src/arrow/engine/catalog.cc | 55 +++++++++++++++------------- cpp/src/arrow/engine/catalog.h | 35 ++++++++---------- cpp/src/arrow/engine/catalog_test.cc | 6 +-- cpp/src/arrow/engine/expression.cc | 2 +- 4 files changed, 49 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/engine/catalog.cc b/cpp/src/arrow/engine/catalog.cc index be8f73e551c..fa1f21f4c90 100644 --- a/cpp/src/arrow/engine/catalog.cc +++ b/cpp/src/arrow/engine/catalog.cc @@ -30,22 +30,25 @@ namespace engine { // Catalog // -Catalog::Catalog(std::unordered_map tables) : tables_(std::move(tables)) {} +using Entry = Catalog::Entry; + +Catalog::Catalog(std::unordered_map tables) + : tables_(std::move(tables)) {} -Result Catalog::Get(const Key& key) const { +Result Catalog::Get(const std::string& key) const { auto value = tables_.find(key); if (value != tables_.end()) return value->second; return Status::KeyError("Table '", key, "' not found in catalog."); } -Result> Catalog::GetSchema(const Key& key) const { - auto as_schema = [](const Value& v) -> Result> { - return v.schema(); +Result> Catalog::GetSchema(const std::string& key) const { + auto as_schema = [](const Entry& entry) -> Result> { + return entry.schema(); }; return Get(key).Map(as_schema); } -Result> Catalog::Make(const std::vector& tables) { +Result> Catalog::Make(const std::vector& tables) { CatalogBuilder builder; for (const auto& key_val : tables) { @@ -59,10 +62,11 @@ Result> Catalog::Make(const std::vector& tabl // Catalog::Entry // -using Entry = Catalog::Entry; +Entry::Entry(std::shared_ptr
table, std::string name) + : entry_(std::move(table)), name_(std::move(name)) {} -Entry::Entry(std::shared_ptr
table) : entry_(std::move(table)) {} -Entry::Entry(std::shared_ptr dataset) : entry_(std::move(dataset)) {} +Entry::Entry(std::shared_ptr dataset, std::string name) + : entry_(std::move(dataset)), name_(std::move(name)) {} Entry::Kind Entry::kind() const { if (util::holds_alternative>(entry_)) { @@ -86,7 +90,11 @@ std::shared_ptr Entry::dataset() const { return nullptr; } -bool Entry::operator==(const Entry& other) const { return entry_ == other.entry_; } +bool Entry::operator==(const Entry& other) const { + // Entries are unique by name in a catalog, but we can still protect with + // pointer equality. + return name_ == other.name_ && entry_ == other.entry_; +} std::shared_ptr Entry::schema() const { switch (kind()) { @@ -105,20 +113,21 @@ std::shared_ptr Entry::schema() const { // CatalogBuilder // -Status CatalogBuilder::Add(const Key& key, const Value& value) { - if (key.empty()) { +Status CatalogBuilder::Add(Entry entry) { + const auto& name = entry.name(); + if (name.empty()) { return Status::Invalid("Key in catalog can't be empty"); } - switch (value.kind()) { + switch (entry.kind()) { case Entry::TABLE: { - if (value.table() == nullptr) { + if (entry.table() == nullptr) { return Status::Invalid("Table entry can't be null."); } break; } case Entry::DATASET: { - if (value.dataset() == nullptr) { + if (entry.dataset() == nullptr) { return Status::Invalid("Table entry can't be null."); } break; @@ -127,24 +136,20 @@ Status CatalogBuilder::Add(const Key& key, const Value& value) { return Status::NotImplemented("Unknown entry kind"); } - auto inserted = tables_.insert({key, value}); + auto inserted = tables_.emplace(name, std::move(entry)); if (!inserted.second) { - return Status::KeyError("Table '", key, "' already in catalog."); + return Status::KeyError("Table '", name, "' already in catalog."); } return Status::OK(); } -Status CatalogBuilder::Add(const Key& key, std::shared_ptr
table) { - return Add(key, Entry(std::move(table))); -} - -Status CatalogBuilder::Add(const Key& key, std::shared_ptr dataset) { - return Add(key, Entry(std::move(dataset))); +Status CatalogBuilder::Add(std::string name, std::shared_ptr
table) { + return Add(Entry(std::move(table), std::move(name))); } -Status CatalogBuilder::Add(const KeyValue& key_value) { - return Add(key_value.first, key_value.second); +Status CatalogBuilder::Add(std::string name, std::shared_ptr dataset) { + return Add(Entry(std::move(dataset), std::move(name))); } Result> CatalogBuilder::Finish() { diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h index e982862d940..25fbc6fb652 100644 --- a/cpp/src/arrow/engine/catalog.h +++ b/cpp/src/arrow/engine/catalog.h @@ -39,14 +39,10 @@ class Catalog { public: class Entry; - using Key = std::string; - using Value = Entry; - using KeyValue = std::pair; + static Result> Make(const std::vector& tables); - static Result> Make(const std::vector& tables); - - Result Get(const Key& name) const; - Result> GetSchema(const Key& name) const; + Result Get(const std::string& name) const; + Result> GetSchema(const std::string& name) const; class Entry { public: @@ -56,10 +52,13 @@ class Catalog { UNKNOWN, }; - explicit Entry(std::shared_ptr
table); - explicit Entry(std::shared_ptr dataset); + Entry(std::shared_ptr
table, std::string name); + Entry(std::shared_ptr dataset, std::string name); Kind kind() const; + + const std::string& name() const { return name_; } + std::shared_ptr
table() const; std::shared_ptr dataset() const; @@ -69,30 +68,26 @@ class Catalog { private: util::variant, std::shared_ptr> entry_; + std::string name_; }; private: friend class CatalogBuilder; - explicit Catalog(std::unordered_map tables); + explicit Catalog(std::unordered_map tables); - std::unordered_map tables_; + std::unordered_map tables_; }; class CatalogBuilder { public: - using Key = Catalog::Key; - using Value = Catalog::Value; - using KeyValue = Catalog::KeyValue; - - Status Add(const Key& key, const Value& value); - Status Add(const Key& key, std::shared_ptr
); - Status Add(const Key& key, std::shared_ptr); - Status Add(const KeyValue& key_value); + Status Add(Catalog::Entry entry); + Status Add(std::string name, std::shared_ptr
); + Status Add(std::string name, std::shared_ptr); Result> Finish(); private: - std::unordered_map tables_; + std::unordered_map tables_; }; } // namespace engine diff --git a/cpp/src/arrow/engine/catalog_test.cc b/cpp/src/arrow/engine/catalog_test.cc index 4542d049b98..d1fc42de424 100644 --- a/cpp/src/arrow/engine/catalog_test.cc +++ b/cpp/src/arrow/engine/catalog_test.cc @@ -38,7 +38,7 @@ class TestCatalog : public testing::Test { std::shared_ptr
table() const { return table(schema_); } }; -void AssertCatalogKeyIs(const std::shared_ptr& catalog, const Catalog::Key& key, +void AssertCatalogKeyIs(const std::shared_ptr& catalog, const std::string& key, const std::shared_ptr
& expected) { ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); ASSERT_EQ(t.kind(), Catalog::Entry::Kind::TABLE); @@ -62,8 +62,8 @@ TEST_F(TestCatalog, Make) { auto key_3 = "c"; auto table_3 = table(schema({field(key_3, int32())})); - std::vector tables{ - {key_1, Entry(table_1)}, {key_2, Entry(table_2)}, {key_3, Entry(table_3)}}; + std::vector tables{Entry(table_1, key_1), Entry(table_2, key_2), + Entry(table_3, key_3)}; ASSERT_OK_AND_ASSIGN(auto catalog, Catalog::Make(std::move(tables))); AssertCatalogKeyIs(catalog, key_1, table_1); diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 2b07cd01668..2a167f075ef 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -230,7 +230,7 @@ Status ValidateCompareOpInputs(const std::shared_ptr& left, } // -// +// EmptyRelExpr // Result> EmptyRelExpr::Make(std::shared_ptr schema) { From 1b5931229f0621596f6a06a4300e5ee2ff0a7af1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 28 Feb 2020 13:42:06 -0500 Subject: [PATCH 05/21] Update --- cpp/src/arrow/engine/expression.cc | 117 ++++++++++++++++-------- cpp/src/arrow/engine/expression.h | 88 ++++++++++++------ cpp/src/arrow/engine/expression_test.cc | 79 +++++++++++----- cpp/src/arrow/testing/gmock.h | 7 +- 4 files changed, 202 insertions(+), 89 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 2a167f075ef..aaed724fb87 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -41,29 +41,46 @@ ExprType ExprType::Table(std::shared_ptr schema) { } ExprType::ExprType(std::shared_ptr schema, Shape shape) - : type_(std::move(schema)), shape_(shape) { + : schema_(std::move(schema)), shape_(shape) { DCHECK_EQ(shape, Shape::TABLE); } ExprType::ExprType(std::shared_ptr type, Shape shape) - : type_(std::move(type)), shape_(shape) { + : data_type_(std::move(type)), shape_(shape) { DCHECK_NE(shape, Shape::TABLE); } -std::shared_ptr ExprType::schema() const { - if (shape_ == TABLE) { - return util::get>(type_); +ExprType::ExprType(const ExprType& other) : shape_(other.shape()) { + switch (other.shape()) { + case SCALAR: + case ARRAY: + data_type_ = other.data_type(); + break; + case TABLE: + schema_ = other.schema(); } - - return nullptr; } -std::shared_ptr ExprType::data_type() const { - if (shape_ != TABLE) { - return util::get>(type_); +ExprType::ExprType(ExprType&& other) : shape_(other.shape()) { + switch (other.shape()) { + case SCALAR: + case ARRAY: + data_type_ = std::move(other.data_type()); + break; + case TABLE: + schema_ = std::move(other.schema()); } +} - return nullptr; +ExprType::~ExprType() { + switch (shape()) { + case SCALAR: + case ARRAY: + data_type_.reset(); + break; + case TABLE: + schema_.reset(); + } } bool ExprType::Equals(const ExprType& type) const { @@ -89,6 +106,51 @@ bool ExprType::Equals(const ExprType& type) const { return false; } +Result ExprType::CastTo(const std::shared_ptr& data_type) const { + switch (shape()) { + case SCALAR: + return ExprType::Scalar(data_type); + case ARRAY: + return ExprType::Array(data_type); + case TABLE: + return Status::Invalid("Cannot cast a TableType with a DataType"); + } + + return Status::UnknownError("unreachable"); +} +Result ExprType::CastTo(const std::shared_ptr& schema) const { + switch (shape()) { + case SCALAR: + return Status::Invalid("Cannot cast a ScalarType with a schema"); + case ARRAY: + return Status::Invalid("Cannot cast an ArrayType with a schema"); + case TABLE: + return ExprType::Table(schema); + } + + return Status::UnknownError("unreachable"); +} + +Result ExprType::Broadcast(const ExprType& lhs, const ExprType& rhs) { + if (lhs.IsTable() || rhs.IsTable()) { + return Status::Invalid("Broadcast operands must not be tables"); + } + + if (!lhs.data_type()->Equals(rhs.data_type())) { + return Status::Invalid("Broadcast operands must be of same type"); + } + + if (lhs.IsArray()) { + return lhs; + } + + if (rhs.IsArray()) { + return rhs; + } + + return lhs; +} + #define ERROR_IF(cond, ...) \ do { \ if (ARROW_PREDICT_FALSE(cond)) { \ @@ -183,52 +245,32 @@ bool Expr::Equals(const Expr& other) const { return ExprEqualityVisitor::Visit(*this, other); } +std::string Expr::ToString() const { return ""; } + // // ScalarExpr // ScalarExpr::ScalarExpr(std::shared_ptr scalar) - : Expr(SCALAR_LITERAL), scalar_(std::move(scalar)) {} + : Expr(SCALAR_LITERAL, ExprType::Scalar(scalar->type)), scalar_(std::move(scalar)) {} Result> ScalarExpr::Make(std::shared_ptr scalar) { ERROR_IF(scalar == nullptr, "ScalarExpr's scalar must be non-null"); - return std::shared_ptr(new ScalarExpr(std::move(scalar))); } -ExprType ScalarExpr::type() const { return ExprType::Scalar(scalar_->type); } - // // FieldRefExpr // -FieldRefExpr::FieldRefExpr(std::shared_ptr field) - : Expr(FIELD_REFERENCE), field_(std::move(field)) {} +FieldRefExpr::FieldRefExpr(std::shared_ptr f) + : Expr(FIELD_REFERENCE, ExprType::Array(f->type())), field_(std::move(f)) {} Result> FieldRefExpr::Make(std::shared_ptr field) { ERROR_IF(field == nullptr, "FieldRefExpr's field must be non-null"); - return std::shared_ptr(new FieldRefExpr(std::move(field))); } -ExprType FieldRefExpr::type() const { return ExprType::Scalar(field_->type()); } - -// -// Comparisons -// - -Status ValidateCompareOpInputs(const std::shared_ptr& left, - const std::shared_ptr& right) { - ERROR_IF(left == nullptr, "EqualCmpExpr's left operand must be non-null"); - ERROR_IF(right == nullptr, "EqualCmpExpr's right operand must be non-null"); - - // TODO(fsaintjacques): Add support for broadcast. - ERROR_IF(left->type() != right->type(), - "Compare operator operands must be of same type."); - - return Status::OK(); -} - // // EmptyRelExpr // @@ -261,9 +303,6 @@ Result> FilterRelExpr::Make( ERROR_IF(!predicate->type().IsPredicate(), "FilterRelExpr's predicate must be a predicate"); - // TODO(fsaintjacques): check fields referenced in predicate are found in - // input. - return std::shared_ptr( new FilterRelExpr(std::move(input), std::move(predicate))); } diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index e799ac9e56e..8a49332c81a 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -23,6 +23,7 @@ #include "arrow/engine/catalog.h" #include "arrow/engine/type_fwd.h" +#include "arrow/result.h" #include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" @@ -66,9 +67,11 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { Shape shape() const { return shape_; } /// \brief DataType of the expression if a scalar or an array. - std::shared_ptr data_type() const; + /// WARNING: You must ensure the proper shape before calling this accessor. + const std::shared_ptr& data_type() const { return data_type_; } /// \brief Schema of the expression if of table shape. - std::shared_ptr schema() const; + /// WARNING: You must ensure the proper shape before calling this accessor. + const std::shared_ptr& schema() const { return schema_; } /// \brief Indicate if the type is a Scalar. bool IsScalar() const { return shape_ == SCALAR; } @@ -78,25 +81,57 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { bool IsTable() const { return shape_ == TABLE; } template - bool HasType() const { - return (shape_ == SCALAR || shape_ == ARRAY) && - util::get>(type_)->id() == TYPE_ID; + bool IsTypedLike() const { + return (IsScalar() || IsArray()) && data_type_->id() == TYPE_ID; } /// \brief Indicate if the type is a predicate, i.e. a boolean scalar. - bool IsPredicate() const { return IsScalar() && HasType(); } + bool IsPredicate() const { return IsTypedLike(); } + + /// \brief Cast to DataType/Schema if the shape allows it. + Result CastTo(const std::shared_ptr& data_type) const; + Result CastTo(const std::shared_ptr& schema) const; + + /// \brief Broadcasting align two types to the largest shape. + /// + /// \param[in] lhs, first type to broadcast + /// \param[in] rhs, second type to broadcast + /// \return broadcasted type or an error why it can't be broadcasted. + /// + /// Broadcasting promotes the shape of the smallest type to the bigger one if + /// they share the same DataType. In functional pattern matching it would look + /// like: + /// + /// ``` + /// Broadcast(rhs, lhs) = match(lhs, rhs) { + /// case: ScalarType(t1), ScalarType(t2) if t1 == t2 => ScalarType(t) + /// case: ScalarType(t1), ArrayType(t2) if t1 == t2 => ArrayType(t) + /// case: ArrayType(t1), ScalarType(t2) if t1 == t2 => ArrayType(t) + /// case: ArrayType(t1), ArrayType(t2) if t1 == t2 => ArrayType(t) + /// case: _ => Error("Types not compatible for broadcasting") + /// } + /// ``` + static Result Broadcast(const ExprType& lhs, const ExprType& rhs); bool Equals(const ExprType& type) const; std::string ToString() const; + ExprType(const ExprType& copy); + ExprType(ExprType&& copy); + ~ExprType(); + private: /// Table constructor ExprType(std::shared_ptr schema, Shape shape); /// Scalar or Array constructor ExprType(std::shared_ptr type, Shape shape); - util::variant, std::shared_ptr> type_; + union { + // Zero initialize the pointer or Copy/Assign constructors will fail. + std::shared_ptr data_type_{}; + std::shared_ptr schema_; + }; Shape shape_; }; @@ -104,7 +139,7 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { class ARROW_EXPORT Expr : public util::EqualityComparable { public: // Tag identifier for the expression type. - enum Kind { + enum Kind : uint8_t { // A Scalar literal, i.e. a constant. SCALAR_LITERAL, // A Field reference in a schema. @@ -137,7 +172,7 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { std::string kind_name() const; /// \brief Return the type and shape of the resulting expression. - virtual ExprType type() const = 0; + const ExprType& type() const { return type_; } /// \brief Indicate if the expressions bool Equals(const Expr& other) const; @@ -149,8 +184,9 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { virtual ~Expr() = default; protected: - explicit Expr(Kind kind) : kind_(kind) {} + explicit Expr(Kind kind, ExprType type) : type_(std::move(type)), kind_(kind) {} + ExprType type_; Kind kind_; }; @@ -226,8 +262,6 @@ class ARROW_EXPORT ScalarExpr : public Expr { const std::shared_ptr& scalar() const { return scalar_; } - ExprType type() const override; - private: explicit ScalarExpr(std::shared_ptr scalar); @@ -241,8 +275,6 @@ class ARROW_EXPORT FieldRefExpr : public Expr { const std::shared_ptr& field() const { return field_; } - ExprType type() const override; - private: explicit FieldRefExpr(std::shared_ptr field); @@ -280,24 +312,29 @@ class ARROW_EXPORT BinaryOpExpr { // Comparison expressions // -Status ValidateCompareOpInputs(const std::shared_ptr& left, - const std::shared_ptr& right); - template class ARROW_EXPORT CmpOpExpr : public BinaryOpExpr, public Expr { public: - ExprType type() const override { return ExprType::Scalar(boolean()); }; - static Result> Make(std::shared_ptr left, std::shared_ptr right) { - ARROW_RETURN_NOT_OK(ValidateCompareOpInputs(left, right)); - return std::shared_ptr(new Self(std::move(left), std::move(right))); + if (left == NULLPTR || right == NULLPTR) { + return Status::Invalid("Compare operands must be non-nulls"); + } + + // Broadcast ensures that types are compatible in shape and type. + auto broadcast = ExprType::Broadcast(left->type(), right->type()); + // The type of comparison is always a boolean predicate. + auto cast = [](const ExprType& t) { return t.CastTo(boolean()); }; + ARROW_ASSIGN_OR_RAISE(auto type, broadcast.Map(cast)); + + return std::shared_ptr( + new Self(std::move(type), std::move(left), std::move(right))); } protected: - CmpOpExpr(std::shared_ptr left, std::shared_ptr right) + CmpOpExpr(ExprType type, std::shared_ptr left, std::shared_ptr right) : BinaryOpExpr(std::move(left), std::move(right)), - Expr(expr_traits::kind_id) {} + Expr(expr_traits::kind_id, std::move(type)) {} }; class ARROW_EXPORT EqualCmpExpr : public CmpOpExpr { @@ -337,13 +374,12 @@ class ARROW_EXPORT LessEqualThanCmpExpr : public CmpOpExpr template class ARROW_EXPORT RelExpr : public Expr { public: - ExprType type() const override { return ExprType::Table(schema_); } - const std::shared_ptr& schema() const { return schema_; } protected: explicit RelExpr(std::shared_ptr schema) - : Expr(expr_traits::kind_id), schema_(std::move(schema)) {} + : Expr(expr_traits::kind_id, ExprType::Table(schema)), + schema_(std::move(schema)) {} std::shared_ptr schema_; }; diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 153ee3a999a..557e8a0d61b 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -37,7 +37,6 @@ TEST_F(ExprTypeTest, Basic) { auto scalar = ExprType::Scalar(i32); EXPECT_EQ(scalar.shape(), ExprType::Shape::SCALAR); EXPECT_TRUE(scalar.data_type()->Equals(i32)); - EXPECT_EQ(scalar.schema(), nullptr); EXPECT_TRUE(scalar.IsScalar()); EXPECT_FALSE(scalar.IsArray()); EXPECT_FALSE(scalar.IsTable()); @@ -45,14 +44,12 @@ TEST_F(ExprTypeTest, Basic) { auto array = ExprType::Array(i32); EXPECT_EQ(array.shape(), ExprType::Shape::ARRAY); EXPECT_TRUE(array.data_type()->Equals(i32)); - EXPECT_EQ(array.schema(), nullptr); EXPECT_FALSE(array.IsScalar()); EXPECT_TRUE(array.IsArray()); EXPECT_FALSE(array.IsTable()); auto table = ExprType::Table(s); EXPECT_EQ(table.shape(), ExprType::Shape::TABLE); - EXPECT_EQ(table.data_type(), nullptr); EXPECT_TRUE(table.schema()->Equals(s)); EXPECT_FALSE(table.IsScalar()); EXPECT_FALSE(table.IsArray()); @@ -64,7 +61,7 @@ TEST_F(ExprTypeTest, IsPredicate) { EXPECT_TRUE(bool_scalar.IsPredicate()); auto bool_array = ExprType::Array(boolean()); - EXPECT_FALSE(bool_array.IsPredicate()); + EXPECT_TRUE(bool_array.IsPredicate()); auto bool_table = ExprType::Table(schema({field("b", boolean())})); EXPECT_FALSE(bool_table.IsPredicate()); @@ -73,6 +70,43 @@ TEST_F(ExprTypeTest, IsPredicate) { EXPECT_FALSE(i32_scalar.IsPredicate()); } +TEST_F(ExprTypeTest, Broadcast) { + auto bool_scalar = ExprType::Scalar(boolean()); + auto bool_array = ExprType::Array(boolean()); + auto bool_table = ExprType::Table(schema({field("b", boolean())})); + auto i32_scalar = ExprType::Scalar(int32()); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must be of same type"), + ExprType::Broadcast(bool_scalar, i32_scalar)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must not be tables"), + ExprType::Broadcast(bool_scalar, bool_table)); + + EXPECT_THAT(ExprType::Broadcast(bool_scalar, bool_scalar), OkAndEq(bool_scalar)); + EXPECT_THAT(ExprType::Broadcast(bool_scalar, bool_array), OkAndEq(bool_array)); + EXPECT_THAT(ExprType::Broadcast(bool_array, bool_scalar), OkAndEq(bool_array)); + EXPECT_THAT(ExprType::Broadcast(bool_array, bool_array), OkAndEq(bool_array)); +} + +TEST_F(ExprTypeTest, CastTo) { + auto bool_scalar = ExprType::Scalar(boolean()); + auto bool_array = ExprType::Array(boolean()); + auto bool_table = ExprType::Table(schema({field("b", boolean())})); + + auto i32 = int32(); + auto other = schema({field("a", i32)}); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast a ScalarType with"), + bool_scalar.CastTo(other)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast an ArrayType with"), + bool_array.CastTo(other)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast a TableType with"), + bool_table.CastTo(i32)); + + EXPECT_EQ(bool_scalar.CastTo(i32), ExprType::Scalar(i32)); + EXPECT_EQ(bool_array.CastTo(i32), ExprType::Array(i32)); + EXPECT_EQ(bool_table.CastTo(other), ExprType::Table(other)); +} + class ExprTest : public testing::Test {}; TEST_F(ExprTest, ScalarExpr) { @@ -94,8 +128,8 @@ TEST_F(ExprTest, FieldRefExpr) { ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(f_i32)); EXPECT_EQ(expr->kind(), Expr::FIELD_REFERENCE); - EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); - EXPECT_THAT(expr->field(), IsPtrEqual(f_i32)); + EXPECT_EQ(expr->type(), ExprType::Array(i32)); + EXPECT_THAT(expr->field(), PtrEquals(f_i32)); } template @@ -134,18 +168,19 @@ TYPED_TEST(CmpExprTest, BasicCompareExpr) { ASSERT_OK_AND_ASSIGN(auto expr, this->Make(f_expr, s_expr)); EXPECT_EQ(expr->kind(), this->kind()); - EXPECT_EQ(expr->type(), ExprType::Scalar(boolean())); + // Ensure type is broadcasted + EXPECT_EQ(expr->type(), ExprType::Array(boolean())); EXPECT_TRUE(expr->type().IsPredicate()); - EXPECT_THAT(expr, IsPtrEqual(expr)); - EXPECT_THAT(expr->left_operand(), IsPtrEqual(f_expr)); - EXPECT_THAT(expr->right_operand(), IsPtrEqual(s_expr)); + EXPECT_THAT(expr, PtrEquals(expr)); + EXPECT_THAT(expr->left_operand(), PtrEquals(f_expr)); + EXPECT_THAT(expr->right_operand(), PtrEquals(s_expr)); ASSERT_OK_AND_ASSIGN(auto other, this->Make(f_expr, s_expr)); - EXPECT_THAT(expr, IsPtrEqual(other)); + EXPECT_THAT(expr, PtrEquals(other)); // Compare operators supports commutativity // TODO(fsaintjacques): what about floating point types? ASSERT_OK_AND_ASSIGN(auto swapped, this->Make(s_expr, f_expr)); - EXPECT_THAT(expr, IsPtrEqual(swapped)); + EXPECT_THAT(expr, PtrEquals(swapped)); } class RelExprTest : public ExprTest { @@ -167,23 +202,23 @@ TEST_F(RelExprTest, EmptyRelExpr) { ASSERT_OK_AND_ASSIGN(auto empty, EmptyRelExpr::Make(schema_1)); EXPECT_THAT(empty->type(), ExprType::Table(schema_1)); - EXPECT_THAT(empty->schema(), IsPtrEqual(schema_1)); - EXPECT_THAT(empty, IsPtrEqual(empty)); + EXPECT_THAT(empty->schema(), PtrEquals(schema_1)); + EXPECT_THAT(empty, PtrEquals(empty)); ASSERT_OK_AND_ASSIGN(auto other, EmptyRelExpr::Make(schema_1)); - EXPECT_THAT(other, IsPtrEqual(empty)); + EXPECT_THAT(other, PtrEquals(empty)); } TEST_F(RelExprTest, ScanRelExpr) { ASSERT_OK_AND_ASSIGN(auto table, catalog->Get(table_1)); ASSERT_OK_AND_ASSIGN(auto scan, ScanRelExpr::Make(table)); - EXPECT_THAT(scan, IsPtrEqual(scan)); + EXPECT_THAT(scan, PtrEquals(scan)); EXPECT_THAT(scan->type(), ExprType::Table(schema_1)); - EXPECT_THAT(scan->schema(), IsPtrEqual(schema_1)); + EXPECT_THAT(scan->schema(), PtrEquals(schema_1)); ASSERT_OK_AND_ASSIGN(auto other, ScanRelExpr::Make(table)); - EXPECT_THAT(other, IsPtrEqual(scan)); + EXPECT_THAT(other, PtrEquals(scan)); } TEST_F(RelExprTest, FilterRelExpr) { @@ -200,11 +235,11 @@ TEST_F(RelExprTest, FilterRelExpr) { FilterRelExpr::Make(empty, empty)); ASSERT_OK_AND_ASSIGN(auto filter, FilterRelExpr::Make(empty, pred)); - EXPECT_THAT(filter, IsPtrEqual(filter)); + EXPECT_THAT(filter, PtrEquals(filter)); EXPECT_THAT(filter->type(), ExprType::Table(schema_1)); - EXPECT_THAT(filter->schema(), IsPtrEqual(schema_1)); - EXPECT_THAT(filter->operand(), IsPtrEqual(empty)); - EXPECT_THAT(filter->predicate(), IsPtrEqual(pred)); + EXPECT_THAT(filter->schema(), PtrEquals(schema_1)); + EXPECT_THAT(filter->operand(), PtrEquals(empty)); + EXPECT_THAT(filter->predicate(), PtrEquals(pred)); } } // namespace engine diff --git a/cpp/src/arrow/testing/gmock.h b/cpp/src/arrow/testing/gmock.h index f570f586031..5b2f2f19502 100644 --- a/cpp/src/arrow/testing/gmock.h +++ b/cpp/src/arrow/testing/gmock.h @@ -21,8 +21,11 @@ namespace arrow { -MATCHER_P(IsEqual, other, "") { return arg.Equals(other); } +using testing::Eq; +using testing::HasSubstr; -MATCHER_P(IsPtrEqual, other, "") { return arg->Equals(*other); } +MATCHER_P(Equals, other, "") { return arg.Equals(other); } +MATCHER_P(PtrEquals, other, "") { return arg->Equals(*other); } +MATCHER_P(OkAndEq, other, "") { return arg.ok() && arg.ValueOrDie() == other; } } // namespace arrow From 92bf42597d57f348e965fa332fce6ec0a9ccc651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 28 Feb 2020 16:44:03 -0500 Subject: [PATCH 06/21] Add projection operator --- cpp/src/arrow/engine/expression.cc | 52 +++++++++++++ cpp/src/arrow/engine/expression.h | 35 +++++++++ cpp/src/arrow/engine/expression_test.cc | 4 + cpp/src/arrow/engine/logical_plan.cc | 88 ++++++++++++++++++++-- cpp/src/arrow/engine/logical_plan.h | 44 +++++++++-- cpp/src/arrow/engine/logical_plan_test.cc | 91 ++++++++++++++++++----- cpp/src/arrow/engine/type_fwd.h | 1 + cpp/src/arrow/engine/type_traits.h | 2 +- 8 files changed, 285 insertions(+), 32 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index aaed724fb87..9f6a5e98f50 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -186,6 +186,8 @@ std::string Expr::kind_name() const { return "empty_rel"; case Expr::SCAN_REL: return "scan_rel"; + case Expr::PROJECTION_REL: + return "projection_rel"; case Expr::FILTER_REL: return "filter_rel"; } @@ -224,6 +226,24 @@ struct ExprEqualityVisitor { return lhs_scan.input() == rhs.input(); } + bool operator()(const ProjectionRelExpr& rhs) const { + auto lhs_proj = internal::checked_cast(lhs); + + const auto& lhs_exprs = lhs_proj.expressions(); + const auto& rhs_exprs = rhs.expressions(); + if (lhs_exprs.size() != rhs_exprs.size()) { + return false; + } + + for (size_t i = 0; i < lhs_exprs.size(); i++) { + if (!lhs_exprs[i]->Equals(rhs_exprs[i])) { + return false; + } + } + + return true; + } + bool operator()(const Expr&) const { return false; } static bool Visit(const Expr& lhs, const Expr& rhs) { @@ -291,6 +311,38 @@ Result> ScanRelExpr::Make(Catalog::Entry input) { return std::shared_ptr(new ScanRelExpr(std::move(input))); } +// +// ProjectionRelExpr +// + +ProjectionRelExpr::ProjectionRelExpr(std::shared_ptr input, + std::shared_ptr schema, + std::vector> expressions) + : UnaryOpExpr(std::move(input)), + RelExpr(std::move(schema)), + expressions_(std::move(expressions)) {} + +Result> ProjectionRelExpr::Make( + std::shared_ptr input, std::vector> expressions) { + ERROR_IF(input == nullptr, "ProjectionRelExpr's input must be non-null."); + ERROR_IF(expressions.empty(), "Must project at least one column."); + + auto n_fields = expressions.size(); + std::vector> fields; + + for (size_t i = 0; i < n_fields; i++) { + const auto& expr = expressions[i]; + const auto& type = expr->type(); + ERROR_IF(!type.IsArray(), "Expression at position ", i, " not of Array type"); + // TODO(fsaintjacques): better name handling. Callers should be able to + // pass a vector of names. + fields.push_back(field("expr", type.data_type())); + } + + return std::shared_ptr(new ProjectionRelExpr( + std::move(input), arrow::schema(std::move(fields)), std::move(expressions))); +} + // // FilterRelExpr // diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 8a49332c81a..75028d89da6 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -162,6 +162,8 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { EMPTY_REL, // Scan relational operator SCAN_REL, + // Projection relational operator + PROJECTION_REL, // Filter relational operator FILTER_REL, }; @@ -246,6 +248,11 @@ struct expr_traits { static constexpr Expr::Kind kind_id = Expr::SCAN_REL; }; +template <> +struct expr_traits { + static constexpr Expr::Kind kind_id = Expr::PROJECTION_REL; +}; + template <> struct expr_traits { static constexpr Expr::Kind kind_id = Expr::FILTER_REL; @@ -308,6 +315,17 @@ class ARROW_EXPORT BinaryOpExpr { std::shared_ptr right_operand_; }; +class ARROW_EXPORT MultiAryOpExpr { + public: + const std::vector>& operands() const { return operands_; } + + protected: + explicit MultiAryOpExpr(std::vector> operands) + : operands_(std::move(operands)) {} + + std::vector> operands_; +}; + // // Comparison expressions // @@ -404,6 +422,21 @@ class ARROW_EXPORT ScanRelExpr : public RelExpr { Catalog::Entry input_; }; +class ARROW_EXPORT ProjectionRelExpr : public UnaryOpExpr, + public RelExpr { + public: + static Result> Make( + std::shared_ptr input, std::vector> expressions); + + const std::vector> expressions() const { return expressions_; } + + private: + ProjectionRelExpr(std::shared_ptr input, std::shared_ptr schema, + std::vector> expressions); + + std::vector> expressions_; +}; + class ARROW_EXPORT FilterRelExpr : public UnaryOpExpr, public RelExpr { public: static Result> Make(std::shared_ptr input, @@ -442,6 +475,8 @@ auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { return visitor(internal::checked_cast(expr)); case Expr::SCAN_REL: return visitor(internal::checked_cast(expr)); + case Expr::PROJECTION_REL: + return visitor(internal::checked_cast(expr)); case Expr::FILTER_REL: // LEAVE LAST or update the outer return cast by moving it here. This is // required for older compiler support. diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 557e8a0d61b..79b37e51b8c 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -221,6 +221,10 @@ TEST_F(RelExprTest, ScanRelExpr) { EXPECT_THAT(other, PtrEquals(scan)); } +TEST_F(RelExprTest, ProjectionRelExpr) { + // TODO(fsaintjacques): FILLME +} + TEST_F(RelExprTest, FilterRelExpr) { ASSERT_OK_AND_ASSIGN(auto empty, EmptyRelExpr::Make(schema_1)); ASSERT_OK_AND_ASSIGN(auto pred, ScalarExpr::Make(MakeScalar(true))); diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc index 93a12bc65ad..b9fc01946c3 100644 --- a/cpp/src/arrow/engine/logical_plan.cc +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -21,6 +21,7 @@ #include "arrow/engine/expression.h" #include "arrow/result.h" +#include "arrow/type.h" #include "arrow/util/logging.h" namespace arrow { @@ -34,6 +35,8 @@ LogicalPlan::LogicalPlan(std::shared_ptr root) : root_(std::move(root)) { DCHECK_NE(root_, nullptr); } +const ExprType& LogicalPlan::type() const { return root()->type(); } + bool LogicalPlan::Equals(const LogicalPlan& other) const { if (this == &other) { return true; @@ -53,13 +56,14 @@ LogicalPlanBuilder::LogicalPlanBuilder(LogicalPlanBuilderOptions options) using ResultExpr = LogicalPlanBuilder::ResultExpr; -#define ERROR_IF(cond, ...) \ - do { \ - if (ARROW_PREDICT_FALSE(cond)) { \ - return Status::Invalid(__VA_ARGS__); \ - } \ +#define ERROR_IF_TYPE(cond, ErrorType, ...) \ + do { \ + if (ARROW_PREDICT_FALSE(cond)) { \ + return Status::ErrorType(__VA_ARGS__); \ + } \ } while (false) +#define ERROR_IF(cond, ...) ERROR_IF_TYPE(cond, Invalid, __VA_ARGS__) // // Leaf builder. // @@ -76,7 +80,27 @@ ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, ERROR_IF(!expr_type.IsTable(), "Input expression does not have a Table shape."); auto field = expr_type.schema()->GetFieldByName(field_name); - ERROR_IF(field == nullptr, "Cannot reference field '", field_name, "' in schema."); + ERROR_IF_TYPE(field == nullptr, KeyError, "Cannot reference field '", field_name, + "' in schema."); + + return FieldRefExpr::Make(std::move(field)); +} + +ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, + int field_index) { + ERROR_IF(input == nullptr, "Input expression must be non-null"); + ERROR_IF_TYPE(field_index < 0, KeyError, "Field index must be positive"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.IsTable(), "Input expression does not have a Table shape."); + + auto schema = expr_type.schema(); + auto num_fields = schema->num_fields(); + ERROR_IF_TYPE(field_index >= num_fields, KeyError, "Field index ", field_index, + " out of bounds."); + + auto field = expr_type.schema()->field(field_index); + ERROR_IF_TYPE(field == nullptr, KeyError, "Field at index ", field_index, " is null."); return FieldRefExpr::Make(std::move(field)); } @@ -98,6 +122,58 @@ ResultExpr LogicalPlanBuilder::Filter(const std::shared_ptr& input, return FilterRelExpr::Make(std::move(input), std::move(predicate)); } +ResultExpr LogicalPlanBuilder::Project( + const std::shared_ptr& input, + const std::vector>& expressions) { + ERROR_IF(input == nullptr, "Input expression can't be null."); + ERROR_IF(expressions.empty(), "Must have at least one expression."); + return ProjectionRelExpr::Make(input, expressions); +} + +ResultExpr LogicalPlanBuilder::Mutate( + const std::shared_ptr& input, + const std::vector>& expressions) { + return Project(input, expressions); +} + +ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, + const std::vector& column_names) { + ERROR_IF(input == nullptr, "Input expression can't be null."); + ERROR_IF(column_names.empty(), "Must have at least one column name."); + + std::vector> expressions{column_names.size()}; + for (size_t i = 0; i < column_names.size(); i++) { + ARROW_ASSIGN_OR_RAISE(expressions[i], Field(input, column_names[i])); + } + + // TODO(fsaintjacques): preserve field names. + return Project(input, expressions); +} + +ResultExpr LogicalPlanBuilder::Select(const std::shared_ptr& input, + const std::vector& column_names) { + return Project(input, column_names); +} + +ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, + const std::vector& column_indices) { + ERROR_IF(input == nullptr, "Input expression can't be null."); + ERROR_IF(column_indices.empty(), "Must have at least one column index."); + + std::vector> expressions{column_indices.size()}; + for (size_t i = 0; i < column_indices.size(); i++) { + ARROW_ASSIGN_OR_RAISE(expressions[i], Field(input, column_indices[i])); + } + + // TODO(fsaintjacques): preserve field names. + return Project(input, expressions); +} + +ResultExpr LogicalPlanBuilder::Select(const std::shared_ptr& input, + const std::vector& column_indices) { + return Project(input, column_indices); +} + #undef ERROR_IF } // namespace engine diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index 2342dd7117e..5296947e1c0 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -18,9 +18,7 @@ #pragma once #include -#include #include -#include #include #include "arrow/type_fwd.h" @@ -37,12 +35,14 @@ namespace engine { class Catalog; class Expr; +class ExprType; class LogicalPlan : public util::EqualityComparable { public: explicit LogicalPlan(std::shared_ptr root); const std::shared_ptr& root() const { return root_; } + const ExprType& type() const; bool Equals(const LogicalPlan& other) const; std::string ToString() const; @@ -52,6 +52,7 @@ class LogicalPlan : public util::EqualityComparable { }; struct LogicalPlanBuilderOptions { + /// Catalog containing named tables. std::shared_ptr catalog; }; @@ -69,34 +70,61 @@ class LogicalPlanBuilder { /// \brief References a field by name. ResultExpr Field(const std::shared_ptr& input, const std::string& field_name); + /// \brief References a field by index. + ResultExpr Field(const std::shared_ptr& input, int field_index); /// \brief Scan a Table/Dataset from the Catalog. ResultExpr Scan(const std::string& table_name); /// @} + /// \defgroup comparator-nodes Comparison operators + /// @{ + + /* + TODO(fsaintjacques): This. + ResultExpr Equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + ResultExpr NotEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + ResultExpr GreaterThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + ResultExpr GreaterEqualThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + ResultExpr LessThan(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + ResultExpr LessEqualThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + */ + + /// @} + /// \defgroup rel-nodes Relational operator nodes in the logical plan + /// \brief Filter rows of a relation with the given predicate. ResultExpr Filter(const std::shared_ptr& input, const std::shared_ptr& predicate); - /* /// \brief Project (mutate) columns with given expressions. - ResultExpr Project(const std::vector>& expressions); - ResultExpr Mutate(const std::vector>& expressions); + ResultExpr Project(const std::shared_ptr& input, + const std::vector>& expressions); + ResultExpr Mutate(const std::shared_ptr& input, + const std::vector>& expressions); /// \brief Project (select) columns by names. /// /// This is a simplified version of Project where columns are selected by /// names. Duplicate and ordering are preserved. - ResultExpr Project(const std::vector& column_names); + ResultExpr Project(const std::shared_ptr& input, + const std::vector& column_names); + ResultExpr Select(const std::shared_ptr& input, + const std::vector& column_names); /// \brief Project (select) columns by indices. /// /// This is a simplified version of Project where columns are selected by /// indices. Duplicate and ordering are preserved. - ResultExpr Project(const std::vector& column_indices); - */ + ResultExpr Project(const std::shared_ptr& input, + const std::vector& column_indices); + ResultExpr Select(const std::shared_ptr& input, + const std::vector& column_indices); /// @} diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc index 200918b1979..609e2f74d63 100644 --- a/cpp/src/arrow/engine/logical_plan_test.cc +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -18,6 +18,7 @@ #include #include "arrow/engine/catalog.h" +#include "arrow/engine/expression.h" #include "arrow/engine/logical_plan.h" #include "arrow/testing/gtest_common.h" @@ -26,42 +27,98 @@ using testing::HasSubstr; namespace arrow { namespace engine { +using ResultExpr = LogicalPlanBuilder::ResultExpr; + class LogicalPlanBuilderTest : public testing::Test { protected: void SetUp() override { - CatalogBuilder builder; - ASSERT_OK(builder.Add(table_1, MockTable(schema_1))); - ASSERT_OK_AND_ASSIGN(catalog, builder.Finish()); + CatalogBuilder catalog_builder; + ASSERT_OK(catalog_builder.Add(table_1, MockTable(schema_1))); + ASSERT_OK_AND_ASSIGN(options.catalog, catalog_builder.Finish()); + builder = LogicalPlanBuilder{options}; + } + + ResultExpr scalar_expr() { + auto forthy_two = MakeScalar(42); + return builder.Scalar(forthy_two); + } + + ResultExpr scan_expr() { return builder.Scan(table_1); } + + template + ResultExpr field_expr(T key, std::shared_ptr input = nullptr) { + if (input == nullptr) { + ARROW_ASSIGN_OR_RAISE(input, scan_expr()); + } + return builder.Field(input, key); } + + ResultExpr predicate_expr() { return nullptr; } + std::string table_1 = "table_1"; - std::shared_ptr schema_1 = schema({field("i32", int32())}); - std::shared_ptr catalog; + std::shared_ptr schema_1 = schema({ + field("bool", boolean()), + field("i32", int32()), + field("u64", uint64()), + field("f32", uint32()), + }); + LogicalPlanBuilderOptions options{}; + LogicalPlanBuilder builder{}; }; TEST_F(LogicalPlanBuilderTest, Scalar) { - LogicalPlanBuilder builder{{catalog}}; auto forthy_two = MakeScalar(42); EXPECT_OK_AND_ASSIGN(auto scalar, builder.Scalar(forthy_two)); } +TEST_F(LogicalPlanBuilderTest, FieldReferences) { + ASSERT_RAISES(Invalid, builder.Field(nullptr, "i32")); + ASSERT_RAISES(Invalid, builder.Field(nullptr, 0)); + + // Can't lookup a scalar + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + ASSERT_RAISES(Invalid, builder.Field(scalar, "i32")); + + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + ASSERT_RAISES(KeyError, builder.Field(table, "")); + ASSERT_RAISES(KeyError, builder.Field(table, -1)); + ASSERT_RAISES(KeyError, builder.Field(table, 9000)); + + EXPECT_OK_AND_ASSIGN(auto field_name_ref, builder.Field(table, "i32")); + EXPECT_OK_AND_ASSIGN(auto field_idx_ref, builder.Field(table, 0)); +} + TEST_F(LogicalPlanBuilderTest, BasicScan) { - LogicalPlanBuilder builder{{catalog}}; + LogicalPlanBuilder builder{options}; + ASSERT_RAISES(KeyError, builder.Scan("")); + ASSERT_RAISES(KeyError, builder.Scan("not_found")); ASSERT_OK(builder.Scan(table_1)); } -TEST_F(LogicalPlanBuilderTest, FieldReferenceByName) { - LogicalPlanBuilder builder{{catalog}}; +TEST_F(LogicalPlanBuilderTest, Filter) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); - // Input must be non-null. - ASSERT_RAISES(Invalid, builder.Field(nullptr, "i32")); + EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + EXPECT_OK_AND_ASSIGN(auto predicate, EqualCmpExpr::Make(field, scalar)); - // The input must have a Table shape. - auto forthy_two = MakeScalar(42); - EXPECT_OK_AND_ASSIGN(auto scalar, builder.Scalar(forthy_two)); - ASSERT_RAISES(Invalid, builder.Field(scalar, "not_found")); + EXPECT_OK_AND_ASSIGN(auto filter, builder.Filter(table, predicate)); +} + +TEST_F(LogicalPlanBuilderTest, ProjectionByNamesAndIndices) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + + std::vector no_names{}; + ASSERT_RAISES(Invalid, builder.Project(table, no_names)); + std::vector invalid_names{"u64", "nope"}; + ASSERT_RAISES(KeyError, builder.Project(table, invalid_names)); + std::vector invalid_idx{42, 0}; + ASSERT_RAISES(KeyError, builder.Project(table, invalid_idx)); - EXPECT_OK_AND_ASSIGN(auto table_scan, builder.Scan(table_1)); - EXPECT_OK_AND_ASSIGN(auto field_ref, builder.Field(table_scan, "i32")); + std::vector valid_names{"u64", "f32"}; + ASSERT_OK(builder.Project(table, valid_names)); + std::vector valid_idx{3, 1, 1}; + ASSERT_OK(builder.Project(table, valid_idx)); } } // namespace engine diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index a497866ba85..a3292704c72 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -33,6 +33,7 @@ class LessThanCmpExpr; class LessEqualThanCmpExpr; class EmptyRelExpr; class ScanRelExpr; +class ProjectionRelExpr; class FilterRelExpr; } // namespace engine diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h index 2982ddbfb63..52abeb6a99d 100644 --- a/cpp/src/arrow/engine/type_traits.h +++ b/cpp/src/arrow/engine/type_traits.h @@ -30,7 +30,7 @@ template using enable_if_compare_expr = enable_if_t::value, Ret>; template -using is_relational_expr = std::is_base_of, E>; +using is_relational_expr = std::is_base_of, E>; template using enable_if_relational_expr = enable_if_t::value, Ret>; From 8ea5ddd716dbb6f0f9373ef1759d4a58de1d0fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Mon, 9 Mar 2020 09:46:11 -0400 Subject: [PATCH 07/21] Removes Table support from Catalog::Entry --- cpp/src/arrow/engine/catalog.cc | 87 +++++++--------------------- cpp/src/arrow/engine/catalog.h | 25 +++----- cpp/src/arrow/engine/catalog_test.cc | 6 +- 3 files changed, 32 insertions(+), 86 deletions(-) diff --git a/cpp/src/arrow/engine/catalog.cc b/cpp/src/arrow/engine/catalog.cc index fa1f21f4c90..a2f173142ce 100644 --- a/cpp/src/arrow/engine/catalog.cc +++ b/cpp/src/arrow/engine/catalog.cc @@ -32,12 +32,12 @@ namespace engine { using Entry = Catalog::Entry; -Catalog::Catalog(std::unordered_map tables) - : tables_(std::move(tables)) {} +Catalog::Catalog(std::unordered_map datasets) + : datasets_(std::move(datasets)) {} Result Catalog::Get(const std::string& key) const { - auto value = tables_.find(key); - if (value != tables_.end()) return value->second; + auto value = datasets_.find(key); + if (value != datasets_.end()) return value->second; return Status::KeyError("Table '", key, "' not found in catalog."); } @@ -48,10 +48,10 @@ Result> Catalog::GetSchema(const std::string& key) const return Get(key).Map(as_schema); } -Result> Catalog::Make(const std::vector& tables) { +Result> Catalog::Make(const std::vector& datasets) { CatalogBuilder builder; - for (const auto& key_val : tables) { + for (const auto& key_val : datasets) { RETURN_NOT_OK(builder.Add(key_val)); } @@ -62,52 +62,22 @@ Result> Catalog::Make(const std::vector& tables) // Catalog::Entry // -Entry::Entry(std::shared_ptr
table, std::string name) - : entry_(std::move(table)), name_(std::move(name)) {} - Entry::Entry(std::shared_ptr dataset, std::string name) - : entry_(std::move(dataset)), name_(std::move(name)) {} - -Entry::Kind Entry::kind() const { - if (util::holds_alternative>(entry_)) { - return TABLE; - } - - if (util::holds_alternative>(entry_)) { - return DATASET; - } + : dataset_(std::move(dataset)), name_(std::move(name)) {} - return UNKNOWN; -} - -std::shared_ptr
Entry::table() const { - if (kind() == TABLE) return util::get>(entry_); - return nullptr; -} +Entry::Entry(std::shared_ptr
table, std::string name) + : dataset_(std::make_shared(std::move(table))), + name_(std::move(name)) {} -std::shared_ptr Entry::dataset() const { - if (kind() == DATASET) return util::get>(entry_); - return nullptr; -} +const std::shared_ptr& Entry::dataset() const { return dataset_; } bool Entry::operator==(const Entry& other) const { // Entries are unique by name in a catalog, but we can still protect with // pointer equality. - return name_ == other.name_ && entry_ == other.entry_; + return name_ == other.name_ && dataset_ == other.dataset_; } -std::shared_ptr Entry::schema() const { - switch (kind()) { - case TABLE: - return table()->schema(); - case DATASET: - return dataset()->schema(); - default: - return nullptr; - } - - return nullptr; -} +const std::shared_ptr& Entry::schema() const { return dataset()->schema(); } // // CatalogBuilder @@ -119,41 +89,28 @@ Status CatalogBuilder::Add(Entry entry) { return Status::Invalid("Key in catalog can't be empty"); } - switch (entry.kind()) { - case Entry::TABLE: { - if (entry.table() == nullptr) { - return Status::Invalid("Table entry can't be null."); - } - break; - } - case Entry::DATASET: { - if (entry.dataset() == nullptr) { - return Status::Invalid("Table entry can't be null."); - } - break; - } - default: - return Status::NotImplemented("Unknown entry kind"); + if (entry.dataset() == nullptr) { + return Status::Invalid("Dataset entry can't be null."); } - auto inserted = tables_.emplace(name, std::move(entry)); + auto inserted = datasets_.emplace(name, std::move(entry)); if (!inserted.second) { - return Status::KeyError("Table '", name, "' already in catalog."); + return Status::KeyError("Dataset '", name, "' already in catalog."); } return Status::OK(); } -Status CatalogBuilder::Add(std::string name, std::shared_ptr
table) { - return Add(Entry(std::move(table), std::move(name))); -} - Status CatalogBuilder::Add(std::string name, std::shared_ptr dataset) { return Add(Entry(std::move(dataset), std::move(name))); } +Status CatalogBuilder::Add(std::string name, std::shared_ptr
table) { + return Add(Entry(std::move(table), std::move(name))); +} + Result> CatalogBuilder::Finish() { - return std::shared_ptr(new Catalog(std::move(tables_))); + return std::shared_ptr(new Catalog(std::move(datasets_))); } } // namespace engine diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h index 25fbc6fb652..d4607320acd 100644 --- a/cpp/src/arrow/engine/catalog.h +++ b/cpp/src/arrow/engine/catalog.h @@ -46,48 +46,39 @@ class Catalog { class Entry { public: - enum Kind { - TABLE = 0, - DATASET, - UNKNOWN, - }; - - Entry(std::shared_ptr
table, std::string name); Entry(std::shared_ptr dataset, std::string name); - - Kind kind() const; + Entry(std::shared_ptr
table, std::string name); const std::string& name() const { return name_; } - std::shared_ptr
table() const; - std::shared_ptr dataset() const; + const std::shared_ptr& dataset() const; - std::shared_ptr schema() const; + const std::shared_ptr& schema() const; bool operator==(const Entry& other) const; private: - util::variant, std::shared_ptr> entry_; + std::shared_ptr dataset_; std::string name_; }; private: friend class CatalogBuilder; - explicit Catalog(std::unordered_map tables); + explicit Catalog(std::unordered_map datasets); - std::unordered_map tables_; + std::unordered_map datasets_; }; class CatalogBuilder { public: Status Add(Catalog::Entry entry); - Status Add(std::string name, std::shared_ptr
); Status Add(std::string name, std::shared_ptr); + Status Add(std::string name, std::shared_ptr
); Result> Finish(); private: - std::unordered_map tables_; + std::unordered_map datasets_; }; } // namespace engine diff --git a/cpp/src/arrow/engine/catalog_test.cc b/cpp/src/arrow/engine/catalog_test.cc index d1fc42de424..cb65109c722 100644 --- a/cpp/src/arrow/engine/catalog_test.cc +++ b/cpp/src/arrow/engine/catalog_test.cc @@ -41,8 +41,7 @@ class TestCatalog : public testing::Test { void AssertCatalogKeyIs(const std::shared_ptr& catalog, const std::string& key, const std::shared_ptr
& expected) { ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); - ASSERT_EQ(t.kind(), Catalog::Entry::Kind::TABLE); - AssertTablesEqual(*t.table(), *expected); + EXPECT_EQ(t.name(), key); ASSERT_OK_AND_ASSIGN(auto schema, catalog->GetSchema(key)); AssertSchemaEqual(*schema, *expected->schema()); @@ -127,8 +126,7 @@ TEST_F(TestCatalogBuilder, DuplicateKeys) { ASSERT_OK_AND_ASSIGN(auto catalog, builder.Finish()); ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); - ASSERT_EQ(t.kind(), Catalog::Entry::Kind::TABLE); - AssertTablesEqual(*t.table(), *table()); + EXPECT_EQ(t.name(), key); } } // namespace engine From 8b47016ab3559d865c7597fa2e32b83433e540c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Wed, 11 Mar 2020 11:31:48 -0400 Subject: [PATCH 08/21] Add more comment --- cpp/src/arrow/engine/expression.cc | 37 ++-- cpp/src/arrow/engine/expression.h | 234 +++++++++++++++--------- cpp/src/arrow/engine/expression_test.cc | 18 +- 3 files changed, 177 insertions(+), 112 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 9f6a5e98f50..13511cc9c6e 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -40,6 +40,10 @@ ExprType ExprType::Table(std::shared_ptr schema) { return ExprType(std::move(schema), Shape::TABLE); } +ExprType ExprType::Table(std::vector> fields) { + return ExprType(arrow::schema(std::move(fields)), Shape::TABLE); +} + ExprType::ExprType(std::shared_ptr schema, Shape shape) : schema_(std::move(schema)), shape_(shape) { DCHECK_EQ(shape, Shape::TABLE); @@ -54,7 +58,7 @@ ExprType::ExprType(const ExprType& other) : shape_(other.shape()) { switch (other.shape()) { case SCALAR: case ARRAY: - data_type_ = other.data_type(); + data_type_ = other.type(); break; case TABLE: schema_ = other.schema(); @@ -65,7 +69,7 @@ ExprType::ExprType(ExprType&& other) : shape_(other.shape()) { switch (other.shape()) { case SCALAR: case ARRAY: - data_type_ = std::move(other.data_type()); + data_type_ = std::move(other.type()); break; case TABLE: schema_ = std::move(other.schema()); @@ -83,22 +87,22 @@ ExprType::~ExprType() { } } -bool ExprType::Equals(const ExprType& type) const { - if (this == &type) { +bool ExprType::Equals(const ExprType& other) const { + if (this == &other) { return true; } - if (shape() != type.shape()) { + if (shape() != other.shape()) { return false; } switch (shape()) { case SCALAR: - return data_type()->Equals(type.data_type()); + return type()->Equals(other.type()); case ARRAY: - return data_type()->Equals(type.data_type()); + return type()->Equals(other.type()); case TABLE: - return schema()->Equals(type.schema()); + return schema()->Equals(other.schema()); default: break; } @@ -106,7 +110,7 @@ bool ExprType::Equals(const ExprType& type) const { return false; } -Result ExprType::CastTo(const std::shared_ptr& data_type) const { +Result ExprType::WithType(const std::shared_ptr& data_type) const { switch (shape()) { case SCALAR: return ExprType::Scalar(data_type); @@ -118,7 +122,7 @@ Result ExprType::CastTo(const std::shared_ptr& data_type) co return Status::UnknownError("unreachable"); } -Result ExprType::CastTo(const std::shared_ptr& schema) const { +Result ExprType::WithSchema(const std::shared_ptr& schema) const { switch (shape()) { case SCALAR: return Status::Invalid("Cannot cast a ScalarType with a schema"); @@ -136,7 +140,7 @@ Result ExprType::Broadcast(const ExprType& lhs, const ExprType& rhs) { return Status::Invalid("Broadcast operands must not be tables"); } - if (!lhs.data_type()->Equals(rhs.data_type())) { + if (!lhs.type()->Equals(rhs.type())) { return Status::Invalid("Broadcast operands must be of same type"); } @@ -318,7 +322,7 @@ Result> ScanRelExpr::Make(Catalog::Entry input) { ProjectionRelExpr::ProjectionRelExpr(std::shared_ptr input, std::shared_ptr schema, std::vector> expressions) - : UnaryOpExpr(std::move(input)), + : UnaryOpMixin(std::move(input)), RelExpr(std::move(schema)), expressions_(std::move(expressions)) {} @@ -332,11 +336,12 @@ Result> ProjectionRelExpr::Make( for (size_t i = 0; i < n_fields; i++) { const auto& expr = expressions[i]; - const auto& type = expr->type(); - ERROR_IF(!type.IsArray(), "Expression at position ", i, " not of Array type"); + const auto& expr_type = expr->type(); + ERROR_IF(!expr_type.HasType(), "Expression at position ", i, + " should not be have a table shape"); // TODO(fsaintjacques): better name handling. Callers should be able to // pass a vector of names. - fields.push_back(field("expr", type.data_type())); + fields.push_back(field("expr", expr_type.type())); } return std::shared_ptr(new ProjectionRelExpr( @@ -360,7 +365,7 @@ Result> FilterRelExpr::Make( } FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate) - : UnaryOpExpr(std::move(input)), + : UnaryOpMixin(std::move(input)), RelExpr(operand()->type().schema()), predicate_(std::move(predicate)) {} diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 75028d89da6..874cc643e0e 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "arrow/engine/catalog.h" #include "arrow/engine/type_fwd.h" @@ -27,7 +28,6 @@ #include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" -#include "arrow/util/variant.h" namespace arrow { namespace engine { @@ -62,13 +62,14 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { static ExprType Array(std::shared_ptr type); /// Construct a Table type. static ExprType Table(std::shared_ptr schema); + static ExprType Table(std::vector> fields); /// \brief Shape of the expression. Shape shape() const { return shape_; } /// \brief DataType of the expression if a scalar or an array. /// WARNING: You must ensure the proper shape before calling this accessor. - const std::shared_ptr& data_type() const { return data_type_; } + const std::shared_ptr& type() const { return data_type_; } /// \brief Schema of the expression if of table shape. /// WARNING: You must ensure the proper shape before calling this accessor. const std::shared_ptr& schema() const { return schema_; } @@ -80,19 +81,27 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { /// \brief Indicate if the type is a Table. bool IsTable() const { return shape_ == TABLE; } + bool HasType() const { return IsScalar() || IsArray(); } + bool HasSchema() const { return IsTable(); } + + bool IsTypedLike(Type::type type_id) const { + return HasType() && data_type_->id() == type_id; + } + + /// \brief Static version of IsTypedLike template bool IsTypedLike() const { - return (IsScalar() || IsArray()) && data_type_->id() == TYPE_ID; + return HasType() && data_type_->id() == TYPE_ID; } /// \brief Indicate if the type is a predicate, i.e. a boolean scalar. bool IsPredicate() const { return IsTypedLike(); } - /// \brief Cast to DataType/Schema if the shape allows it. - Result CastTo(const std::shared_ptr& data_type) const; - Result CastTo(const std::shared_ptr& schema) const; + /// \brief Cast the inner DataType/Schema while preserving the shape. + Result WithType(const std::shared_ptr& data_type) const; + Result WithSchema(const std::shared_ptr& schema) const; - /// \brief Broadcasting align two types to the largest shape. + /// \brief Expand the smallest shape to the bigger one if possible. /// /// \param[in] lhs, first type to broadcast /// \param[in] rhs, second type to broadcast @@ -128,7 +137,7 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { ExprType(std::shared_ptr type, Shape shape); union { - // Zero initialize the pointer or Copy/Assign constructors will fail. + /// Zero initialize the pointer or Copy/Assign constructors will fail. std::shared_ptr data_type_{}; std::shared_ptr schema_; }; @@ -138,33 +147,33 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { /// Represents an expression tree class ARROW_EXPORT Expr : public util::EqualityComparable { public: - // Tag identifier for the expression type. + /// Tag identifier for the expression type. enum Kind : uint8_t { - // A Scalar literal, i.e. a constant. + /// A Scalar literal, i.e. a constant. SCALAR_LITERAL, - // A Field reference in a schema. + /// A Field reference in a schema. FIELD_REFERENCE, - // Equal compare operator + /// Equal compare operator EQ_CMP_OP, - // Not-Equal compare operator + /// Not-Equal compare operator NE_CMP_OP, - // Greater-Than compare operator + /// Greater-Than compare operator GT_CMP_OP, - // Greater-Equal-Than compare operator + /// Greater-Equal-Than compare operator GE_CMP_OP, - // Less-Than compare operator + /// Less-Than compare operator LT_CMP_OP, - // Less-Equal-Than compare operator + /// Less-Equal-Than compare operator LE_CMP_OP, - // Empty relation with a known schema. + /// Empty relation with a known schema. EMPTY_REL, - // Scan relational operator + /// Scan relational operator SCAN_REL, - // Projection relational operator + /// Projection relational operator PROJECTION_REL, - // Filter relational operator + /// Filter relational operator FILTER_REL, }; @@ -176,7 +185,7 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { /// \brief Return the type and shape of the resulting expression. const ExprType& type() const { return type_; } - /// \brief Indicate if the expressions + /// \brief Indicate if the expressions are equal. bool Equals(const Expr& other) const; using util::EqualityComparable::Equals; @@ -192,9 +201,9 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { Kind kind_; }; -// The following traits are used to break cycle between CRTP base classes and -// their derived counterparts to extract the Expr::Kind and other static -// properties from the forward declared class. +/// The following traits are used to break cycle between CRTP base classes and +/// their derived counterparts to extract the Expr::Kind and other static +/// properties from the forward declared class. template struct expr_traits; @@ -258,101 +267,101 @@ struct expr_traits { static constexpr Expr::Kind kind_id = Expr::FILTER_REL; }; -// -// Value Expressions -// - -// An unnamed scalar literal expression. -class ARROW_EXPORT ScalarExpr : public Expr { - public: - static Result> Make(std::shared_ptr scalar); - - const std::shared_ptr& scalar() const { return scalar_; } - - private: - explicit ScalarExpr(std::shared_ptr scalar); - - std::shared_ptr scalar_; -}; - -// References a column in a table/dataset -class ARROW_EXPORT FieldRefExpr : public Expr { - public: - static Result> Make(std::shared_ptr field); - - const std::shared_ptr& field() const { return field_; } - - private: - explicit FieldRefExpr(std::shared_ptr field); - - std::shared_ptr field_; -}; - -// -// Operator expressions -// +/// +/// Operator expressions mixin. +/// -class ARROW_EXPORT UnaryOpExpr { +class ARROW_EXPORT UnaryOpMixin { public: const std::shared_ptr& operand() const { return operand_; } protected: - explicit UnaryOpExpr(std::shared_ptr operand) : operand_(std::move(operand)) {} + explicit UnaryOpMixin(std::shared_ptr operand) : operand_(std::move(operand)) {} std::shared_ptr operand_; }; -class ARROW_EXPORT BinaryOpExpr { +class ARROW_EXPORT BinaryOpMixin { public: const std::shared_ptr& left_operand() const { return left_operand_; } const std::shared_ptr& right_operand() const { return right_operand_; } protected: - BinaryOpExpr(std::shared_ptr left, std::shared_ptr right) + BinaryOpMixin(std::shared_ptr left, std::shared_ptr right) : left_operand_(std::move(left)), right_operand_(std::move(right)) {} std::shared_ptr left_operand_; std::shared_ptr right_operand_; }; -class ARROW_EXPORT MultiAryOpExpr { +class ARROW_EXPORT MultiAryOpMixin { public: const std::vector>& operands() const { return operands_; } protected: - explicit MultiAryOpExpr(std::vector> operands) + explicit MultiAryOpMixin(std::vector> operands) : operands_(std::move(operands)) {} std::vector> operands_; }; -// -// Comparison expressions -// +/// +/// Value Expressions +/// + +/// An unnamed scalar literal expression. +class ARROW_EXPORT ScalarExpr : public Expr { + public: + static Result> Make(std::shared_ptr scalar); + + const std::shared_ptr& scalar() const { return scalar_; } + + private: + explicit ScalarExpr(std::shared_ptr scalar); + + std::shared_ptr scalar_; +}; + +/// References a column in a table/dataset +class ARROW_EXPORT FieldRefExpr : public Expr { + public: + static Result> Make(std::shared_ptr field); + + const std::shared_ptr& field() const { return field_; } + + private: + explicit FieldRefExpr(std::shared_ptr field); + + std::shared_ptr field_; +}; + +/// +/// Comparison expressions +/// -template -class ARROW_EXPORT CmpOpExpr : public BinaryOpExpr, public Expr { +template +class ARROW_EXPORT CmpOpExpr : public BinaryOpMixin, public Expr { public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right) { + static Result> Make(std::shared_ptr left, + std::shared_ptr right) { if (left == NULLPTR || right == NULLPTR) { return Status::Invalid("Compare operands must be non-nulls"); } - // Broadcast ensures that types are compatible in shape and type. - auto broadcast = ExprType::Broadcast(left->type(), right->type()); - // The type of comparison is always a boolean predicate. - auto cast = [](const ExprType& t) { return t.CastTo(boolean()); }; - ARROW_ASSIGN_OR_RAISE(auto type, broadcast.Map(cast)); + // Broadcast the comparison to the biggest shape. + ARROW_ASSIGN_OR_RAISE(auto broadcast, + ExprType::Broadcast(left->type(), right->type())); + // And change this shape's type to boolean. + ARROW_ASSIGN_OR_RAISE(auto type, broadcast.WithType(boolean())); - return std::shared_ptr( - new Self(std::move(type), std::move(left), std::move(right))); + return std::shared_ptr( + new Derived(std::move(type), std::move(left), std::move(right))); } protected: CmpOpExpr(ExprType type, std::shared_ptr left, std::shared_ptr right) - : BinaryOpExpr(std::move(left), std::move(right)), - Expr(expr_traits::kind_id, std::move(type)) {} + : BinaryOpMixin(std::move(left), std::move(right)), + Expr(expr_traits::kind_id, std::move(type)) {} }; class ARROW_EXPORT EqualCmpExpr : public CmpOpExpr { @@ -385,23 +394,34 @@ class ARROW_EXPORT LessEqualThanCmpExpr : public CmpOpExpr using CmpOpExpr::CmpOpExpr; }; -// -// Relational Expressions -// +/// +/// Relational Expressions +/// -template +/// \brief Relational Expressions that acts on relations (arrow::Table). +template class ARROW_EXPORT RelExpr : public Expr { public: const std::shared_ptr& schema() const { return schema_; } protected: explicit RelExpr(std::shared_ptr schema) - : Expr(expr_traits::kind_id, ExprType::Table(schema)), + : Expr(expr_traits::kind_id, ExprType::Table(schema)), schema_(std::move(schema)) {} std::shared_ptr schema_; }; +/// \brief An empty relation that returns/contains no rows. +/// +/// An EmptyRelExpr is usually not found in user constructed logical plan but +/// can appear due to optimization passes, e.g. replacing a FilterRelExpr with +/// an always false predicate. It is also subsequently used in constant +/// propagation-like optimizations, e.g Filter(EmptyRel) => EmptyRel, or +/// InnerJoin(_, EmptyRel) => EmptyRel. +/// +/// \input schema, the schema of the empty relation +/// \ouput relation with no rows of the given input schema class ARROW_EXPORT EmptyRelExpr : public RelExpr { public: static Result> Make(std::shared_ptr schema); @@ -410,6 +430,19 @@ class ARROW_EXPORT EmptyRelExpr : public RelExpr { using RelExpr::RelExpr; }; +/// \brief Materialize a relation from a dataset. +/// +/// The ScanRelExpr are found in the leaves of the Expr tree. A Scan materialize +/// the relation from a datasets. In essence, it is a relational operator that +/// has no relation input (except some auxiliary information like a catalog +/// entry), and output a relation. +/// +/// \input table, a catalog entry pointing to a dataset +/// \ouput relation from the materialized dataset +/// +/// ``` +/// SELECT * FROM table; +/// ``` class ARROW_EXPORT ScanRelExpr : public RelExpr { public: static Result> Make(Catalog::Entry input); @@ -422,7 +455,23 @@ class ARROW_EXPORT ScanRelExpr : public RelExpr { Catalog::Entry input_; }; -class ARROW_EXPORT ProjectionRelExpr : public UnaryOpExpr, +/// \brief Project columns based on expressions. +/// +/// A projection creates a relation with new columns based on expressions of +/// the input's columns. It could be a simple permutation or selection of +/// column via FieldRefExpr or more complex expressions like the sum of two +/// columns. The projection operator will usually change the output schema of +/// the input relation due to the expressions without changing the number of +/// rows. +/// +/// \input relation, the input relation to compute the expressions from +/// \input expressions, the expressions to compute +/// \output relation where the columns are the expressions computed +/// +/// ``` +/// SELECT a, b, a + b, 1, mean(a) > b FROM relation; +/// ``` +class ARROW_EXPORT ProjectionRelExpr : public UnaryOpMixin, public RelExpr { public: static Result> Make( @@ -437,7 +486,18 @@ class ARROW_EXPORT ProjectionRelExpr : public UnaryOpExpr, std::vector> expressions_; }; -class ARROW_EXPORT FilterRelExpr : public UnaryOpExpr, public RelExpr { +/// \brief Filter the rows of a relation according to a predicate. +/// +/// A filter removes rows that don't match a predicate or a mask column. +/// +/// \input relation, the input relation to filter the rows from +/// \input predicate, a predicate to evaluate for each filter +/// \output relation where the rows are filtered according to the predicate +/// +/// ``` +/// SELECT * FROM relation WHERE predicate +/// ``` +class ARROW_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr { public: static Result> Make(std::shared_ptr input, std::shared_ptr predicate); diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 79b37e51b8c..ec4836f611b 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -36,14 +36,14 @@ TEST_F(ExprTypeTest, Basic) { auto scalar = ExprType::Scalar(i32); EXPECT_EQ(scalar.shape(), ExprType::Shape::SCALAR); - EXPECT_TRUE(scalar.data_type()->Equals(i32)); + EXPECT_TRUE(scalar.type()->Equals(i32)); EXPECT_TRUE(scalar.IsScalar()); EXPECT_FALSE(scalar.IsArray()); EXPECT_FALSE(scalar.IsTable()); auto array = ExprType::Array(i32); EXPECT_EQ(array.shape(), ExprType::Shape::ARRAY); - EXPECT_TRUE(array.data_type()->Equals(i32)); + EXPECT_TRUE(array.type()->Equals(i32)); EXPECT_FALSE(array.IsScalar()); EXPECT_TRUE(array.IsArray()); EXPECT_FALSE(array.IsTable()); @@ -87,7 +87,7 @@ TEST_F(ExprTypeTest, Broadcast) { EXPECT_THAT(ExprType::Broadcast(bool_array, bool_array), OkAndEq(bool_array)); } -TEST_F(ExprTypeTest, CastTo) { +TEST_F(ExprTypeTest, WithTypeOrSchema) { auto bool_scalar = ExprType::Scalar(boolean()); auto bool_array = ExprType::Array(boolean()); auto bool_table = ExprType::Table(schema({field("b", boolean())})); @@ -96,15 +96,15 @@ TEST_F(ExprTypeTest, CastTo) { auto other = schema({field("a", i32)}); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast a ScalarType with"), - bool_scalar.CastTo(other)); + bool_scalar.WithSchema(other)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast an ArrayType with"), - bool_array.CastTo(other)); + bool_array.WithSchema(other)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast a TableType with"), - bool_table.CastTo(i32)); + bool_table.WithType(i32)); - EXPECT_EQ(bool_scalar.CastTo(i32), ExprType::Scalar(i32)); - EXPECT_EQ(bool_array.CastTo(i32), ExprType::Array(i32)); - EXPECT_EQ(bool_table.CastTo(other), ExprType::Table(other)); + EXPECT_EQ(bool_scalar.WithType(i32), ExprType::Scalar(i32)); + EXPECT_EQ(bool_array.WithType(i32), ExprType::Array(i32)); + EXPECT_EQ(bool_table.WithSchema(other), ExprType::Table(other)); } class ExprTest : public testing::Test {}; From 236ff3cfa4140390dcd440dac6df8981d267c312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Wed, 11 Mar 2020 14:32:56 -0400 Subject: [PATCH 09/21] Move compare enums into a single enum --- cpp/src/arrow/engine/expression.cc | 30 +-- cpp/src/arrow/engine/expression.h | 226 ++++++++-------------- cpp/src/arrow/engine/expression_test.cc | 18 +- cpp/src/arrow/engine/logical_plan_test.cc | 2 +- cpp/src/arrow/engine/type_fwd.h | 50 ++++- cpp/src/arrow/engine/type_traits.h | 79 +++++++- 6 files changed, 219 insertions(+), 186 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 13511cc9c6e..e06267ba6ce 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -16,7 +16,6 @@ // under the License. #include "arrow/engine/expression.h" -#include "arrow/engine/type_traits.h" #include "arrow/scalar.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" @@ -122,6 +121,7 @@ Result ExprType::WithType(const std::shared_ptr& data_type) return Status::UnknownError("unreachable"); } + Result ExprType::WithSchema(const std::shared_ptr& schema) const { switch (shape()) { case SCALAR: @@ -168,31 +168,19 @@ Result ExprType::Broadcast(const ExprType& lhs, const ExprType& rhs) { std::string Expr::kind_name() const { switch (kind_) { - case Expr::SCALAR_LITERAL: + case ExprKind::SCALAR_LITERAL: return "scalar"; - case Expr::FIELD_REFERENCE: + case ExprKind::FIELD_REFERENCE: return "field_ref"; - - case Expr::EQ_CMP_OP: - return "eq_cmp"; - case Expr::NE_CMP_OP: - return "ne_cmp"; - case Expr::GT_CMP_OP: - return "gt_cmp"; - case Expr::GE_CMP_OP: - return "ge_cmp"; - case Expr::LT_CMP_OP: - return "lt_cmp"; - case Expr::LE_CMP_OP: - return "le_cmp"; - - case Expr::EMPTY_REL: + case ExprKind::COMPARE_OP: + return "compare_op"; + case ExprKind::EMPTY_REL: return "empty_rel"; - case Expr::SCAN_REL: + case ExprKind::SCAN_REL: return "scan_rel"; - case Expr::PROJECTION_REL: + case ExprKind::PROJECTION_REL: return "projection_rel"; - case Expr::FILTER_REL: + case ExprKind::FILTER_REL: return "filter_rel"; } diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 874cc643e0e..a6bab480457 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -24,6 +24,7 @@ #include "arrow/engine/catalog.h" #include "arrow/engine/type_fwd.h" +#include "arrow/engine/type_traits.h" #include "arrow/result.h" #include "arrow/type.h" #include "arrow/type_fwd.h" @@ -147,38 +148,8 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { /// Represents an expression tree class ARROW_EXPORT Expr : public util::EqualityComparable { public: - /// Tag identifier for the expression type. - enum Kind : uint8_t { - /// A Scalar literal, i.e. a constant. - SCALAR_LITERAL, - /// A Field reference in a schema. - FIELD_REFERENCE, - - /// Equal compare operator - EQ_CMP_OP, - /// Not-Equal compare operator - NE_CMP_OP, - /// Greater-Than compare operator - GT_CMP_OP, - /// Greater-Equal-Than compare operator - GE_CMP_OP, - /// Less-Than compare operator - LT_CMP_OP, - /// Less-Equal-Than compare operator - LE_CMP_OP, - - /// Empty relation with a known schema. - EMPTY_REL, - /// Scan relational operator - SCAN_REL, - /// Projection relational operator - PROJECTION_REL, - /// Filter relational operator - FILTER_REL, - }; - /// \brief Return the kind of the expression. - Kind kind() const { return kind_; } + ExprKind kind() const { return kind_; } /// \brief Return a string representation of the kind. std::string kind_name() const; @@ -195,76 +166,10 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { virtual ~Expr() = default; protected: - explicit Expr(Kind kind, ExprType type) : type_(std::move(type)), kind_(kind) {} + explicit Expr(ExprKind kind, ExprType type) : type_(std::move(type)), kind_(kind) {} ExprType type_; - Kind kind_; -}; - -/// The following traits are used to break cycle between CRTP base classes and -/// their derived counterparts to extract the Expr::Kind and other static -/// properties from the forward declared class. -template -struct expr_traits; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::SCALAR_LITERAL; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::FIELD_REFERENCE; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::EQ_CMP_OP; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::NE_CMP_OP; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::GT_CMP_OP; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::GE_CMP_OP; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::LT_CMP_OP; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::LE_CMP_OP; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::EMPTY_REL; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::SCAN_REL; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::PROJECTION_REL; -}; - -template <> -struct expr_traits { - static constexpr Expr::Kind kind_id = Expr::FILTER_REL; + ExprKind kind_; }; /// @@ -339,66 +244,86 @@ class ARROW_EXPORT FieldRefExpr : public Expr { /// Comparison expressions /// -template -class ARROW_EXPORT CmpOpExpr : public BinaryOpMixin, public Expr { +class ARROW_EXPORT CompareOpExpr : public BinaryOpMixin, public Expr { public: - static Result> Make(std::shared_ptr left, - std::shared_ptr right) { - if (left == NULLPTR || right == NULLPTR) { - return Status::Invalid("Compare operands must be non-nulls"); + CompareKind compare_kind() const { return compare_kind_; } + + /// This inner-class is required because `using` statements can't use derived + /// methods. + template + struct MakeMixin { + static Result> Make(std::shared_ptr left, + std::shared_ptr right) { + if (left == NULLPTR || right == NULLPTR) { + return Status::Invalid("Compare operands must be non-nulls"); + } + + // Broadcast the comparison to the biggest shape. + ARROW_ASSIGN_OR_RAISE(auto broadcast, + ExprType::Broadcast(left->type(), right->type())); + // And change this shape's type to boolean. + ARROW_ASSIGN_OR_RAISE(auto type, broadcast.WithType(boolean())); + + return std::shared_ptr(new Derived(std::move(type), + expr_traits::compare_kind_id, + std::move(left), std::move(right))); } + }; - // Broadcast the comparison to the biggest shape. - ARROW_ASSIGN_OR_RAISE(auto broadcast, - ExprType::Broadcast(left->type(), right->type())); - // And change this shape's type to boolean. - ARROW_ASSIGN_OR_RAISE(auto type, broadcast.WithType(boolean())); + protected: + CompareOpExpr(ExprType type, CompareKind op, std::shared_ptr left, + std::shared_ptr right) + : BinaryOpMixin(std::move(left), std::move(right)), + Expr(COMPARE_OP, std::move(type)), + compare_kind_(op) {} - return std::shared_ptr( - new Derived(std::move(type), std::move(left), std::move(right))); - } + CompareKind compare_kind_; +}; + +template +class BaseCompareExpr : public CompareOpExpr, private CompareOpExpr::MakeMixin { + public: + using CompareOpExpr::MakeMixin::Make; protected: - CmpOpExpr(ExprType type, std::shared_ptr left, std::shared_ptr right) - : BinaryOpMixin(std::move(left), std::move(right)), - Expr(expr_traits::kind_id, std::move(type)) {} + using CompareOpExpr::CompareOpExpr; }; -class ARROW_EXPORT EqualCmpExpr : public CmpOpExpr { +class ARROW_EXPORT EqualExpr : public BaseCompareExpr { protected: - using CmpOpExpr::CmpOpExpr; + using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT NotEqualCmpExpr : public CmpOpExpr { +class ARROW_EXPORT NotEqualExpr : public BaseCompareExpr { protected: - using CmpOpExpr::CmpOpExpr; + using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT GreaterThanCmpExpr : public CmpOpExpr { +class ARROW_EXPORT GreaterThanExpr : public BaseCompareExpr { protected: - using CmpOpExpr::CmpOpExpr; + using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT GreaterEqualThanCmpExpr : public CmpOpExpr { +class ARROW_EXPORT GreaterThanEqualExpr : public BaseCompareExpr { protected: - using CmpOpExpr::CmpOpExpr; + using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT LessThanCmpExpr : public CmpOpExpr { +class ARROW_EXPORT LessThanExpr : public BaseCompareExpr { protected: - using CmpOpExpr::CmpOpExpr; + using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT LessEqualThanCmpExpr : public CmpOpExpr { +class ARROW_EXPORT LessThanEqualExpr : public BaseCompareExpr { protected: - using CmpOpExpr::CmpOpExpr; + using BaseCompareExpr::BaseCompareExpr; }; /// /// Relational Expressions /// -/// \brief Relational Expressions that acts on relations (arrow::Table). +/// \brief Relational Expressions that acts on tables. template class ARROW_EXPORT RelExpr : public Expr { public: @@ -513,31 +438,36 @@ class ARROW_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { switch (expr.kind()) { - case Expr::SCALAR_LITERAL: + case ExprKind::SCALAR_LITERAL: return visitor(internal::checked_cast(expr)); - case Expr::FIELD_REFERENCE: + case ExprKind::FIELD_REFERENCE: return visitor(internal::checked_cast(expr)); - case Expr::EQ_CMP_OP: - return visitor(internal::checked_cast(expr)); - case Expr::NE_CMP_OP: - return visitor(internal::checked_cast(expr)); - case Expr::GT_CMP_OP: - return visitor(internal::checked_cast(expr)); - case Expr::GE_CMP_OP: - return visitor(internal::checked_cast(expr)); - case Expr::LT_CMP_OP: - return visitor(internal::checked_cast(expr)); - case Expr::LE_CMP_OP: - return visitor(internal::checked_cast(expr)); - - case Expr::EMPTY_REL: + case ExprKind::COMPARE_OP: { + const auto& cmp_expr = static_cast(expr); + switch (cmp_expr.compare_kind()) { + case CompareKind::EQUAL: + return visitor(internal::checked_cast(expr)); + case CompareKind::NOT_EQUAL: + return visitor(internal::checked_cast(expr)); + case CompareKind::GREATER_THAN: + return visitor(internal::checked_cast(expr)); + case CompareKind::GREATER_THAN_EQUAL: + return visitor(internal::checked_cast(expr)); + case CompareKind::LESS_THAN: + return visitor(internal::checked_cast(expr)); + case CompareKind::LESS_THAN_EQUAL: + return visitor(internal::checked_cast(expr)); + } + } + + case ExprKind::EMPTY_REL: return visitor(internal::checked_cast(expr)); - case Expr::SCAN_REL: + case ExprKind::SCAN_REL: return visitor(internal::checked_cast(expr)); - case Expr::PROJECTION_REL: + case ExprKind::PROJECTION_REL: return visitor(internal::checked_cast(expr)); - case Expr::FILTER_REL: + case ExprKind::FILTER_REL: // LEAVE LAST or update the outer return cast by moving it here. This is // required for older compiler support. break; diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index ec4836f611b..efaa2a7e74e 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -115,7 +115,7 @@ TEST_F(ExprTest, ScalarExpr) { auto i32 = int32(); ASSERT_OK_AND_ASSIGN(auto value, MakeScalar(i32, 10)); ASSERT_OK_AND_ASSIGN(auto expr, ScalarExpr::Make(value)); - EXPECT_EQ(expr->kind(), Expr::SCALAR_LITERAL); + EXPECT_EQ(expr->kind(), ExprKind::SCALAR_LITERAL); EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); EXPECT_EQ(*expr->scalar(), *value); } @@ -127,15 +127,16 @@ TEST_F(ExprTest, FieldRefExpr) { auto f_i32 = field("i32", i32); ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(f_i32)); - EXPECT_EQ(expr->kind(), Expr::FIELD_REFERENCE); + EXPECT_EQ(expr->kind(), ExprKind::FIELD_REFERENCE); EXPECT_EQ(expr->type(), ExprType::Array(i32)); EXPECT_THAT(expr->field(), PtrEquals(f_i32)); } template -class CmpExprTest : public ExprTest { +class CompareExprTest : public ExprTest { public: - Expr::Kind kind() { return expr_traits::kind_id; } + ExprKind kind() { return expr_traits::kind_id; } + CompareKind compare_kind() { return expr_traits::compare_kind_id; } Result> Make(std::shared_ptr left, std::shared_ptr right) { @@ -143,10 +144,12 @@ class CmpExprTest : public ExprTest { } }; -using CompareExprs = ::testing::Types; +using CompareExprs = + ::testing::Types; -TYPED_TEST_CASE(CmpExprTest, CompareExprs); -TYPED_TEST(CmpExprTest, BasicCompareExpr) { +TYPED_TEST_CASE(CompareExprTest, CompareExprs); +TYPED_TEST(CompareExprTest, BasicCompareExpr) { auto i32 = int32(); auto f_i32 = field("i32", i32); @@ -168,6 +171,7 @@ TYPED_TEST(CmpExprTest, BasicCompareExpr) { ASSERT_OK_AND_ASSIGN(auto expr, this->Make(f_expr, s_expr)); EXPECT_EQ(expr->kind(), this->kind()); + EXPECT_EQ(expr->compare_kind(), this->compare_kind()); // Ensure type is broadcasted EXPECT_EQ(expr->type(), ExprType::Array(boolean())); EXPECT_TRUE(expr->type().IsPredicate()); diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc index 609e2f74d63..10ec42b8a6f 100644 --- a/cpp/src/arrow/engine/logical_plan_test.cc +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -100,7 +100,7 @@ TEST_F(LogicalPlanBuilderTest, Filter) { EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); - EXPECT_OK_AND_ASSIGN(auto predicate, EqualCmpExpr::Make(field, scalar)); + EXPECT_OK_AND_ASSIGN(auto predicate, EqualExpr::Make(field, scalar)); EXPECT_OK_AND_ASSIGN(auto filter, builder.Filter(table, predicate)); } diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index a3292704c72..682f6d1ae8d 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -17,20 +17,58 @@ #pragma once +#include + namespace arrow { namespace engine { class ExprType; +/// Tag identifier for the expression type. +enum ExprKind : uint8_t { + /// A Scalar literal, i.e. a constant. + SCALAR_LITERAL, + /// A Field reference in a schema. + FIELD_REFERENCE, + + // Comparison operators, + COMPARE_OP, + + /// Empty relation with a known schema. + EMPTY_REL, + /// Scan relational operator + SCAN_REL, + /// Projection relational operator + PROJECTION_REL, + /// Filter relational operator + FILTER_REL, +}; + class Expr; class ScalarExpr; class FieldRefExpr; -class EqualCmpExpr; -class NotEqualCmpExpr; -class GreaterThanCmpExpr; -class GreaterEqualThanCmpExpr; -class LessThanCmpExpr; -class LessEqualThanCmpExpr; + +/// Tag identifier for comparison operators +enum CompareKind : uint8_t { + EQUAL, + NOT_EQUAL, + GREATER_THAN, + GREATER_THAN_EQUAL, + LESS_THAN, + LESS_THAN_EQUAL, +}; + +class CompareOpExpr; +class EqualExpr; +class NotEqualExpr; +class GreaterThanExpr; +class GreaterThanEqualExpr; +class LessThanExpr; +class LessThanEqualExpr; + +template +class RelExpr; + class EmptyRelExpr; class ScanRelExpr; class ProjectionRelExpr; diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h index 52abeb6a99d..e8f3cc00f29 100644 --- a/cpp/src/arrow/engine/type_traits.h +++ b/cpp/src/arrow/engine/type_traits.h @@ -17,14 +17,87 @@ #pragma once -#include "arrow/engine/expression.h" -#include "arrow/type_traits.h" +#include + +#include "arrow/engine/type_fwd.h" namespace arrow { namespace engine { +template +using enable_if_t = typename std::enable_if::type; + +template +struct expr_traits; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::SCALAR_LITERAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::FIELD_REFERENCE; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::NOT_EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::GREATER_THAN; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::GREATER_THAN_EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::LESS_THAN; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::LESS_THAN_EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::EMPTY_REL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::SCAN_REL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::PROJECTION_REL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::FILTER_REL; +}; + template -using is_compare_expr = std::is_base_of, E>; +using is_compare_expr = std::is_base_of; template using enable_if_compare_expr = enable_if_t::value, Ret>; From 52d5a9b05d83fc7472b3d163ffd641806ecf4458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Wed, 11 Mar 2020 22:09:26 -0400 Subject: [PATCH 10/21] FieldRef should preserve the input --- cpp/src/arrow/engine/expression.cc | 74 +++++++++++++++++++++---- cpp/src/arrow/engine/expression.h | 22 +++++--- cpp/src/arrow/engine/expression_test.cc | 23 ++++++-- cpp/src/arrow/engine/logical_plan.cc | 49 ++-------------- cpp/src/arrow/engine/logical_plan.h | 10 +--- cpp/src/arrow/util/macros.h | 6 ++ 6 files changed, 104 insertions(+), 80 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index e06267ba6ce..76d8ddcab72 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -27,6 +27,19 @@ namespace engine { // ExprType // +std::string ShapeToString(ExprType::Shape shape) { + switch (shape) { + case ExprType::SCALAR: + return "scalar"; + case ExprType::ARRAY: + return "array"; + case ExprType::TABLE: + return "table"; + } + + return ""; +} + ExprType ExprType::Scalar(std::shared_ptr type) { return ExprType(std::move(type), Shape::SCALAR); } @@ -155,13 +168,15 @@ Result ExprType::Broadcast(const ExprType& lhs, const ExprType& rhs) { return lhs; } -#define ERROR_IF(cond, ...) \ - do { \ - if (ARROW_PREDICT_FALSE(cond)) { \ - return Status::Invalid(__VA_ARGS__); \ - } \ +#define ERROR_IF_TYPE(cond, ErrorType, ...) \ + do { \ + if (ARROW_PREDICT_FALSE(cond)) { \ + return Status::ErrorType(__VA_ARGS__); \ + } \ } while (false) +#define ERROR_IF(cond, ...) ERROR_IF_TYPE(cond, Invalid, __VA_ARGS__) + // // Expr // @@ -195,7 +210,8 @@ struct ExprEqualityVisitor { bool operator()(const FieldRefExpr& rhs) const { auto lhs_field = internal::checked_cast(lhs); - return lhs_field.field()->Equals(*rhs.field()); + return lhs_field.index() == rhs.index() && + lhs_field.operand()->Equals(*rhs.operand()); } template @@ -275,12 +291,45 @@ Result> ScalarExpr::Make(std::shared_ptr sca // FieldRefExpr // -FieldRefExpr::FieldRefExpr(std::shared_ptr f) - : Expr(FIELD_REFERENCE, ExprType::Array(f->type())), field_(std::move(f)) {} +FieldRefExpr::FieldRefExpr(std::shared_ptr input, int index) + : UnaryOpMixin(std::move(input)), + Expr(FIELD_REFERENCE, + ExprType::Array(operand()->type().schema()->field(index)->type())), + index_(index) {} + +Result> FieldRefExpr::Make(std::shared_ptr input, + int index) { + ERROR_IF(input == nullptr, "FieldRefExpr's input must be non-null"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.IsTable(), "FieldRefExpr's input must have a table shape, got '", + ShapeToString(expr_type.shape()), "'"); + + auto schema = expr_type.schema(); + ERROR_IF_TYPE(index < 0 || index >= schema->num_fields(), KeyError, + "FieldRefExpr's index is out of bound, '", index, "' not in range [0, ", + schema->num_fields(), ")"); + + return std::shared_ptr(new FieldRefExpr(std::move(input), index)); +} + +Result> FieldRefExpr::Make(std::shared_ptr input, + std::string field_name) { + ERROR_IF(input == nullptr, "FieldRefExpr's input must be non-null"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.IsTable(), "FieldRefExpr's input must have a table shape, got '", + ShapeToString(expr_type.shape()), "'"); + + auto schema = expr_type.schema(); + auto field = schema->GetFieldByName(field_name); + ERROR_IF_TYPE(field == nullptr, KeyError, + "FieldRefExpr's can't reference with field name '", field_name, "'"); + + auto index = schema->GetFieldIndex(field_name); + ERROR_IF(index == -1, "FieldRefExpr's index by name is invalid."); -Result> FieldRefExpr::Make(std::shared_ptr field) { - ERROR_IF(field == nullptr, "FieldRefExpr's field must be non-null"); - return std::shared_ptr(new FieldRefExpr(std::move(field))); + return std::shared_ptr(new FieldRefExpr(std::move(input), index)); } // @@ -317,7 +366,7 @@ ProjectionRelExpr::ProjectionRelExpr(std::shared_ptr input, Result> ProjectionRelExpr::Make( std::shared_ptr input, std::vector> expressions) { ERROR_IF(input == nullptr, "ProjectionRelExpr's input must be non-null."); - ERROR_IF(expressions.empty(), "Must project at least one column."); + ERROR_IF(expressions.empty(), "Must project at least one expression."); auto n_fields = expressions.size(); std::vector> fields; @@ -358,6 +407,7 @@ FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate_(std::move(predicate)) {} #undef ERROR_IF +#undef ERROR_IF_TYPE } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index a6bab480457..c7281761440 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -29,6 +29,7 @@ #include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" +#include "arrow/util/macros.h" namespace arrow { namespace engine { @@ -228,16 +229,19 @@ class ARROW_EXPORT ScalarExpr : public Expr { }; /// References a column in a table/dataset -class ARROW_EXPORT FieldRefExpr : public Expr { +class ARROW_EXPORT FieldRefExpr : public UnaryOpMixin, public Expr { public: - static Result> Make(std::shared_ptr field); + static Result> Make(std::shared_ptr input, + int index); + static Result> Make(std::shared_ptr input, + std::string field_name); - const std::shared_ptr& field() const { return field_; } + int index() const { return index_; } private: - explicit FieldRefExpr(std::shared_ptr field); + FieldRefExpr(std::shared_ptr input, int index); - std::shared_ptr field_; + int index_; }; /// @@ -459,6 +463,8 @@ auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { case CompareKind::LESS_THAN_EQUAL: return visitor(internal::checked_cast(expr)); } + + ARROW_UNREACHABLE; } case ExprKind::EMPTY_REL: @@ -468,12 +474,10 @@ auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { case ExprKind::PROJECTION_REL: return visitor(internal::checked_cast(expr)); case ExprKind::FILTER_REL: - // LEAVE LAST or update the outer return cast by moving it here. This is - // required for older compiler support. - break; + return visitor(internal::checked_cast(expr)); } - return visitor(internal::checked_cast(expr)); + ARROW_UNREACHABLE; } } // namespace engine diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index efaa2a7e74e..17e5c754e35 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -121,15 +121,25 @@ TEST_F(ExprTest, ScalarExpr) { } TEST_F(ExprTest, FieldRefExpr) { - ASSERT_RAISES(Invalid, FieldRefExpr::Make(nullptr)); - auto i32 = int32(); auto f_i32 = field("i32", i32); + auto schema = arrow::schema({f_i32}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + ASSERT_RAISES(Invalid, FieldRefExpr::Make(nullptr, 0)); + ASSERT_RAISES(Invalid, FieldRefExpr::Make(input, -1)); + ASSERT_RAISES(Invalid, FieldRefExpr::Make(input, 1)); + ASSERT_RAISES(Invalid, FieldRefExpr::Make(input, "not_present")); + + ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(input, 0)); + EXPECT_EQ(expr->kind(), ExprKind::FIELD_REFERENCE); + EXPECT_EQ(expr->type(), ExprType::Array(i32)); + EXPECT_THAT(expr->index(), 0); - ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(f_i32)); + ASSERT_OK_AND_ASSIGN(expr, FieldRefExpr::Make(input, "i32")); EXPECT_EQ(expr->kind(), ExprKind::FIELD_REFERENCE); EXPECT_EQ(expr->type(), ExprType::Array(i32)); - EXPECT_THAT(expr->field(), PtrEquals(f_i32)); + EXPECT_THAT(expr->index(), 0); } template @@ -151,10 +161,11 @@ using CompareExprs = TYPED_TEST_CASE(CompareExprTest, CompareExprs); TYPED_TEST(CompareExprTest, BasicCompareExpr) { auto i32 = int32(); - auto f_i32 = field("i32", i32); - ASSERT_OK_AND_ASSIGN(auto f_expr, FieldRefExpr::Make(f_i32)); + auto schema = arrow::schema({f_i32}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + ASSERT_OK_AND_ASSIGN(auto f_expr, FieldRefExpr::Make(input, "i32")); ASSERT_OK_AND_ASSIGN(auto s_i32, MakeScalar(i32, 42)); ASSERT_OK_AND_ASSIGN(auto s_expr, ScalarExpr::Make(s_i32)); diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc index b9fc01946c3..aecad0dcfbb 100644 --- a/cpp/src/arrow/engine/logical_plan.cc +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -64,6 +64,7 @@ using ResultExpr = LogicalPlanBuilder::ResultExpr; } while (false) #define ERROR_IF(cond, ...) ERROR_IF_TYPE(cond, Invalid, __VA_ARGS__) + // // Leaf builder. // @@ -74,35 +75,12 @@ ResultExpr LogicalPlanBuilder::Scalar(const std::shared_ptr& scal ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, const std::string& field_name) { - ERROR_IF(input == nullptr, "Input expression must be non-null"); - - auto expr_type = input->type(); - ERROR_IF(!expr_type.IsTable(), "Input expression does not have a Table shape."); - - auto field = expr_type.schema()->GetFieldByName(field_name); - ERROR_IF_TYPE(field == nullptr, KeyError, "Cannot reference field '", field_name, - "' in schema."); - - return FieldRefExpr::Make(std::move(field)); + return FieldRefExpr::Make(input, field_name); } ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, int field_index) { - ERROR_IF(input == nullptr, "Input expression must be non-null"); - ERROR_IF_TYPE(field_index < 0, KeyError, "Field index must be positive"); - - auto expr_type = input->type(); - ERROR_IF(!expr_type.IsTable(), "Input expression does not have a Table shape."); - - auto schema = expr_type.schema(); - auto num_fields = schema->num_fields(); - ERROR_IF_TYPE(field_index >= num_fields, KeyError, "Field index ", field_index, - " out of bounds."); - - auto field = expr_type.schema()->field(field_index); - ERROR_IF_TYPE(field == nullptr, KeyError, "Field at index ", field_index, " is null."); - - return FieldRefExpr::Make(std::move(field)); + return FieldRefExpr::Make(input, field_index); } // @@ -117,25 +95,15 @@ ResultExpr LogicalPlanBuilder::Scan(const std::string& table_name) { ResultExpr LogicalPlanBuilder::Filter(const std::shared_ptr& input, const std::shared_ptr& predicate) { - ERROR_IF(input == nullptr, "Input expression can't be null."); - ERROR_IF(predicate == nullptr, "Predicate expression can't be null."); return FilterRelExpr::Make(std::move(input), std::move(predicate)); } ResultExpr LogicalPlanBuilder::Project( const std::shared_ptr& input, const std::vector>& expressions) { - ERROR_IF(input == nullptr, "Input expression can't be null."); - ERROR_IF(expressions.empty(), "Must have at least one expression."); return ProjectionRelExpr::Make(input, expressions); } -ResultExpr LogicalPlanBuilder::Mutate( - const std::shared_ptr& input, - const std::vector>& expressions) { - return Project(input, expressions); -} - ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, const std::vector& column_names) { ERROR_IF(input == nullptr, "Input expression can't be null."); @@ -150,11 +118,6 @@ ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, return Project(input, expressions); } -ResultExpr LogicalPlanBuilder::Select(const std::shared_ptr& input, - const std::vector& column_names) { - return Project(input, column_names); -} - ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, const std::vector& column_indices) { ERROR_IF(input == nullptr, "Input expression can't be null."); @@ -169,12 +132,8 @@ ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, return Project(input, expressions); } -ResultExpr LogicalPlanBuilder::Select(const std::shared_ptr& input, - const std::vector& column_indices) { - return Project(input, column_indices); -} - #undef ERROR_IF +#undef ERROR_IF_TYPE } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index 5296947e1c0..9b460dca5ae 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -68,10 +68,10 @@ class LogicalPlanBuilder { /// \brief Construct a Scalar literal. ResultExpr Scalar(const std::shared_ptr& scalar); - /// \brief References a field by name. - ResultExpr Field(const std::shared_ptr& input, const std::string& field_name); /// \brief References a field by index. ResultExpr Field(const std::shared_ptr& input, int field_index); + /// \brief References a field by name. + ResultExpr Field(const std::shared_ptr& input, const std::string& field_name); /// \brief Scan a Table/Dataset from the Catalog. ResultExpr Scan(const std::string& table_name); @@ -105,8 +105,6 @@ class LogicalPlanBuilder { /// \brief Project (mutate) columns with given expressions. ResultExpr Project(const std::shared_ptr& input, const std::vector>& expressions); - ResultExpr Mutate(const std::shared_ptr& input, - const std::vector>& expressions); /// \brief Project (select) columns by names. /// @@ -114,8 +112,6 @@ class LogicalPlanBuilder { /// names. Duplicate and ordering are preserved. ResultExpr Project(const std::shared_ptr& input, const std::vector& column_names); - ResultExpr Select(const std::shared_ptr& input, - const std::vector& column_names); /// \brief Project (select) columns by indices. /// @@ -123,8 +119,6 @@ class LogicalPlanBuilder { /// indices. Duplicate and ordering are preserved. ResultExpr Project(const std::shared_ptr& input, const std::vector& column_indices); - ResultExpr Select(const std::shared_ptr& input, - const std::vector& column_indices); /// @} diff --git a/cpp/src/arrow/util/macros.h b/cpp/src/arrow/util/macros.h index 7d04a80e802..90760a4c23d 100644 --- a/cpp/src/arrow/util/macros.h +++ b/cpp/src/arrow/util/macros.h @@ -61,6 +61,12 @@ #define ARROW_PREFETCH(addr) #endif +#if defined(__GNUC__) +#define ARROW_UNREACHABLE __builtin_unreachable() +#elif defined(_MSC_VER) +#define ARROW_UNREACHABLE __assume(0) +#endif + #if (defined(__GNUC__) || defined(__APPLE__)) #define ARROW_MUST_USE_RESULT __attribute__((warn_unused_result)) #elif defined(_MSC_VER) From 7d78abe8123f2207aa63df82d6a5dcbf6171c9f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 12 Mar 2020 09:39:21 -0400 Subject: [PATCH 11/21] Add explicit libarrow_engine.so --- ci/cpp-msvc-build-main.bat | 1 + ci/docker/conda-cpp.dockerfile | 1 + ci/docker/conda-integration.dockerfile | 1 + ci/docker/cuda-10.0-cpp.dockerfile | 1 + ci/docker/cuda-10.1-cpp.dockerfile | 1 + ci/docker/cuda-9.1-cpp.dockerfile | 1 + ci/docker/debian-10-cpp.dockerfile | 1 + ci/docker/fedora-29-cpp.dockerfile | 1 + ci/docker/ubuntu-14.04-cpp.dockerfile | 1 + ci/docker/ubuntu-16.04-cpp.dockerfile | 1 + ci/docker/ubuntu-18.04-cpp.dockerfile | 1 + ci/scripts/PKGBUILD | 1 + ci/scripts/cpp_build.sh | 1 + cpp/CMakeLists.txt | 5 + cpp/cmake_modules/DefineOptions.cmake | 2 + cpp/cmake_modules/FindArrowEngine.cmake | 98 +++++++++++++++++++ cpp/src/arrow/CMakeLists.txt | 10 +- .../arrow/engine/ArrowEngineConfig.cmake.in | 37 +++++++ cpp/src/arrow/engine/CMakeLists.txt | 74 +++++++++++++- cpp/src/arrow/engine/api.h | 22 +++++ cpp/src/arrow/engine/arrow-engine.pc.in | 25 +++++ cpp/src/arrow/engine/catalog.cc | 4 + cpp/src/arrow/engine/expression.h | 4 +- cpp/src/arrow/engine/expression_test.cc | 6 +- cpp/src/arrow/engine/pch.h | 24 +++++ dev/archery/archery/cli.py | 2 + dev/archery/archery/lang/cpp.py | 8 +- 27 files changed, 318 insertions(+), 16 deletions(-) create mode 100644 cpp/cmake_modules/FindArrowEngine.cmake create mode 100644 cpp/src/arrow/engine/ArrowEngineConfig.cmake.in create mode 100644 cpp/src/arrow/engine/api.h create mode 100644 cpp/src/arrow/engine/arrow-engine.pc.in create mode 100644 cpp/src/arrow/engine/pch.h diff --git a/ci/cpp-msvc-build-main.bat b/ci/cpp-msvc-build-main.bat index 735073c49cc..fee46e1843c 100644 --- a/ci/cpp-msvc-build-main.bat +++ b/ci/cpp-msvc-build-main.bat @@ -78,6 +78,7 @@ cmake -G "%GENERATOR%" %CMAKE_ARGS% ^ -DARROW_FLIGHT=%ARROW_BUILD_FLIGHT% ^ -DARROW_GANDIVA=%ARROW_BUILD_GANDIVA% ^ -DARROW_DATASET=ON ^ + -DARROW_ENGINE=ON ^ -DARROW_S3=%ARROW_S3% ^ -DARROW_MIMALLOC=ON ^ -DARROW_PARQUET=ON ^ diff --git a/ci/docker/conda-cpp.dockerfile b/ci/docker/conda-cpp.dockerfile index 0e35b6caf6d..190a5fcc0a2 100644 --- a/ci/docker/conda-cpp.dockerfile +++ b/ci/docker/conda-cpp.dockerfile @@ -65,6 +65,7 @@ ENTRYPOINT [ "/bin/bash", "-c", "-l" ] ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=CONDA \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=$CONDA_PREFIX \ diff --git a/ci/docker/conda-integration.dockerfile b/ci/docker/conda-integration.dockerfile index 5672cb8def3..86c33a6cebe 100644 --- a/ci/docker/conda-integration.dockerfile +++ b/ci/docker/conda-integration.dockerfile @@ -45,6 +45,7 @@ ENV ARROW_BUILD_INTEGRATION=ON \ ARROW_FLIGHT=ON \ ARROW_ORC=OFF \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_GANDIVA=OFF \ ARROW_PLASMA=OFF \ ARROW_FILESYSTEM=OFF \ diff --git a/ci/docker/cuda-10.0-cpp.dockerfile b/ci/docker/cuda-10.0-cpp.dockerfile index 0697513d30d..cb077b8e92f 100644 --- a/ci/docker/cuda-10.0-cpp.dockerfile +++ b/ci/docker/cuda-10.0-cpp.dockerfile @@ -76,6 +76,7 @@ ENV ARROW_BUILD_STATIC=OFF \ ARROW_CSV=OFF \ ARROW_CUDA=ON \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_FILESYSTEM=OFF \ ARROW_FLIGHT=OFF \ diff --git a/ci/docker/cuda-10.1-cpp.dockerfile b/ci/docker/cuda-10.1-cpp.dockerfile index 6e86ca97fc5..73ef0603e82 100644 --- a/ci/docker/cuda-10.1-cpp.dockerfile +++ b/ci/docker/cuda-10.1-cpp.dockerfile @@ -76,6 +76,7 @@ ENV ARROW_BUILD_STATIC=OFF \ ARROW_CSV=OFF \ ARROW_CUDA=ON \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_FILESYSTEM=OFF \ ARROW_FLIGHT=OFF \ diff --git a/ci/docker/cuda-9.1-cpp.dockerfile b/ci/docker/cuda-9.1-cpp.dockerfile index bf3242ea180..b7c302d822c 100644 --- a/ci/docker/cuda-9.1-cpp.dockerfile +++ b/ci/docker/cuda-9.1-cpp.dockerfile @@ -76,6 +76,7 @@ ENV ARROW_BUILD_STATIC=OFF \ ARROW_CSV=OFF \ ARROW_CUDA=ON \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_FILESYSTEM=OFF \ ARROW_FLIGHT=OFF \ diff --git a/ci/docker/debian-10-cpp.dockerfile b/ci/docker/debian-10-cpp.dockerfile index e51f482d842..11d1b10c45d 100644 --- a/ci/docker/debian-10-cpp.dockerfile +++ b/ci/docker/debian-10-cpp.dockerfile @@ -61,6 +61,7 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=/usr/local \ diff --git a/ci/docker/fedora-29-cpp.dockerfile b/ci/docker/fedora-29-cpp.dockerfile index 94adb5e7762..0dc1b603fee 100644 --- a/ci/docker/fedora-29-cpp.dockerfile +++ b/ci/docker/fedora-29-cpp.dockerfile @@ -59,6 +59,7 @@ RUN dnf update -y && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=ON \ ARROW_GANDIVA_JAVA=ON \ ARROW_GANDIVA=OFF \ diff --git a/ci/docker/ubuntu-14.04-cpp.dockerfile b/ci/docker/ubuntu-14.04-cpp.dockerfile index 5f24f9a353b..4266f23ef10 100644 --- a/ci/docker/ubuntu-14.04-cpp.dockerfile +++ b/ci/docker/ubuntu-14.04-cpp.dockerfile @@ -58,6 +58,7 @@ RUN apt-get update -y -q && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=OFF \ ARROW_GANDIVA_JAVA=OFF \ ARROW_GANDIVA=OFF \ diff --git a/ci/docker/ubuntu-16.04-cpp.dockerfile b/ci/docker/ubuntu-16.04-cpp.dockerfile index c6773662717..eba39219059 100644 --- a/ci/docker/ubuntu-16.04-cpp.dockerfile +++ b/ci/docker/ubuntu-16.04-cpp.dockerfile @@ -66,6 +66,7 @@ RUN apt-get update -y -q && \ ENV ARROW_BUILD_BENCHMARKS=OFF \ ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_GANDIVA_JAVA=OFF \ ARROW_GANDIVA=ON \ diff --git a/ci/docker/ubuntu-18.04-cpp.dockerfile b/ci/docker/ubuntu-18.04-cpp.dockerfile index 92c75445e62..64b106e33d2 100644 --- a/ci/docker/ubuntu-18.04-cpp.dockerfile +++ b/ci/docker/ubuntu-18.04-cpp.dockerfile @@ -95,6 +95,7 @@ RUN apt-get update -y -q && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=OFF \ ARROW_GANDIVA=ON \ ARROW_HDFS=ON \ diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 133da043eb6..f95ba9fce6d 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -82,6 +82,7 @@ build() { -DARROW_COMPUTE=ON \ -DARROW_CSV=ON \ -DARROW_DATASET=ON \ + -DARROW_ENGINE=ON \ -DARROW_FILESYSTEM=ON \ -DARROW_HDFS=OFF \ -DARROW_JEMALLOC=OFF \ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 0286987caa1..0233f975914 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -58,6 +58,7 @@ cmake -G "${CMAKE_GENERATOR:-Ninja}" \ -DARROW_CUDA=${ARROW_CUDA:-OFF} \ -DARROW_CXXFLAGS=${ARROW_CXXFLAGS:-} \ -DARROW_DATASET=${ARROW_DATASET:-ON} \ + -DARROW_ENGINE=${ARROW_ENGINE:-ON} \ -DARROW_DEPENDENCY_SOURCE=${ARROW_DEPENDENCY_SOURCE:-AUTO} \ -DARROW_EXTRA_ERROR_CONTEXT=${ARROW_EXTRA_ERROR_CONTEXT:-OFF} \ -DARROW_ENABLE_TIMING_TESTS=${ARROW_ENABLE_TIMING_TESTS:-ON} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 553dee28244..b86acef1d8e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -296,7 +296,12 @@ if(ARROW_CUDA OR ARROW_FLIGHT OR ARROW_PARQUET OR ARROW_BUILD_TESTS) set(ARROW_IPC ON) endif() +if(ARROW_ENGINE) + set(ARROW_DATASET ON) +endif() + if(ARROW_DATASET) + set(ARROW_PARQUET ON) set(ARROW_FILESYSTEM ON) endif() diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index f307a6d10da..eece40c859d 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -173,6 +173,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF) + define_option(ARROW_ENGINE "Build the Arrow Query Engine Modules" OFF) + define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF) define_option(ARROW_FLIGHT diff --git a/cpp/cmake_modules/FindArrowEngine.cmake b/cpp/cmake_modules/FindArrowEngine.cmake new file mode 100644 index 00000000000..200fcaa5427 --- /dev/null +++ b/cpp/cmake_modules/FindArrowEngine.cmake @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# - Find Arrow Query Engine (arrow/engine/api.h, libarrow_engine.a, libarrow_engine.so) +# +# This module requires Arrow from which it uses +# arrow_find_package() +# +# This module defines +# ARROW_ENGINE_FOUND, whether Arrow Query Engine has been found +# ARROW_ENGINE_IMPORT_LIB, +# path to libarrow_engine's import library (Windows only) +# ARROW_ENGINE_INCLUDE_DIR, directory containing headers +# ARROW_ENGINE_LIB_DIR, directory containing Arrow Query Engine libraries +# ARROW_ENGINE_SHARED_LIB, path to libarrow_engine's shared library +# ARROW_ENGINE_STATIC_LIB, path to libarrow_engine.a + +if(DEFINED ARROW_ENGINE_FOUND) + return() +endif() + +set(find_package_arguments) +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION) + list(APPEND find_package_arguments + "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}") +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED) + list(APPEND find_package_arguments REQUIRED) +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY) + list(APPEND find_package_arguments QUIET) +endif() +find_package(Arrow ${find_package_arguments}) +find_package(ArrowEngine ${find_package_arguments}) + +if(ARROW_FOUND AND ARROW_DATASET_FOUND) + arrow_find_package(ARROW_ENGINE + "${ARROW_HOME}" + arrow_engine + arrow/engine/api.h + ArrowEngine + arrow-engine) + if(NOT ARROW_ENGINE_VERSION) + set(ARROW_ENGINE_VERSION "${ARROW_VERSION}") + endif() +endif() + +if("${ARROW_ENGINE_VERSION}" VERSION_EQUAL "${ARROW_VERSION}") + set(ARROW_ENGINE_VERSION_MATCH TRUE) +else() + set(ARROW_ENGINE_VERSION_MATCH FALSE) +endif() + +mark_as_advanced(ARROW_ENGINE_IMPORT_LIB + ARROW_ENGINE_INCLUDE_DIR + ARROW_ENGINE_LIBS + ARROW_ENGINE_LIB_DIR + ARROW_ENGINE_SHARED_IMP_LIB + ARROW_ENGINE_SHARED_LIB + ARROW_ENGINE_STATIC_LIB + ARROW_ENGINE_VERSION + ARROW_ENGINE_VERSION_MATCH) + +find_package_handle_standard_args(ArrowEngine + REQUIRED_VARS + ARROW_ENGINE_INCLUDE_DIR + ARROW_ENGINE_LIB_DIR + ARROW_ENGINE_VERSION_MATCH + VERSION_VAR + ARROW_ENGINE_VERSION) +set(ARROW_ENGINE_FOUND ${ArrowEngine_FOUND}) + +if(ArrowEngine_FOUND AND NOT ArrowEngine_FIND_QUIETLY) + message(STATUS "Found the Arrow Engine by ${ARROW_ENGINE_FIND_APPROACH}") + message( + STATUS "Found the Arrow Engine shared library: ${ARROW_ENGINE_SHARED_LIB}" + ) + message( + STATUS "Found the Arrow Engine import library: ${ARROW_ENGINE_IMPORT_LIB}" + ) + message( + STATUS "Found the Arrow Engine static library: ${ARROW_ENGINE_STATIC_LIB}" + ) +endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 7552337bc72..8d2ba524620 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -339,10 +339,7 @@ if(ARROW_COMPUTE) compute/kernels/match.cc compute/kernels/util_internal.cc compute/operations/cast.cc - compute/operations/literal.cc - engine/catalog.cc - engine/expression.cc - engine/logical_plan.cc) + compute/operations/literal.cc) endif() if(ARROW_FILESYSTEM) @@ -579,7 +576,6 @@ endif() if(ARROW_COMPUTE) add_subdirectory(compute) - add_subdirectory(engine) endif() if(ARROW_CUDA) @@ -590,6 +586,10 @@ if(ARROW_DATASET) add_subdirectory(dataset) endif() +if(ARROW_ENGINE) + add_subdirectory(engine) +endif() + if(ARROW_FILESYSTEM) add_subdirectory(filesystem) endif() diff --git a/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in new file mode 100644 index 00000000000..43cce1be535 --- /dev/null +++ b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This config sets the following variables in your project:: +# +# ArrowEngine_FOUND - true if Arrow Query engine is found on the system +# +# This config sets the following targets in your project:: +# +# arrow_engine_shared - for linked as shared library if shared library is built +# arrow_engine_static - for linked as static library if static library is built + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(Arrow) +find_dependency(ArrowDataset) + +# Load targets only once. If we load targets multiple times, CMake reports +# already existent target error. +if(NOT (TARGET arrow_engine_shared OR TARGET arrow_engine_static)) + include("${CMAKE_CURRENT_LIST_DIR}/ArrowEngineTargets.cmake") +endif() diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index adb9cd15598..2a9adcfc5ce 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -15,12 +15,80 @@ # specific language governing permissions and limitations # under the License. +add_custom_target(arrow_engine) + arrow_install_all_headers("arrow/engine") +set(ARROW_ENGINE_SRCS catalog.cc expression.cc logical_plan.cc) + +set(ARROW_ENGINE_LINK_STATIC arrow_static arrow_dataset_static) +set(ARROW_ENGINE_LINK_SHARED arrow_shared arrow_dataset_shared) + +add_arrow_lib(arrow_engine + CMAKE_PACKAGE_NAME + ArrowEngine + PKG_CONFIG_NAME + arrow-engine + OUTPUTS + ARROW_ENGINE_LIBRARIES + SOURCES + ${ARROW_ENGINE_SRCS} + PRECOMPILED_HEADERS + "$<$:arrow/engine/pch.h>" + PRIVATE_INCLUDES + ${ARROW_ENGINE_PRIVATE_INCLUDES} + SHARED_LINK_LIBS + ${ARROW_ENGINE_LINK_SHARED} + STATIC_LINK_LIBS + ${ARROW_ENGINE_LINK_STATIC}) + +if(ARROW_TEST_LINKAGE STREQUAL "static") + set(ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_static ${ARROW_TEST_STATIC_LINK_LIBS}) +else() + set(ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_shared ${ARROW_TEST_SHARED_LINK_LIBS}) +endif() + +foreach(LIB_TARGET ${ARROW_ENGINE_LIBRARIES}) + target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_DS_EXPORTING) +endforeach() + +# Adding unit tests part of the "engine" portion of the test suite +function(ADD_ARROW_ENGINE_TEST REL_TEST_NAME) + set(options) + set(one_value_args PREFIX) + set(multi_value_args LABELS) + cmake_parse_arguments(ARG + "${options}" + "${one_value_args}" + "${multi_value_args}" + ${ARGN}) + + if(ARG_PREFIX) + set(PREFIX ${ARG_PREFIX}) + else() + set(PREFIX "arrow-engine") + endif() + + if(ARG_LABELS) + set(LABELS ${ARG_LABELS}) + else() + set(LABELS "arrow_engine") + endif() + + add_arrow_test(${REL_TEST_NAME} + EXTRA_LINK_LIBS + ${ARROW_ENGINE_TEST_LINK_LIBS} + PREFIX + ${PREFIX} + LABELS + ${LABELS} + ${ARG_UNPARSED_ARGUMENTS}) +endfunction() + # # Unit tests # -add_arrow_test(catalog_test PREFIX arrow-engine) -add_arrow_test(expression_test PREFIX arrow-engine) -add_arrow_test(logical_plan_test PREFIX arrow-engine) +add_arrow_engine_test(catalog_test PREFIX arrow-engine) +add_arrow_engine_test(expression_test PREFIX arrow-engine) +add_arrow_engine_test(logical_plan_test PREFIX arrow-engine) diff --git a/cpp/src/arrow/engine/api.h b/cpp/src/arrow/engine/api.h new file mode 100644 index 00000000000..85ac052b217 --- /dev/null +++ b/cpp/src/arrow/engine/api.h @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/engine/catalog.h" +#include "arrow/engine/expression.h" +#include "arrow/engine/logical_plan.h" diff --git a/cpp/src/arrow/engine/arrow-engine.pc.in b/cpp/src/arrow/engine/arrow-engine.pc.in new file mode 100644 index 00000000000..0ceabbedf68 --- /dev/null +++ b/cpp/src/arrow/engine/arrow-engine.pc.in @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: Apache Arrow Query Engine +Description: Apache Arrow Query Engine provides an API to execute queries on Arrow table and datasets +Version: @ARROW_VERSION@ +Requires: arrow arrow-dataset +Libs: -L${libdir} -larrow_dataset diff --git a/cpp/src/arrow/engine/catalog.cc b/cpp/src/arrow/engine/catalog.cc index a2f173142ce..d5479c222d2 100644 --- a/cpp/src/arrow/engine/catalog.cc +++ b/cpp/src/arrow/engine/catalog.cc @@ -106,6 +106,10 @@ Status CatalogBuilder::Add(std::string name, std::shared_ptr d } Status CatalogBuilder::Add(std::string name, std::shared_ptr
table) { + if (table == nullptr) { + return Status::Invalid("Table entry can't be null."); + } + return Add(Entry(std::move(table), std::move(name))); } diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index c7281761440..abdc9c2e409 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -105,8 +105,8 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { /// \brief Expand the smallest shape to the bigger one if possible. /// - /// \param[in] lhs, first type to broadcast - /// \param[in] rhs, second type to broadcast + /// \param[in] lhs first type to broadcast + /// \param[in] rhs second type to broadcast /// \return broadcasted type or an error why it can't be broadcasted. /// /// Broadcasting promotes the shape of the smallest type to the bigger one if diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 17e5c754e35..7cd8e2726a3 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -127,9 +127,9 @@ TEST_F(ExprTest, FieldRefExpr) { ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); ASSERT_RAISES(Invalid, FieldRefExpr::Make(nullptr, 0)); - ASSERT_RAISES(Invalid, FieldRefExpr::Make(input, -1)); - ASSERT_RAISES(Invalid, FieldRefExpr::Make(input, 1)); - ASSERT_RAISES(Invalid, FieldRefExpr::Make(input, "not_present")); + ASSERT_RAISES(KeyError, FieldRefExpr::Make(input, -1)); + ASSERT_RAISES(KeyError, FieldRefExpr::Make(input, 1)); + ASSERT_RAISES(KeyError, FieldRefExpr::Make(input, "not_present")); ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(input, 0)); EXPECT_EQ(expr->kind(), ExprKind::FIELD_REFERENCE); diff --git a/cpp/src/arrow/engine/pch.h b/cpp/src/arrow/engine/pch.h new file mode 100644 index 00000000000..2014fc9a1f2 --- /dev/null +++ b/cpp/src/arrow/engine/pch.h @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Often-used headers, for precompiling. +// If updating this header, please make sure you check compilation speed +// before checking in. Adding headers which are not used extremely often +// may incur a slowdown, since it makes the precompiled header heavier to load. + +#include "arrow/engine/expression.h" +#include "arrow/pch.h" diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index 744e92405de..9366631e547 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -155,6 +155,8 @@ def _apply_options(cmd, options): help="Build with compute kernels support.") @click.option("--with-dataset", default=False, type=BOOL, help="Build with dataset support.") +@click.option("--with-engine", default=False, type=BOOL, + help="Build with query engine support.") @click.option("--use-sanitizers", default=False, type=BOOL, help="Toggles ARROW_USE_*SAN sanitizers.") @click.option("--with-fuzzing", default=False, type=BOOL, diff --git a/dev/archery/archery/lang/cpp.py b/dev/archery/archery/lang/cpp.py index 607581b3c71..5eb0507bd88 100644 --- a/dev/archery/archery/lang/cpp.py +++ b/dev/archery/archery/lang/cpp.py @@ -46,7 +46,7 @@ def __init__(self, with_parquet=False, # Components with_gandiva=False, with_compute=False, with_dataset=False, - with_plasma=False, with_flight=False, + with_engine=False, with_plasma=False, with_flight=False, # extras with_lint_only=False, with_fuzzing=False, use_gold_linker=True, use_sanitizers=True, @@ -65,12 +65,13 @@ def __init__(self, self.with_benchmarks = with_benchmarks self.with_examples = with_examples self.with_python = with_python - self.with_parquet = with_parquet or with_dataset self.with_gandiva = with_gandiva self.with_plasma = with_plasma self.with_flight = with_flight self.with_compute = with_compute - self.with_dataset = with_dataset + self.with_engine = with_engine + self.with_dataset = with_dataset or self.with_engine + self.with_parquet = with_parquet or self.with_dataset self.with_lint_only = with_lint_only self.with_fuzzing = with_fuzzing @@ -147,6 +148,7 @@ def _gen_defs(self): yield ("ARROW_FLIGHT", truthifier(self.with_flight)) yield ("ARROW_COMPUTE", truthifier(self.with_compute)) yield ("ARROW_DATASET", truthifier(self.with_dataset)) + yield ("ARROW_ENGINE", truthifier(self.with_engine)) if self.use_sanitizers or self.with_fuzzing: yield ("ARROW_USE_ASAN", "ON") From 6fbac47b76634c57065255385657c751b5072432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 12 Mar 2020 13:54:23 -0400 Subject: [PATCH 12/21] Fix msvc --- cpp/src/arrow/engine/expression.cc | 12 ++++++------ cpp/src/arrow/engine/expression.h | 1 + cpp/src/arrow/engine/expression_test.cc | 2 +- cpp/src/arrow/engine/type_fwd.h | 6 ++++-- cpp/src/arrow/util/compare.h | 1 + 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 76d8ddcab72..1552cda3f0c 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -41,29 +41,29 @@ std::string ShapeToString(ExprType::Shape shape) { } ExprType ExprType::Scalar(std::shared_ptr type) { - return ExprType(std::move(type), Shape::SCALAR); + return ExprType(std::move(type), SCALAR); } ExprType ExprType::Array(std::shared_ptr type) { - return ExprType(std::move(type), Shape::ARRAY); + return ExprType(std::move(type), ARRAY); } ExprType ExprType::Table(std::shared_ptr schema) { - return ExprType(std::move(schema), Shape::TABLE); + return ExprType(std::move(schema), TABLE); } ExprType ExprType::Table(std::vector> fields) { - return ExprType(arrow::schema(std::move(fields)), Shape::TABLE); + return ExprType(arrow::schema(std::move(fields)), TABLE); } ExprType::ExprType(std::shared_ptr schema, Shape shape) : schema_(std::move(schema)), shape_(shape) { - DCHECK_EQ(shape, Shape::TABLE); + DCHECK_EQ(shape, TABLE); } ExprType::ExprType(std::shared_ptr type, Shape shape) : data_type_(std::move(type)), shape_(shape) { - DCHECK_NE(shape, Shape::TABLE); + DCHECK_NE(shape, TABLE); } ExprType::ExprType(const ExprType& other) : shape_(other.shape()) { diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index abdc9c2e409..dfb5ac8fb42 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -30,6 +30,7 @@ #include "arrow/type_fwd.h" #include "arrow/util/compare.h" #include "arrow/util/macros.h" +#include "arrow/util/visibility.h" namespace arrow { namespace engine { diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 7cd8e2726a3..798780a241e 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -175,7 +175,7 @@ TYPED_TEST(CompareExprTest, BasicCompareExpr) { ASSERT_RAISES(Invalid, this->Make(nullptr, f_expr)); // Not type compatible - ASSERT_OK_AND_ASSIGN(auto s_i64, MakeScalar(int64(), 42L)); + ASSERT_OK_AND_ASSIGN(auto s_i64, MakeScalar(int64(), 42LL)); ASSERT_OK_AND_ASSIGN(auto s_expr_i64, ScalarExpr::Make(s_i64)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must be of same type"), this->Make(s_expr_i64, f_expr)); diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index 682f6d1ae8d..c4d70caf64a 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -19,13 +19,15 @@ #include +#include "arrow/util/visibility.h" + namespace arrow { namespace engine { class ExprType; /// Tag identifier for the expression type. -enum ExprKind : uint8_t { +enum ARROW_EXPORT ExprKind : uint8_t { /// A Scalar literal, i.e. a constant. SCALAR_LITERAL, /// A Field reference in a schema. @@ -49,7 +51,7 @@ class ScalarExpr; class FieldRefExpr; /// Tag identifier for comparison operators -enum CompareKind : uint8_t { +enum ARROW_EXPORT CompareKind : uint8_t { EQUAL, NOT_EQUAL, GREATER_THAN, diff --git a/cpp/src/arrow/util/compare.h b/cpp/src/arrow/util/compare.h index 287a30d03b2..3be798f69b8 100644 --- a/cpp/src/arrow/util/compare.h +++ b/cpp/src/arrow/util/compare.h @@ -22,6 +22,7 @@ #include #include "arrow/util/macros.h" +#include "arrow/util/visibility.h" namespace arrow { namespace util { From 4a478094cb72f124d4f45ac825f23d908a5eccbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 12 Mar 2020 15:11:58 -0400 Subject: [PATCH 13/21] Add custom ARROW_EN_EXPORT for engine --- cpp/src/arrow/engine/CMakeLists.txt | 2 +- cpp/src/arrow/engine/catalog.h | 5 +-- cpp/src/arrow/engine/expression.h | 43 +++++++++++++------------- cpp/src/arrow/engine/logical_plan.h | 11 +++---- cpp/src/arrow/engine/type_fwd.h | 8 +++-- cpp/src/arrow/engine/visibility.h | 48 +++++++++++++++++++++++++++++ 6 files changed, 83 insertions(+), 34 deletions(-) create mode 100644 cpp/src/arrow/engine/visibility.h diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index 2a9adcfc5ce..b4420320ad6 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -49,7 +49,7 @@ else() endif() foreach(LIB_TARGET ${ARROW_ENGINE_LIBRARIES}) - target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_DS_EXPORTING) + target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_EN_EXPORTING) endforeach() # Adding unit tests part of the "engine" portion of the test suite diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h index d4607320acd..8d831708515 100644 --- a/cpp/src/arrow/engine/catalog.h +++ b/cpp/src/arrow/engine/catalog.h @@ -25,6 +25,7 @@ #include "arrow/type_fwd.h" #include "arrow/util/variant.h" +#include "arrow/engine/visibility.h" namespace arrow { @@ -35,7 +36,7 @@ class Dataset; namespace engine { /// Catalog is made of named Table/Dataset to be referenced in LogicalPlans. -class Catalog { +class ARROW_EN_EXPORT Catalog { public: class Entry; @@ -69,7 +70,7 @@ class Catalog { std::unordered_map datasets_; }; -class CatalogBuilder { +class ARROW_EN_EXPORT CatalogBuilder { public: Status Add(Catalog::Entry entry); Status Add(std::string name, std::shared_ptr); diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index dfb5ac8fb42..2bf69679d33 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -25,12 +25,12 @@ #include "arrow/engine/catalog.h" #include "arrow/engine/type_fwd.h" #include "arrow/engine/type_traits.h" +#include "arrow/engine/visibility.h" #include "arrow/result.h" #include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" #include "arrow/util/macros.h" -#include "arrow/util/visibility.h" namespace arrow { namespace engine { @@ -48,7 +48,7 @@ namespace engine { /// ArrayType(DataType), /// TableType(Schema), /// } -class ARROW_EXPORT ExprType : public util::EqualityComparable { +class ARROW_EN_EXPORT ExprType : public util::EqualityComparable { public: enum Shape : uint8_t { // The expression yields a Scalar, e.g. "1". @@ -148,7 +148,7 @@ class ARROW_EXPORT ExprType : public util::EqualityComparable { }; /// Represents an expression tree -class ARROW_EXPORT Expr : public util::EqualityComparable { +class ARROW_EN_EXPORT Expr : public util::EqualityComparable { public: /// \brief Return the kind of the expression. ExprKind kind() const { return kind_; } @@ -178,7 +178,7 @@ class ARROW_EXPORT Expr : public util::EqualityComparable { /// Operator expressions mixin. /// -class ARROW_EXPORT UnaryOpMixin { +class ARROW_EN_EXPORT UnaryOpMixin { public: const std::shared_ptr& operand() const { return operand_; } @@ -188,7 +188,7 @@ class ARROW_EXPORT UnaryOpMixin { std::shared_ptr operand_; }; -class ARROW_EXPORT BinaryOpMixin { +class ARROW_EN_EXPORT BinaryOpMixin { public: const std::shared_ptr& left_operand() const { return left_operand_; } const std::shared_ptr& right_operand() const { return right_operand_; } @@ -201,7 +201,7 @@ class ARROW_EXPORT BinaryOpMixin { std::shared_ptr right_operand_; }; -class ARROW_EXPORT MultiAryOpMixin { +class ARROW_EN_EXPORT MultiAryOpMixin { public: const std::vector>& operands() const { return operands_; } @@ -217,7 +217,7 @@ class ARROW_EXPORT MultiAryOpMixin { /// /// An unnamed scalar literal expression. -class ARROW_EXPORT ScalarExpr : public Expr { +class ARROW_EN_EXPORT ScalarExpr : public Expr { public: static Result> Make(std::shared_ptr scalar); @@ -230,7 +230,7 @@ class ARROW_EXPORT ScalarExpr : public Expr { }; /// References a column in a table/dataset -class ARROW_EXPORT FieldRefExpr : public UnaryOpMixin, public Expr { +class ARROW_EN_EXPORT FieldRefExpr : public UnaryOpMixin, public Expr { public: static Result> Make(std::shared_ptr input, int index); @@ -249,7 +249,7 @@ class ARROW_EXPORT FieldRefExpr : public UnaryOpMixin, public Expr { /// Comparison expressions /// -class ARROW_EXPORT CompareOpExpr : public BinaryOpMixin, public Expr { +class ARROW_EN_EXPORT CompareOpExpr : public BinaryOpMixin, public Expr { public: CompareKind compare_kind() const { return compare_kind_; } @@ -294,32 +294,33 @@ class BaseCompareExpr : public CompareOpExpr, private CompareOpExpr::MakeMixin { +class ARROW_EN_EXPORT EqualExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT NotEqualExpr : public BaseCompareExpr { +class ARROW_EN_EXPORT NotEqualExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT GreaterThanExpr : public BaseCompareExpr { +class ARROW_EN_EXPORT GreaterThanExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT GreaterThanEqualExpr : public BaseCompareExpr { +class ARROW_EN_EXPORT GreaterThanEqualExpr + : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT LessThanExpr : public BaseCompareExpr { +class ARROW_EN_EXPORT LessThanExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; }; -class ARROW_EXPORT LessThanEqualExpr : public BaseCompareExpr { +class ARROW_EN_EXPORT LessThanEqualExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; }; @@ -330,7 +331,7 @@ class ARROW_EXPORT LessThanEqualExpr : public BaseCompareExpr /// \brief Relational Expressions that acts on tables. template -class ARROW_EXPORT RelExpr : public Expr { +class ARROW_EN_EXPORT RelExpr : public Expr { public: const std::shared_ptr& schema() const { return schema_; } @@ -352,7 +353,7 @@ class ARROW_EXPORT RelExpr : public Expr { /// /// \input schema, the schema of the empty relation /// \ouput relation with no rows of the given input schema -class ARROW_EXPORT EmptyRelExpr : public RelExpr { +class ARROW_EN_EXPORT EmptyRelExpr : public RelExpr { public: static Result> Make(std::shared_ptr schema); @@ -373,7 +374,7 @@ class ARROW_EXPORT EmptyRelExpr : public RelExpr { /// ``` /// SELECT * FROM table; /// ``` -class ARROW_EXPORT ScanRelExpr : public RelExpr { +class ARROW_EN_EXPORT ScanRelExpr : public RelExpr { public: static Result> Make(Catalog::Entry input); @@ -401,8 +402,8 @@ class ARROW_EXPORT ScanRelExpr : public RelExpr { /// ``` /// SELECT a, b, a + b, 1, mean(a) > b FROM relation; /// ``` -class ARROW_EXPORT ProjectionRelExpr : public UnaryOpMixin, - public RelExpr { +class ARROW_EN_EXPORT ProjectionRelExpr : public UnaryOpMixin, + public RelExpr { public: static Result> Make( std::shared_ptr input, std::vector> expressions); @@ -427,7 +428,7 @@ class ARROW_EXPORT ProjectionRelExpr : public UnaryOpMixin, /// ``` /// SELECT * FROM relation WHERE predicate /// ``` -class ARROW_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr { +class ARROW_EN_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr { public: static Result> Make(std::shared_ptr input, std::shared_ptr predicate); diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index 9b460dca5ae..00016f0e547 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -21,9 +21,10 @@ #include #include +#include "arrow/engine/type_fwd.h" +#include "arrow/engine/visibility.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" -#include "arrow/util/variant.h" namespace arrow { @@ -33,11 +34,7 @@ class Dataset; namespace engine { -class Catalog; -class Expr; -class ExprType; - -class LogicalPlan : public util::EqualityComparable { +class ARROW_EN_EXPORT LogicalPlan : public util::EqualityComparable { public: explicit LogicalPlan(std::shared_ptr root); @@ -56,7 +53,7 @@ struct LogicalPlanBuilderOptions { std::shared_ptr catalog; }; -class LogicalPlanBuilder { +class ARROW_EN_EXPORT LogicalPlanBuilder { public: using ResultExpr = Result>; diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index c4d70caf64a..6cb2f56857c 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -19,7 +19,7 @@ #include -#include "arrow/util/visibility.h" +#include "arrow/engine/visibility.h" namespace arrow { namespace engine { @@ -27,7 +27,7 @@ namespace engine { class ExprType; /// Tag identifier for the expression type. -enum ARROW_EXPORT ExprKind : uint8_t { +enum ExprKind : uint8_t { /// A Scalar literal, i.e. a constant. SCALAR_LITERAL, /// A Field reference in a schema. @@ -51,7 +51,7 @@ class ScalarExpr; class FieldRefExpr; /// Tag identifier for comparison operators -enum ARROW_EXPORT CompareKind : uint8_t { +enum CompareKind : uint8_t { EQUAL, NOT_EQUAL, GREATER_THAN, @@ -76,5 +76,7 @@ class ScanRelExpr; class ProjectionRelExpr; class FilterRelExpr; +class Catalog; + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/visibility.h b/cpp/src/arrow/engine/visibility.h new file mode 100644 index 00000000000..0598aee3802 --- /dev/null +++ b/cpp/src/arrow/engine/visibility.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4251) +#else +#pragma GCC diagnostic ignored "-Wattributes" +#endif + +#ifdef ARROW_EN_STATIC +#define ARROW_EN_EXPORT +#elif defined(ARROW_EN_EXPORTING) +#define ARROW_EN_EXPORT __declspec(dllexport) +#else +#define ARROW_EN_EXPORT __declspec(dllimport) +#endif + +#define ARROW_EN_NO_EXPORT +#else // Not Windows +#ifndef ARROW_EN_EXPORT +#define ARROW_EN_EXPORT __attribute__((visibility("default"))) +#endif +#ifndef ARROW_EN_NO_EXPORT +#define ARROW_EN_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif // Non-Windows + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif From 5679f83a8d178ee9411e8b083025990d8247feb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 12 Mar 2020 15:33:40 -0400 Subject: [PATCH 14/21] Fix constant type --- cpp/src/arrow/engine/catalog.h | 2 +- cpp/src/arrow/engine/expression_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h index 8d831708515..a188eb4d283 100644 --- a/cpp/src/arrow/engine/catalog.h +++ b/cpp/src/arrow/engine/catalog.h @@ -23,9 +23,9 @@ #include #include +#include "arrow/engine/visibility.h" #include "arrow/type_fwd.h" #include "arrow/util/variant.h" -#include "arrow/engine/visibility.h" namespace arrow { diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 798780a241e..07ea194291e 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -175,7 +175,7 @@ TYPED_TEST(CompareExprTest, BasicCompareExpr) { ASSERT_RAISES(Invalid, this->Make(nullptr, f_expr)); // Not type compatible - ASSERT_OK_AND_ASSIGN(auto s_i64, MakeScalar(int64(), 42LL)); + ASSERT_OK_AND_ASSIGN(auto s_i64, MakeScalar(int64(), static_cast(42))); ASSERT_OK_AND_ASSIGN(auto s_expr_i64, ScalarExpr::Make(s_i64)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must be of same type"), this->Make(s_expr_i64, f_expr)); From a9fc553de869f8f9a8f5448c6796adf2ec2c17bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Thu, 12 Mar 2020 15:55:06 -0400 Subject: [PATCH 15/21] Make msvc happy --- cpp/src/arrow/engine/expression.cc | 9 ++++++--- cpp/src/arrow/engine/expression.h | 23 +++++++++++++---------- cpp/src/arrow/engine/type_fwd.h | 1 - cpp/src/arrow/engine/type_traits.h | 2 +- cpp/src/arrow/testing/gtest_util.h | 2 +- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 1552cda3f0c..29dc5732318 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -336,6 +336,9 @@ Result> FieldRefExpr::Make(std::shared_ptr i // EmptyRelExpr // +EmptyRelExpr::EmptyRelExpr(std::shared_ptr schema) + : RelExpr(ExprKind::EMPTY_REL, std::move(schema)) {} + Result> EmptyRelExpr::Make(std::shared_ptr schema) { ERROR_IF(schema == nullptr, "EmptyRelExpr schema must be non-null"); return std::shared_ptr(new EmptyRelExpr(std::move(schema))); @@ -346,7 +349,7 @@ Result> EmptyRelExpr::Make(std::shared_ptr // ScanRelExpr::ScanRelExpr(Catalog::Entry input) - : RelExpr(input.schema()), input_(std::move(input)) {} + : RelExpr(ExprKind::SCAN_REL, input.schema()), input_(std::move(input)) {} Result> ScanRelExpr::Make(Catalog::Entry input) { return std::shared_ptr(new ScanRelExpr(std::move(input))); @@ -360,7 +363,7 @@ ProjectionRelExpr::ProjectionRelExpr(std::shared_ptr input, std::shared_ptr schema, std::vector> expressions) : UnaryOpMixin(std::move(input)), - RelExpr(std::move(schema)), + RelExpr(ExprKind::PROJECTION_REL, std::move(schema)), expressions_(std::move(expressions)) {} Result> ProjectionRelExpr::Make( @@ -403,7 +406,7 @@ Result> FilterRelExpr::Make( FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate) : UnaryOpMixin(std::move(input)), - RelExpr(operand()->type().schema()), + RelExpr(ExprKind::FILTER_REL, operand()->type().schema()), predicate_(std::move(predicate)) {} #undef ERROR_IF diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 2bf69679d33..36506de4f38 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -297,32 +297,38 @@ class BaseCompareExpr : public CompareOpExpr, private CompareOpExpr::MakeMixin { protected: using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; }; class ARROW_EN_EXPORT NotEqualExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; }; class ARROW_EN_EXPORT GreaterThanExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; }; class ARROW_EN_EXPORT GreaterThanEqualExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; }; class ARROW_EN_EXPORT LessThanExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; }; class ARROW_EN_EXPORT LessThanEqualExpr : public BaseCompareExpr { protected: using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; }; /// @@ -330,15 +336,13 @@ class ARROW_EN_EXPORT LessThanEqualExpr : public BaseCompareExpr class ARROW_EN_EXPORT RelExpr : public Expr { public: const std::shared_ptr& schema() const { return schema_; } protected: - explicit RelExpr(std::shared_ptr schema) - : Expr(expr_traits::kind_id, ExprType::Table(schema)), - schema_(std::move(schema)) {} + explicit RelExpr(ExprKind kind, std::shared_ptr schema) + : Expr(kind, ExprType::Table(schema)), schema_(std::move(schema)) {} std::shared_ptr schema_; }; @@ -353,12 +357,12 @@ class ARROW_EN_EXPORT RelExpr : public Expr { /// /// \input schema, the schema of the empty relation /// \ouput relation with no rows of the given input schema -class ARROW_EN_EXPORT EmptyRelExpr : public RelExpr { +class ARROW_EN_EXPORT EmptyRelExpr : public RelExpr { public: static Result> Make(std::shared_ptr schema); protected: - using RelExpr::RelExpr; + explicit EmptyRelExpr(std::shared_ptr schema); }; /// \brief Materialize a relation from a dataset. @@ -374,7 +378,7 @@ class ARROW_EN_EXPORT EmptyRelExpr : public RelExpr { /// ``` /// SELECT * FROM table; /// ``` -class ARROW_EN_EXPORT ScanRelExpr : public RelExpr { +class ARROW_EN_EXPORT ScanRelExpr : public RelExpr { public: static Result> Make(Catalog::Entry input); @@ -402,8 +406,7 @@ class ARROW_EN_EXPORT ScanRelExpr : public RelExpr { /// ``` /// SELECT a, b, a + b, 1, mean(a) > b FROM relation; /// ``` -class ARROW_EN_EXPORT ProjectionRelExpr : public UnaryOpMixin, - public RelExpr { +class ARROW_EN_EXPORT ProjectionRelExpr : public UnaryOpMixin, public RelExpr { public: static Result> Make( std::shared_ptr input, std::vector> expressions); @@ -428,7 +431,7 @@ class ARROW_EN_EXPORT ProjectionRelExpr : public UnaryOpMixin, /// ``` /// SELECT * FROM relation WHERE predicate /// ``` -class ARROW_EN_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr { +class ARROW_EN_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr { public: static Result> Make(std::shared_ptr input, std::shared_ptr predicate); diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index 6cb2f56857c..6ca1b47bd1a 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -68,7 +68,6 @@ class GreaterThanEqualExpr; class LessThanExpr; class LessThanEqualExpr; -template class RelExpr; class EmptyRelExpr; diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h index e8f3cc00f29..3303273547e 100644 --- a/cpp/src/arrow/engine/type_traits.h +++ b/cpp/src/arrow/engine/type_traits.h @@ -103,7 +103,7 @@ template using enable_if_compare_expr = enable_if_t::value, Ret>; template -using is_relational_expr = std::is_base_of, E>; +using is_relational_expr = std::is_base_of; template using enable_if_relational_expr = enable_if_t::value, Ret>; diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 82ea1cb60d1..e2be60c5908 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -398,7 +398,7 @@ inline void BitmapFromVector(const std::vector& is_valid, } // Returns a table with 0 rows of a given schema. -std::shared_ptr
MockTable(std::shared_ptr schema); +ARROW_EXPORT std::shared_ptr
MockTable(std::shared_ptr schema); template void AssertSortedEquals(std::vector u, std::vector v) { From f35c500e5e8f64f63fecd4ee19298d839c5d8885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 13 Mar 2020 08:18:46 -0400 Subject: [PATCH 16/21] Make inner class exportable --- cpp/src/arrow/engine/catalog.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h index a188eb4d283..fc16f9c7c0e 100644 --- a/cpp/src/arrow/engine/catalog.h +++ b/cpp/src/arrow/engine/catalog.h @@ -45,7 +45,7 @@ class ARROW_EN_EXPORT Catalog { Result Get(const std::string& name) const; Result> GetSchema(const std::string& name) const; - class Entry { + class ARROW_EN_EXPORT Entry { public: Entry(std::shared_ptr dataset, std::string name); Entry(std::shared_ptr
table, std::string name); From 369af820c224741392a3b6235994923344cc3d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 13 Mar 2020 10:17:44 -0400 Subject: [PATCH 17/21] Try msvc --- cpp/src/arrow/engine/expression.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 36506de4f38..73989932b7f 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -286,7 +286,8 @@ class ARROW_EN_EXPORT CompareOpExpr : public BinaryOpMixin, public Expr { }; template -class BaseCompareExpr : public CompareOpExpr, private CompareOpExpr::MakeMixin { +class BaseCompareExpr : public CompareOpExpr, + protected CompareOpExpr::MakeMixin { public: using CompareOpExpr::MakeMixin::Make; From 0423241a6a025fef0a74fbd47d9fcabbbb1223d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Mon, 16 Mar 2020 10:59:20 -0400 Subject: [PATCH 18/21] Add Aggregate operators with Count and Sum --- cpp/src/arrow/compute/kernels/sum_internal.h | 26 +++++++++ cpp/src/arrow/engine/expression.cc | 39 ++++++++++++++ cpp/src/arrow/engine/expression.h | 46 ++++++++++++++++ cpp/src/arrow/engine/expression_test.cc | 57 ++++++++++++++++++++ cpp/src/arrow/engine/logical_plan.cc | 12 +++++ cpp/src/arrow/engine/logical_plan.h | 15 +++++- cpp/src/arrow/engine/logical_plan_test.cc | 22 +++++++- cpp/src/arrow/engine/type_fwd.h | 14 +++++ cpp/src/arrow/engine/type_traits.h | 12 +++++ cpp/src/arrow/testing/gmock.h | 1 + cpp/src/arrow/type_traits.h | 4 ++ 11 files changed, 245 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/sum_internal.h b/cpp/src/arrow/compute/kernels/sum_internal.h index 302d004e399..21f28da4db5 100644 --- a/cpp/src/arrow/compute/kernels/sum_internal.h +++ b/cpp/src/arrow/compute/kernels/sum_internal.h @@ -54,6 +54,32 @@ struct FindAccumulatorType> { using Type = DoubleType; }; +#define ACCUMULATOR_TYPE_CASE(ID, TYPE) \ + case Type::ID: \ + return TypeTraits::Type>::type_singleton(); + +static inline std::shared_ptr GetAccumulatorType( + const std::shared_ptr& type) { + switch (type->id()) { + ACCUMULATOR_TYPE_CASE(INT8, Int8Type) + ACCUMULATOR_TYPE_CASE(INT16, Int16Type) + ACCUMULATOR_TYPE_CASE(INT32, Int32Type) + ACCUMULATOR_TYPE_CASE(INT64, Int64Type) + ACCUMULATOR_TYPE_CASE(UINT8, UInt8Type) + ACCUMULATOR_TYPE_CASE(UINT16, UInt16Type) + ACCUMULATOR_TYPE_CASE(UINT32, UInt32Type) + ACCUMULATOR_TYPE_CASE(UINT64, UInt64Type) + ACCUMULATOR_TYPE_CASE(FLOAT, FloatType) + ACCUMULATOR_TYPE_CASE(DOUBLE, DoubleType) + default: + return nullptr; + } + + ARROW_UNREACHABLE; +} + +#undef ACCUMULATOR_TYPE_CASE + template class SumAggregateFunction final : public AggregateFunctionStaticState { using CType = typename TypeTraits::CType; diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc index 29dc5732318..121f86d84d3 100644 --- a/cpp/src/arrow/engine/expression.cc +++ b/cpp/src/arrow/engine/expression.cc @@ -16,6 +16,8 @@ // under the License. #include "arrow/engine/expression.h" + +#include "arrow/compute/kernels/sum_internal.h" #include "arrow/scalar.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" @@ -189,6 +191,8 @@ std::string Expr::kind_name() const { return "field_ref"; case ExprKind::COMPARE_OP: return "compare_op"; + case ExprKind::AGGREGATE_FN_OP: + return "aggregate_fn_op"; case ExprKind::EMPTY_REL: return "empty_rel"; case ExprKind::SCAN_REL: @@ -332,6 +336,41 @@ Result> FieldRefExpr::Make(std::shared_ptr i return std::shared_ptr(new FieldRefExpr(std::move(input), index)); } +// +// CountExpr +// + +CountExpr::CountExpr(std::shared_ptr input) + : UnaryOpMixin(std::move(input)), + AggregateFnExpr(ExprType::Scalar(int64()), AggregateFnKind::COUNT) {} + +Result> CountExpr::Make(std::shared_ptr input) { + ERROR_IF(input == nullptr, "CountExpr's input must be non-null"); + return std::shared_ptr(new CountExpr(std::move(input))); +} + +// +// SumExpr +// + +SumExpr::SumExpr(std::shared_ptr input) + : UnaryOpMixin(std::move(input)), + AggregateFnExpr( + ExprType::Scalar(arrow::compute::GetAccumulatorType(operand()->type().type())), + AggregateFnKind::SUM) {} + +Result> SumExpr::Make(std::shared_ptr input) { + ERROR_IF(input == nullptr, "SumExpr's input must be non-null"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.HasType(), "SumExpr's input must be a Scalar or an Array"); + + auto type = expr_type.type(); + ERROR_IF(!is_numeric(type->id()), "SumExpr's require an input with numeric type"); + + return std::shared_ptr(new SumExpr(std::move(input))); +} + // // EmptyRelExpr // diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index 73989932b7f..d88b47dc608 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -332,6 +332,40 @@ class ARROW_EN_EXPORT LessThanEqualExpr : public BaseCompareExpr; }; +/// +/// Aggregate Functions +/// + +/// \brief Aggregate function operators collapse arrays and scalars to scalar. +class ARROW_EN_EXPORT AggregateFnExpr : public Expr { + public: + AggregateFnKind aggregate_kind() const { return aggregate_kind_; } + + protected: + AggregateFnExpr(ExprType type, AggregateFnKind kind) + : Expr(AGGREGATE_FN_OP, std::move(type)), aggregate_kind_(kind) {} + + AggregateFnKind aggregate_kind_; +}; + +/// \brief Count the number of values in the input expression. +class ARROW_EN_EXPORT CountExpr : public UnaryOpMixin, public AggregateFnExpr { + public: + static Result> Make(std::shared_ptr input); + + protected: + explicit CountExpr(std::shared_ptr input); +}; + +/// \brief Sum the input values. +class ARROW_EN_EXPORT SumExpr : public UnaryOpMixin, public AggregateFnExpr { + public: + static Result> Make(std::shared_ptr input); + + protected: + explicit SumExpr(std::shared_ptr input); +}; + /// /// Relational Expressions /// @@ -473,6 +507,18 @@ auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { ARROW_UNREACHABLE; } + case ExprKind::AGGREGATE_FN_OP: { + const auto& agg_expr = static_cast(expr); + switch (agg_expr.aggregate_kind()) { + case AggregateFnKind::COUNT: + return visitor(internal::checked_cast(expr)); + case AggregateFnKind::SUM: + return visitor(internal::checked_cast(expr)); + } + + ARROW_UNREACHABLE; + } + case ExprKind::EMPTY_REL: return visitor(internal::checked_cast(expr)); case ExprKind::SCAN_REL: diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc index 07ea194291e..037327e12f6 100644 --- a/cpp/src/arrow/engine/expression_test.cc +++ b/cpp/src/arrow/engine/expression_test.cc @@ -198,6 +198,63 @@ TYPED_TEST(CompareExprTest, BasicCompareExpr) { EXPECT_THAT(expr, PtrEquals(swapped)); } +TEST_F(ExprTest, CountExpr) { + ASSERT_RAISES(Invalid, CountExpr::Make(nullptr)); + + // Counting scalar is permitted. + ASSERT_OK_AND_ASSIGN(auto i32_lit, ScalarExpr::Make(MakeScalar(42))); + EXPECT_THAT(CountExpr::Make(i32_lit), Ok()); + + // Counting a string scalar is permitted + ASSERT_OK_AND_ASSIGN(auto str_lit, ScalarExpr::Make(MakeScalar("hi"))); + EXPECT_THAT(CountExpr::Make(str_lit), Ok()); + + auto schema = arrow::schema({field("i32", int32()), field("str", utf8())}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + // Counting an int column should be supported. + ASSERT_OK_AND_ASSIGN(auto i32_column, FieldRefExpr::Make(input, 0)); + EXPECT_THAT(CountExpr::Make(i32_column), Ok()); + + // Counting a string column should be supported. + ASSERT_OK_AND_ASSIGN(auto str_column, FieldRefExpr::Make(input, 1)); + EXPECT_THAT(CountExpr::Make(str_column), Ok()); + + // Counting a table should be supported + EXPECT_THAT(CountExpr::Make(input), Ok()); +} + +TEST_F(ExprTest, SumExpr) { + ASSERT_RAISES(Invalid, SumExpr::Make(nullptr)); + + // Summing a scalar is permitted. + ASSERT_OK_AND_ASSIGN(auto i32_lit, ScalarExpr::Make(MakeScalar(42))); + EXPECT_THAT(SumExpr::Make(i32_lit), Ok()); + + // Summing a string is not permitted. + ASSERT_OK_AND_ASSIGN(auto str_lit, ScalarExpr::Make(MakeScalar("hi"))); + ASSERT_RAISES(Invalid, SumExpr::Make(str_lit)); + + auto schema = arrow::schema( + {field("i32", int32()), field("str", utf8()), field("list_i32", list(int32()))}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + // Summing an integer column should be supported. + ASSERT_OK_AND_ASSIGN(auto i32_column, FieldRefExpr::Make(input, 0)); + EXPECT_THAT(SumExpr::Make(i32_column), Ok()); + + // Summing a string column should not be supported. + ASSERT_OK_AND_ASSIGN(auto str_column, FieldRefExpr::Make(input, 1)); + ASSERT_RAISES(Invalid, SumExpr::Make(str_column)); + + // Summing a list column should not be supported (yet). + ASSERT_OK_AND_ASSIGN(auto list_i32_column, FieldRefExpr::Make(input, 2)); + ASSERT_RAISES(Invalid, SumExpr::Make(list_i32_column)); + + // Summing a table should not be supported + ASSERT_RAISES(Invalid, SumExpr::Make(input)); +} + class RelExprTest : public ExprTest { protected: void SetUp() override { diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc index aecad0dcfbb..d06269b3302 100644 --- a/cpp/src/arrow/engine/logical_plan.cc +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -83,6 +83,18 @@ ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, return FieldRefExpr::Make(input, field_index); } +// +// Count +// + +ResultExpr LogicalPlanBuilder::Count(const std::shared_ptr& input) { + return CountExpr::Make(input); +} + +ResultExpr LogicalPlanBuilder::Sum(const std::shared_ptr& input) { + return SumExpr::Make(input); +} + // // Relational // diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index 00016f0e547..8a279548608 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -93,6 +93,17 @@ class ARROW_EN_EXPORT LogicalPlanBuilder { /// @} + /// \defgroup Aggregate function operators + /// @{ + + /// \brief Count the number of elements in the input. + ResultExpr Count(const std::shared_ptr& input); + + /// \brief Sum the elements of the input. + ResultExpr Sum(const std::shared_ptr& input); + + /// @} + /// \defgroup rel-nodes Relational operator nodes in the logical plan /// \brief Filter rows of a relation with the given predicate. @@ -106,14 +117,14 @@ class ARROW_EN_EXPORT LogicalPlanBuilder { /// \brief Project (select) columns by names. /// /// This is a simplified version of Project where columns are selected by - /// names. Duplicate and ordering are preserved. + /// names. Duplicates and ordering are preserved. ResultExpr Project(const std::shared_ptr& input, const std::vector& column_names); /// \brief Project (select) columns by indices. /// /// This is a simplified version of Project where columns are selected by - /// indices. Duplicate and ordering are preserved. + /// indices. Duplicates and ordering are preserved. ResultExpr Project(const std::shared_ptr& input, const std::vector& column_indices); diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc index 10ec42b8a6f..d1e30d1becc 100644 --- a/cpp/src/arrow/engine/logical_plan_test.cc +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -60,7 +60,8 @@ class LogicalPlanBuilderTest : public testing::Test { field("bool", boolean()), field("i32", int32()), field("u64", uint64()), - field("f32", uint32()), + field("f32", float32()), + field("utf8", utf8()), }); LogicalPlanBuilderOptions options{}; LogicalPlanBuilder builder{}; @@ -95,6 +96,25 @@ TEST_F(LogicalPlanBuilderTest, BasicScan) { ASSERT_OK(builder.Scan(table_1)); } +TEST_F(LogicalPlanBuilderTest, Count) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + + EXPECT_OK_AND_ASSIGN(auto f_count, builder.Count(field)); + EXPECT_OK_AND_ASSIGN(auto t_count, builder.Count(table)); +} + +TEST_F(LogicalPlanBuilderTest, Sum) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + + EXPECT_OK_AND_ASSIGN(auto i32_field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto f_count, builder.Sum(i32_field)); + + EXPECT_OK_AND_ASSIGN(auto str_field, field_expr("utf8", table)); + ASSERT_RAISES(Invalid, builder.Sum(str_field)); + ASSERT_RAISES(Invalid, builder.Sum(table)); +} + TEST_F(LogicalPlanBuilderTest, Filter) { EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index 6ca1b47bd1a..ebb35cbb975 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -35,6 +35,8 @@ enum ExprKind : uint8_t { // Comparison operators, COMPARE_OP, + // Aggregate function operators, + AGGREGATE_FN_OP, /// Empty relation with a known schema. EMPTY_REL, @@ -68,6 +70,18 @@ class GreaterThanEqualExpr; class LessThanExpr; class LessThanEqualExpr; +/// Tag identifier for aggregate function operators +enum AggregateFnKind : uint8_t { + // Count the number of elements in the input array + COUNT, + // Sum the elements of the input array. + SUM, +}; + +class AggregateFnExpr; +class CountExpr; +class SumExpr; + class RelExpr; class EmptyRelExpr; diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h index 3303273547e..949b9508b26 100644 --- a/cpp/src/arrow/engine/type_traits.h +++ b/cpp/src/arrow/engine/type_traits.h @@ -76,6 +76,12 @@ struct expr_traits { static constexpr auto compare_kind_id = CompareKind::LESS_THAN_EQUAL; }; +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::AGGREGATE_FN_OP; + static constexpr auto aggregate_kind_id = AggregateFnKind::COUNT; +}; + template <> struct expr_traits { static constexpr auto kind_id = ExprKind::EMPTY_REL; @@ -102,6 +108,12 @@ using is_compare_expr = std::is_base_of; template using enable_if_compare_expr = enable_if_t::value, Ret>; +template +using is_aggregate_fn_expr = std::is_base_of; + +template +using enable_if_aggregate_fn_expr = enable_if_t::value, Ret>; + template using is_relational_expr = std::is_base_of; diff --git a/cpp/src/arrow/testing/gmock.h b/cpp/src/arrow/testing/gmock.h index 5b2f2f19502..4b801679ec1 100644 --- a/cpp/src/arrow/testing/gmock.h +++ b/cpp/src/arrow/testing/gmock.h @@ -26,6 +26,7 @@ using testing::HasSubstr; MATCHER_P(Equals, other, "") { return arg.Equals(other); } MATCHER_P(PtrEquals, other, "") { return arg->Equals(*other); } +MATCHER(Ok, "") { return arg.ok(); } MATCHER_P(OkAndEq, other, "") { return arg.ok() && arg.ValueOrDie() == other; } } // namespace arrow diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 021c9985c1e..06746fdf4cf 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -684,6 +684,10 @@ static inline bool is_floating(Type::type type_id) { return false; } +static inline bool is_numeric(Type::type type_id) { + return is_integer(type_id) || is_floating(type_id); +} + static inline bool is_primitive(Type::type type_id) { switch (type_id) { case Type::NA: From 24675c6b97a0a5776aebd3b784dd9ed8c43e5dcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Mon, 16 Mar 2020 14:00:56 -0400 Subject: [PATCH 19/21] Add Compare expressions to LogicalPlanBuilder --- cpp/src/arrow/engine/logical_plan.cc | 55 ++++++++++++++++++++++- cpp/src/arrow/engine/logical_plan.h | 22 ++++++--- cpp/src/arrow/engine/logical_plan_test.cc | 14 +++++- 3 files changed, 83 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc index d06269b3302..bf3e745fd50 100644 --- a/cpp/src/arrow/engine/logical_plan.cc +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -83,6 +83,57 @@ ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, return FieldRefExpr::Make(input, field_index); } +ResultExpr LogicalPlanBuilder::Compare(CompareKind compare_kind, + const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + switch (compare_kind) { + case (CompareKind::EQUAL): + return EqualExpr::Make(lhs, rhs); + case (CompareKind::NOT_EQUAL): + return NotEqualExpr::Make(lhs, rhs); + case (CompareKind::GREATER_THAN): + return GreaterThanExpr::Make(lhs, rhs); + case (CompareKind::GREATER_THAN_EQUAL): + return GreaterThanEqualExpr::Make(lhs, rhs); + case (CompareKind::LESS_THAN): + return LessThanExpr::Make(lhs, rhs); + case (CompareKind::LESS_THAN_EQUAL): + return LessThanEqualExpr::Make(lhs, rhs); + } + + ARROW_UNREACHABLE; +} + +ResultExpr LogicalPlanBuilder::Equal(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::EQUAL, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::NotEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::NOT_EQUAL, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::GreaterThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::GREATER_THAN, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::GreaterThanEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::GREATER_THAN_EQUAL, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::LessThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::LESS_THAN, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::LessThanEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::LESS_THAN_EQUAL, lhs, rhs); +} + // // Count // @@ -102,12 +153,12 @@ ResultExpr LogicalPlanBuilder::Sum(const std::shared_ptr& input) { ResultExpr LogicalPlanBuilder::Scan(const std::string& table_name) { ERROR_IF(catalog_ == nullptr, "Cannot scan from an empty catalog"); ARROW_ASSIGN_OR_RAISE(auto table, catalog_->Get(table_name)); - return ScanRelExpr::Make(std::move(table)); + return ScanRelExpr::Make(table); } ResultExpr LogicalPlanBuilder::Filter(const std::shared_ptr& input, const std::shared_ptr& predicate) { - return FilterRelExpr::Make(std::move(input), std::move(predicate)); + return FilterRelExpr::Make(input, predicate); } ResultExpr LogicalPlanBuilder::Project( diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h index 8a279548608..f5ef46809f5 100644 --- a/cpp/src/arrow/engine/logical_plan.h +++ b/cpp/src/arrow/engine/logical_plan.h @@ -78,18 +78,30 @@ class ARROW_EN_EXPORT LogicalPlanBuilder { /// \defgroup comparator-nodes Comparison operators /// @{ - /* - TODO(fsaintjacques): This. + /// \brief Compare inputs. + ResultExpr Compare(CompareKind compare_kind, const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + + /// \brief Compare if inputs are equal. ResultExpr Equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + /// \brief Compare if inputs are not equal. ResultExpr NotEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + /// \brief Compare if lhs is greater than rhs. ResultExpr GreaterThan(const std::shared_ptr& lhs, const std::shared_ptr& rhs); - ResultExpr GreaterEqualThan(const std::shared_ptr& lhs, + + /// \brief Compare if lhs is greater than equal rhs. + ResultExpr GreaterThanEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + /// \brief Compare if lhs is less than rhs. ResultExpr LessThan(const std::shared_ptr& lhs, const std::shared_ptr& rhs); - ResultExpr LessEqualThan(const std::shared_ptr& lhs, + + /// \brief Compare if lhs is less than equal rhs. + ResultExpr LessThanEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs); - */ /// @} diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc index d1e30d1becc..05daade44df 100644 --- a/cpp/src/arrow/engine/logical_plan_test.cc +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -90,12 +90,24 @@ TEST_F(LogicalPlanBuilderTest, FieldReferences) { } TEST_F(LogicalPlanBuilderTest, BasicScan) { - LogicalPlanBuilder builder{options}; ASSERT_RAISES(KeyError, builder.Scan("")); ASSERT_RAISES(KeyError, builder.Scan("not_found")); ASSERT_OK(builder.Scan(table_1)); } +TEST_F(LogicalPlanBuilderTest, Comparisons) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + + EXPECT_OK_AND_ASSIGN(auto eq, builder.Equal(field, scalar)); + EXPECT_OK_AND_ASSIGN(auto ne, builder.NotEqual(field, scalar)); + EXPECT_OK_AND_ASSIGN(auto gt, builder.GreaterThan(field, scalar)); + EXPECT_OK_AND_ASSIGN(auto ge, builder.GreaterThanEqual(field, scalar)); + EXPECT_OK_AND_ASSIGN(auto lt, builder.LessThan(field, scalar)); + EXPECT_OK_AND_ASSIGN(auto le, builder.LessThanEqual(field, scalar)); +} + TEST_F(LogicalPlanBuilderTest, Count) { EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); From ef10404a020a68134c0012c8390a648a33498bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Mon, 16 Mar 2020 14:03:53 -0400 Subject: [PATCH 20/21] Add missing forward --- cpp/src/arrow/engine/type_fwd.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index ebb35cbb975..bf9892f8c13 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -91,5 +91,8 @@ class FilterRelExpr; class Catalog; +class LogicalPlan; +class LogicalPlanBuilder; + } // namespace engine } // namespace arrow From cf58f34b9b648bf6d4c83f80ba0001997ecea6d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Mon, 16 Mar 2020 17:01:40 -0400 Subject: [PATCH 21/21] Add IsA --- cpp/src/arrow/engine/expression.h | 39 +++++++++++++++++++++++ cpp/src/arrow/engine/logical_plan_test.cc | 37 +++++++++++++++++++-- cpp/src/arrow/engine/type_fwd.h | 11 ++++--- cpp/src/arrow/engine/type_traits.h | 17 ++++++++++ 4 files changed, 96 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h index d88b47dc608..a007c70d98f 100644 --- a/cpp/src/arrow/engine/expression.h +++ b/cpp/src/arrow/engine/expression.h @@ -532,5 +532,44 @@ auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { ARROW_UNREACHABLE; } +/// +/// RTTI utilities +/// + +/// \defgroup isa-expr Family of functions to introspect if an expression of a +/// given expression class. +/// @{ + +template +enable_if_simple_expr IsA(const Expr& expr) { + return expr.kind() == expr_traits::kind_id; +} + +template +enable_if_compare_expr IsA(const Expr& expr) { + if (expr.kind() != ExprKind::COMPARE_OP) { + return false; + } + const auto& cmp = internal::checked_cast(expr); + return cmp.compare_kind() == expr_traits::compare_kind_id; +} + +template +enable_if_aggregate_fn_expr IsA(const Expr& expr) { + if (expr.kind() != ExprKind::AGGREGATE_FN_OP) { + return false; + } + const auto& agg = internal::checked_cast(expr); + return agg.aggregate_kind() == expr_traits::aggregate_kind_id; +} + +template +bool IsA(const std::shared_ptr& expr) { + if (!expr) return false; + return IsA(*expr); +} + +/// @} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc index 05daade44df..28d4607ed41 100644 --- a/cpp/src/arrow/engine/logical_plan_test.cc +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -70,6 +70,7 @@ class LogicalPlanBuilderTest : public testing::Test { TEST_F(LogicalPlanBuilderTest, Scalar) { auto forthy_two = MakeScalar(42); EXPECT_OK_AND_ASSIGN(auto scalar, builder.Scalar(forthy_two)); + ASSERT_TRUE(IsA(scalar)); } TEST_F(LogicalPlanBuilderTest, FieldReferences) { @@ -86,13 +87,18 @@ TEST_F(LogicalPlanBuilderTest, FieldReferences) { ASSERT_RAISES(KeyError, builder.Field(table, 9000)); EXPECT_OK_AND_ASSIGN(auto field_name_ref, builder.Field(table, "i32")); + ASSERT_TRUE(IsA(field_name_ref)); + EXPECT_OK_AND_ASSIGN(auto field_idx_ref, builder.Field(table, 0)); + ASSERT_TRUE(IsA(field_idx_ref)); } TEST_F(LogicalPlanBuilderTest, BasicScan) { ASSERT_RAISES(KeyError, builder.Scan("")); ASSERT_RAISES(KeyError, builder.Scan("not_found")); - ASSERT_OK(builder.Scan(table_1)); + + EXPECT_OK_AND_ASSIGN(auto scan, builder.Scan(table_1)); + ASSERT_TRUE(IsA(scan)); } TEST_F(LogicalPlanBuilderTest, Comparisons) { @@ -101,26 +107,49 @@ TEST_F(LogicalPlanBuilderTest, Comparisons) { EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); EXPECT_OK_AND_ASSIGN(auto eq, builder.Equal(field, scalar)); + ASSERT_TRUE(IsA(eq)); + EXPECT_OK_AND_ASSIGN(auto ne, builder.NotEqual(field, scalar)); + ASSERT_TRUE(IsA(ne)); + EXPECT_OK_AND_ASSIGN(auto gt, builder.GreaterThan(field, scalar)); + ASSERT_TRUE(IsA(gt)); + EXPECT_OK_AND_ASSIGN(auto ge, builder.GreaterThanEqual(field, scalar)); + ASSERT_TRUE(IsA(ge)); + EXPECT_OK_AND_ASSIGN(auto lt, builder.LessThan(field, scalar)); + ASSERT_TRUE(IsA(lt)); + EXPECT_OK_AND_ASSIGN(auto le, builder.LessThanEqual(field, scalar)); + ASSERT_TRUE(IsA(le)); } TEST_F(LogicalPlanBuilderTest, Count) { EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + EXPECT_OK_AND_ASSIGN(auto s_count, builder.Count(scalar)); + ASSERT_TRUE(IsA(s_count)); + EXPECT_OK_AND_ASSIGN(auto f_count, builder.Count(field)); + ASSERT_TRUE(IsA(f_count)); + EXPECT_OK_AND_ASSIGN(auto t_count, builder.Count(table)); + ASSERT_TRUE(IsA(t_count)); } TEST_F(LogicalPlanBuilderTest, Sum) { EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + EXPECT_OK_AND_ASSIGN(auto s_sum, builder.Sum(scalar)); + ASSERT_TRUE(IsA(s_sum)); + EXPECT_OK_AND_ASSIGN(auto i32_field, field_expr("i32", table)); - EXPECT_OK_AND_ASSIGN(auto f_count, builder.Sum(i32_field)); + EXPECT_OK_AND_ASSIGN(auto f_sum, builder.Sum(i32_field)); + ASSERT_TRUE(IsA(s_sum)); EXPECT_OK_AND_ASSIGN(auto str_field, field_expr("utf8", table)); ASSERT_RAISES(Invalid, builder.Sum(str_field)); @@ -135,6 +164,7 @@ TEST_F(LogicalPlanBuilderTest, Filter) { EXPECT_OK_AND_ASSIGN(auto predicate, EqualExpr::Make(field, scalar)); EXPECT_OK_AND_ASSIGN(auto filter, builder.Filter(table, predicate)); + ASSERT_TRUE(IsA(filter)); } TEST_F(LogicalPlanBuilderTest, ProjectionByNamesAndIndices) { @@ -150,7 +180,8 @@ TEST_F(LogicalPlanBuilderTest, ProjectionByNamesAndIndices) { std::vector valid_names{"u64", "f32"}; ASSERT_OK(builder.Project(table, valid_names)); std::vector valid_idx{3, 1, 1}; - ASSERT_OK(builder.Project(table, valid_idx)); + EXPECT_OK_AND_ASSIGN(auto project, builder.Project(table, valid_idx)); + ASSERT_TRUE(IsA(project)); } } // namespace engine diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h index bf9892f8c13..55d10720e68 100644 --- a/cpp/src/arrow/engine/type_fwd.h +++ b/cpp/src/arrow/engine/type_fwd.h @@ -33,18 +33,19 @@ enum ExprKind : uint8_t { /// A Field reference in a schema. FIELD_REFERENCE, - // Comparison operators, + // Comparison operators, see CompareKind. COMPARE_OP, - // Aggregate function operators, + + // Aggregate function operators, see AggregateFnKind. AGGREGATE_FN_OP, /// Empty relation with a known schema. EMPTY_REL, - /// Scan relational operator + /// Scan relational operator. SCAN_REL, - /// Projection relational operator + /// Projection relational operator. PROJECTION_REL, - /// Filter relational operator + /// Filter relational operator. FILTER_REL, }; diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h index 949b9508b26..6e44cb6d826 100644 --- a/cpp/src/arrow/engine/type_traits.h +++ b/cpp/src/arrow/engine/type_traits.h @@ -82,6 +82,12 @@ struct expr_traits { static constexpr auto aggregate_kind_id = AggregateFnKind::COUNT; }; +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::AGGREGATE_FN_OP; + static constexpr auto aggregate_kind_id = AggregateFnKind::SUM; +}; + template <> struct expr_traits { static constexpr auto kind_id = ExprKind::EMPTY_REL; @@ -102,6 +108,12 @@ struct expr_traits { static constexpr auto kind_id = ExprKind::FILTER_REL; }; +template +using is_expr = std::is_base_of; + +template +using enable_if_expr = enable_if_t::value, Ret>; + template using is_compare_expr = std::is_base_of; @@ -120,5 +132,10 @@ using is_relational_expr = std::is_base_of; template using enable_if_relational_expr = enable_if_t::value, Ret>; +// Catch-all used by `IsA` pattern matcher. +template +using enable_if_simple_expr = + enable_if_t::value && !is_aggregate_fn_expr::value, Ret>; + } // namespace engine } // namespace arrow