From 911fd82588d5146bff13946b1d48cb55e19e9744 Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Mon, 19 Jul 2021 15:53:21 +0200 Subject: [PATCH 1/4] DPL Analysis: Introducing conditional expressions * `ifnode(condition, then, else)` operation is added to expressions * these can be nested * all three arguments can be arbitrary valid expressions * `condition` needs to have boolean result, `then` and `else` should return similar types (ideally the same - both floats, or both boolean, etc.) * Added a test * Added `conditionalExpressions.cxx` tutorial example (note that it uses bitwise operations in filter expression and thus will only work as is with arrow > 3) --- Analysis/Tutorials/CMakeLists.txt | 6 + .../Tutorials/src/conditionalExpressions.cxx | 56 +++++ Framework/Core/include/Framework/BasicOps.h | 3 +- .../include/Framework/ExpressionHelpers.h | 23 +- .../Core/include/Framework/Expressions.h | 208 +++++++++++++----- Framework/Core/src/Expressions.cxx | 109 +++++++-- Framework/Core/test/test_Expressions.cxx | 69 +++++- 7 files changed, 382 insertions(+), 92 deletions(-) create mode 100644 Analysis/Tutorials/src/conditionalExpressions.cxx diff --git a/Analysis/Tutorials/CMakeLists.txt b/Analysis/Tutorials/CMakeLists.txt index 3bc60f6219bdc..52bb47e3f52a1 100644 --- a/Analysis/Tutorials/CMakeLists.txt +++ b/Analysis/Tutorials/CMakeLists.txt @@ -228,3 +228,9 @@ o2_add_dpl_workflow(multiprocess-example JOB_POOL analysis PUBLIC_LINK_LIBRARIES O2::Framework O2::AnalysisCore O2::AnalysisDataModel COMPONENT_NAME AnalysisTutorial) + +o2_add_dpl_workflow(conditional-expressions + SOURCES src/conditionalExpressions.cxx + JOB_POOL analysis + PUBLIC_LINK_LIBRARIES O2::Framework O2::AnalysisCore O2::AnalysisDataModel + COMPONENT_NAME AnalysisTutorial) diff --git a/Analysis/Tutorials/src/conditionalExpressions.cxx b/Analysis/Tutorials/src/conditionalExpressions.cxx new file mode 100644 index 0000000000000..7582bcc4f28f2 --- /dev/null +++ b/Analysis/Tutorials/src/conditionalExpressions.cxx @@ -0,0 +1,56 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +/// +/// \brief Demonstration of conditions in filter expressions + +#include "Framework/runDataProcessing.h" +#include "Framework/AnalysisTask.h" + +using namespace o2; +using namespace o2::framework; +using namespace o2::framework::expressions; + +struct ConditionalExpressions { + Configurable useFlags{"useFlags", false, "Switch to enable using track flags for selection"}; + Filter trackFilter = nabs(aod::track::eta) < 0.9f && aod::track::pt > 0.5f && ifnode(useFlags == true, (aod::track::flags & static_cast(o2::aod::track::ITSrefit)) != 0u, true); + OutputObj etapt{TH2F("etapt", ";#eta;#p_{T}", 201, -2.1, 2.1, 601, 0, 60.1)}; + void process(aod::Collision const&, soa::Filtered> const& tracks) + { + for (auto& track : tracks) { + etapt->Fill(track.eta(), track.pt()); + } + } +}; + +struct BasicOperations { + Configurable useFlags{"useFlags", false, "Switch to enable using track flags for selection"}; + Filter trackFilter = nabs(aod::track::eta) < 0.9f && aod::track::pt > 0.5f; + OutputObj etapt{TH2F("etapt", ";#eta;#p_{T}", 201, -2.1, 2.1, 601, 0, 60.1)}; + void process(aod::Collision const&, soa::Filtered> const& tracks) + { + for (auto& track : tracks) { + if (useFlags) { + if ((track.flags() & o2::aod::track::ITSrefit) != 0u) { + etapt->Fill(track.eta(), track.pt()); + } + } else { + etapt->Fill(track.eta(), track.pt()); + } + } + } +}; + +WorkflowSpec defineDataProcessing(ConfigContext const& cfgc) +{ + return WorkflowSpec{ + adaptAnalysisTask(cfgc), + adaptAnalysisTask(cfgc)}; +} diff --git a/Framework/Core/include/Framework/BasicOps.h b/Framework/Core/include/Framework/BasicOps.h index d2a0aacb43c77..230936832f85b 100644 --- a/Framework/Core/include/Framework/BasicOps.h +++ b/Framework/Core/include/Framework/BasicOps.h @@ -41,7 +41,8 @@ enum BasicOp : unsigned int { Acos, Atan, Abs, - BitwiseNot + BitwiseNot, + Conditional }; } // namespace o2::framework diff --git a/Framework/Core/include/Framework/ExpressionHelpers.h b/Framework/Core/include/Framework/ExpressionHelpers.h index d16338dcd64b2..7630e2f799bce 100644 --- a/Framework/Core/include/Framework/ExpressionHelpers.h +++ b/Framework/Core/include/Framework/ExpressionHelpers.h @@ -20,7 +20,7 @@ namespace o2::framework::expressions { /// a map between BasicOp and gandiva node definitions /// note that logical 'and' and 'or' are created separately -static std::array basicOperationsMap = { +static std::array basicOperationsMap = { "and", "or", "add", @@ -48,7 +48,8 @@ static std::array basicOperationsMap = { "acosf", "atanf", "absf", - "bitwise_not"}; + "bitwise_not", + "if"}; struct DatumSpec { /// datum spec either contains an index, a value of a literal or a binding label @@ -72,17 +73,21 @@ bool operator==(DatumSpec const& lhs, DatumSpec const& rhs); std::ostream& operator<<(std::ostream& os, DatumSpec const& spec); struct ColumnOperationSpec { + size_t index = 0; BasicOp op; DatumSpec left; DatumSpec right; + DatumSpec condition; DatumSpec result; atype::type type = atype::NA; ColumnOperationSpec() = default; - // TODO: extend this to support unary ops seamlessly - explicit ColumnOperationSpec(BasicOp op_) : op{op_}, - left{}, - right{}, - result{} + explicit ColumnOperationSpec(BasicOp op_, size_t index_ = 0) + : index{index_}, + op{op_}, + left{}, + right{}, + condition{}, + result{} { switch (op) { case BasicOp::LogicalOr: @@ -110,6 +115,10 @@ struct NodeRecord { Node* node_ptr = nullptr; size_t index = 0; explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {} + bool operator!=(NodeRecord const& rhs) + { + return this->node_ptr != rhs.node_ptr; + } }; } // namespace o2::framework::expressions diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index a05bcf110990a..b739b913d52d6 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -146,93 +146,117 @@ struct PlaceholderNode : LiteralNode { LiteralNode::var_t (*retrieve)(InitContext&, std::string const& name); }; +/// A conditional node +struct ConditionalNode { +}; + /// A generic tree node struct Node { - Node(LiteralNode v) : self{v}, left{nullptr}, right{nullptr} + Node(LiteralNode v) : self{v}, left{nullptr}, right{nullptr}, condition{nullptr} { } - Node(PlaceholderNode v) : self{v}, left{nullptr}, right{nullptr} + Node(PlaceholderNode v) : self{v}, left{nullptr}, right{nullptr}, condition{nullptr} { } - Node(Node&& n) : self{n.self}, left{std::move(n.left)}, right{std::move(n.right)} + Node(Node&& n) : self{n.self}, left{std::move(n.left)}, right{std::move(n.right)}, condition{std::move(n.condition)} { } - Node(BindingNode n) : self{n}, left{nullptr}, right{nullptr} + Node(BindingNode n) : self{n}, left{nullptr}, right{nullptr}, condition{nullptr} { } + Node(ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_) + : self{op}, + left{std::make_unique(std::move(then_))}, + right{std::make_unique(std::move(else_))}, + condition{std::make_unique(std::move(condition_))} {} + Node(OpNode op, Node&& l, Node&& r) : self{op}, left{std::make_unique(std::move(l))}, - right{std::make_unique(std::move(r))} {} + right{std::make_unique(std::move(r))}, + condition{nullptr} {} Node(OpNode op, Node&& l) : self{op}, left{std::make_unique(std::move(l))}, - right{nullptr} {} + right{nullptr}, + condition{nullptr} {} /// variant with possible nodes - using self_t = std::variant; + using self_t = std::variant; self_t self; + size_t index = 0; /// pointers to children std::unique_ptr left; std::unique_ptr right; + std::unique_ptr condition; }; /// overloaded operators to build the tree from an expression -#define BINARY_OP_NODES(_operator_, _operation_) \ - template \ - inline Node operator _operator_(Node left, T right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), LiteralNode{right}}; \ - } \ - template \ - inline Node operator _operator_(T left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, LiteralNode{left}, std::move(right)}; \ - } \ - template \ - inline Node operator _operator_(Node left, Configurable right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), PlaceholderNode{right}}; \ - } \ - template \ - inline Node operator _operator_(Configurable left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, std::move(right)}; \ - } \ - inline Node operator _operator_(Node left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), std::move(right)}; \ - } \ - inline Node operator _operator_(BindingNode left, BindingNode right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, right}; \ - } \ - template <> \ - inline Node operator _operator_(BindingNode left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, std::move(right)}; \ - } \ - template <> \ - inline Node operator _operator_(Node left, BindingNode right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), right}; \ - } \ - \ - template \ - inline Node operator _operator_(Configurable left, BindingNode right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, right}; \ - } \ - template \ - inline Node operator _operator_(BindingNode left, Configurable right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, PlaceholderNode{right}}; \ +#define BINARY_OP_NODES(_operator_, _operation_) \ + template \ + inline Node operator _operator_(Node left, T right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), LiteralNode{right}}; \ + } \ + template \ + inline Node operator _operator_(T left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, LiteralNode{left}, std::move(right)}; \ + } \ + template \ + inline Node operator _operator_(Node left, Configurable right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), PlaceholderNode{right}}; \ + } \ + template \ + inline Node operator _operator_(Configurable left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, std::move(right)}; \ + } \ + inline Node operator _operator_(Node left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), std::move(right)}; \ + } \ + inline Node operator _operator_(BindingNode left, BindingNode right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, right}; \ + } \ + template <> \ + inline Node operator _operator_(BindingNode left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, std::move(right)}; \ + } \ + template <> \ + inline Node operator _operator_(Node left, BindingNode right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), right}; \ + } \ + \ + template \ + inline Node operator _operator_(Configurable left, BindingNode right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, right}; \ + } \ + template \ + inline Node operator _operator_(BindingNode left, Configurable right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, PlaceholderNode{right}}; \ + } \ + template \ + inline Node operator _operator_(Configurable left, L right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, LiteralNode{right}}; \ + } \ + template \ + inline Node operator _operator_(L left, Configurable right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, PlaceholderNode{right}}; \ } BINARY_OP_NODES(&, BitwiseAnd); @@ -319,20 +343,84 @@ inline Node nbitwise_not(Node left) return Node{OpNode{BasicOp::BitwiseNot}, std::move(left)}; } +/// conditionals +template +inline Node ifnode(C condition_, T then_, E else_) +{ + return Node{ConditionalNode{}, std::move(then_), std::move(else_), std::move(condition_)}; +} + +template <> +inline Node ifnode(Node condition_, Node then_, Node else_) +{ + return Node{ConditionalNode{}, std::move(then_), std::move(else_), std::move(condition_)}; +} + +template +inline Node ifnode(Node condition_, Node then_, L else_) +{ + return Node{ConditionalNode{}, std::move(then_), LiteralNode{else_}, std::move(condition_)}; +} + +template +inline Node ifnode(Node condition_, L then_, Node else_) +{ + return Node{ConditionalNode{}, LiteralNode{then_}, std::move(else_), std::move(condition_)}; +} + +template +inline Node ifnode(Node condition_, L1 then_, L2 else_) +{ + return Node{ConditionalNode{}, LiteralNode{then_}, LiteralNode{else_}, std::move(condition_)}; +} + +template +inline Node ifnode(Configurable condition_, Node then_, Node else_) +{ + return Node{ConditionalNode{}, std::move(then_), std::move(else_), PlaceholderNode{condition_}}; +} + +template +inline Node ifnode(Node condition_, Node then_, Configurable else_) +{ + return Node{ConditionalNode{}, std::move(then_), PlaceholderNode{else_}, std::move(condition_)}; +} + +template +inline Node ifnode(Node condition_, Configurable then_, Node else_) +{ + return Node{ConditionalNode{}, PlaceholderNode{then_}, std::move(else_), std::move(condition_)}; +} + +template +inline Node ifnode(Node condition_, Configurable then_, Configurable else_) +{ + return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::move(condition_)}; +} + /// A struct, containing the root of the expression tree struct Filter { - Filter(Node&& node_) : node{std::make_unique(std::move(node_))} {} - Filter(Filter&& other) : node{std::move(other.node)} {} + Filter(Node&& node_) : node{std::make_unique(std::move(node_))} + { + (void)designateSubtrees(node.get()); + } + + Filter(Filter&& other) : node{std::move(other.node)} + { + (void)designateSubtrees(node.get()); + } std::unique_ptr node; + + size_t designateSubtrees(Node* node, size_t index = 0); }; using Projector = Filter; using Selection = std::shared_ptr; /// Function for creating gandiva selection from our internal filter tree -Selection createSelection(std::shared_ptr table, Filter const& expression); +Selection createSelection(std::shared_ptr const& table, Filter const& expression); /// Function for creating gandiva selection from prepared gandiva expressions tree -Selection createSelection(std::shared_ptr table, std::shared_ptr gfilter); +Selection createSelection(std::shared_ptr const& table, std::shared_ptr gfilter); struct ColumnOperationSpec; using Operations = std::vector; diff --git a/Framework/Core/src/Expressions.cxx b/Framework/Core/src/Expressions.cxx index cd1756a9ea3d4..0d19e21a1eacb 100644 --- a/Framework/Core/src/Expressions.cxx +++ b/Framework/Core/src/Expressions.cxx @@ -26,6 +26,36 @@ using namespace o2::framework; namespace o2::framework::expressions { + +size_t Filter::designateSubtrees(Node* node, size_t index) +{ + std::stack path; + auto local_index = index; + path.emplace(node, 0); + + while (path.empty() == false) { + auto& top = path.top(); + top.node_ptr->index = local_index; + path.pop(); + if (top.node_ptr->condition != nullptr) { + // start new subtrees + index = designateSubtrees(top.node_ptr->left.get(), local_index + 1); + index = designateSubtrees(top.node_ptr->condition.get(), index + 1); + index = designateSubtrees(top.node_ptr->right.get(), index + 1); + } else { + // continue current subtree + if (top.node_ptr->left != nullptr) { + path.emplace(top.node_ptr->left.get(), 0); + } + if (top.node_ptr->right != nullptr) { + path.emplace(top.node_ptr->right.get(), 0); + } + } + } + + return index; +} + namespace { struct LiteralNodeHelper { @@ -142,8 +172,9 @@ void updatePlaceholders(Filter& filter, InitContext& context) auto& top = path.top(); updateNode(top.node_ptr); - auto leftp = top.node_ptr->left.get(); - auto rightp = top.node_ptr->right.get(); + auto* leftp = top.node_ptr->left.get(); + auto* rightp = top.node_ptr->right.get(); + auto* condp = top.node_ptr->condition.get(); path.pop(); if (leftp != nullptr) { @@ -152,6 +183,9 @@ void updatePlaceholders(Filter& filter, InitContext& context) if (rightp != nullptr) { path.emplace(rightp, 0); } + if (condp != nullptr) { + path.emplace(condp, 0); + } } } @@ -185,14 +219,15 @@ Operations createOperations(Filter const& expression) auto operationSpec = std::visit( overloaded{ - [](OpNode node) { return ColumnOperationSpec{node.op}; }, + [&](OpNode node) { return ColumnOperationSpec{node.op, top.node_ptr->index}; }, + [&](ConditionalNode) { return ColumnOperationSpec{BasicOp::Conditional, top.node_ptr->index}; }, [](auto&&) { return ColumnOperationSpec{}; }}, top.node_ptr->self); operationSpec.result = DatumSpec{top.index, operationSpec.type}; path.pop(); - auto left = top.node_ptr->left.get(); + auto* left = top.node_ptr->left.get(); bool leftLeaf = isLeaf(left); size_t li = 0; if (leftLeaf) { @@ -224,6 +259,23 @@ Operations createOperations(Filter const& expression) } } + decltype(left) condition = nullptr; + if (top.node_ptr->condition != nullptr) { + condition = top.node_ptr->condition.get(); + } + bool condleaf = condition != nullptr ? isLeaf(condition) : true; + size_t ci = 0; + if (condition != nullptr) { + if (condleaf) { + operationSpec.condition = processLeaf(condition); + } else { + ci = index; + operationSpec.condition = DatumSpec{index++, atype::BOOL}; + } + } else { + operationSpec.condition = DatumSpec{}; + } + OperationSpecs.push_back(std::move(operationSpec)); if (!leftLeaf) { path.emplace(left, li); @@ -231,6 +283,9 @@ Operations createOperations(Filter const& expression) if (!isUnary && !rightLeaf) { path.emplace(right, ri); } + if (!condleaf) { + path.emplace(condition, ci); + } } // at this stage the operations vector is created, but the field types are // only set for the logical operations and leaf nodes @@ -303,6 +358,7 @@ Operations createOperations(Filter const& expression) if (it->type == atype::NA) { it->type = type; } + it->result.type = it->type; resultTypes[std::get(it->result.datum)] = it->type; } @@ -312,12 +368,12 @@ Operations createOperations(Filter const& expression) gandiva::ConditionPtr makeCondition(gandiva::NodePtr node) { - return gandiva::TreeExprBuilder::MakeCondition(node); + return gandiva::TreeExprBuilder::MakeCondition(std::move(node)); } gandiva::ExpressionPtr makeExpression(gandiva::NodePtr node, gandiva::FieldPtr result) { - return gandiva::TreeExprBuilder::MakeExpression(node, result); + return gandiva::TreeExprBuilder::MakeExpression(std::move(node), std::move(result)); } std::shared_ptr @@ -338,7 +394,7 @@ std::shared_ptr { std::shared_ptr filter; auto s = gandiva::Filter::Make(Schema, - condition, + std::move(condition), &filter); if (!s.ok()) { throw runtime_error_f("Failed to create filter: %s", s.ToString().c_str()); @@ -351,7 +407,7 @@ std::shared_ptr { std::shared_ptr projector; auto s = gandiva::Projector::Make(Schema, - {makeExpression(createExpressionTree(opSpecs, Schema), result)}, + {makeExpression(createExpressionTree(opSpecs, Schema), std::move(result))}, &projector); if (!s.ok()) { throw runtime_error_f("Failed to create projector: %s", s.ToString().c_str()); @@ -362,10 +418,10 @@ std::shared_ptr std::shared_ptr createProjector(gandiva::SchemaPtr const& Schema, Projector&& p, gandiva::FieldPtr result) { - return createProjector(Schema, createOperations(std::move(p)), std::move(result)); + return createProjector(Schema, createOperations(p), std::move(result)); } -Selection createSelection(std::shared_ptr table, std::shared_ptr gfilter) +Selection createSelection(std::shared_ptr const& table, std::shared_ptr gfilter) { Selection selection; auto s = gandiva::SelectionVector::MakeInt64(table->num_rows(), @@ -396,13 +452,13 @@ Selection createSelection(std::shared_ptr table, std::shared_ptr table, - const Filter& expression) +Selection createSelection(std::shared_ptr const& table, + Filter const& expression) { return createSelection(table, createFilter(table->schema(), createOperations(std::move(expression)))); } -auto createProjection(std::shared_ptr table, std::shared_ptr gprojector) +auto createProjection(std::shared_ptr const& table, std::shared_ptr const& gprojector) { arrow::TableBatchReader reader(*table); std::shared_ptr batch; @@ -430,6 +486,7 @@ gandiva::NodePtr createExpressionTree(Operations const& opSpecs, opNodes.resize(opSpecs.size()); std::fill(opNodes.begin(), opNodes.end(), nullptr); std::unordered_map fieldNodes; + std::unordered_map subtrees; auto datumNode = [Schema, &opNodes, &fieldNodes](DatumSpec const& spec) { if (spec.datum.index() == 0) { @@ -490,6 +547,7 @@ gandiva::NodePtr createExpressionTree(Operations const& opSpecs, for (auto it = opSpecs.rbegin(); it != opSpecs.rend(); ++it) { auto leftNode = datumNode(it->left); auto rightNode = datumNode(it->right); + auto condNode = datumNode(it->condition); auto insertUpcastNode = [&](gandiva::NodePtr node, atype::type t) { if (t != it->type) { @@ -509,12 +567,17 @@ gandiva::NodePtr createExpressionTree(Operations const& opSpecs, } }; + gandiva::NodePtr temp_node; + switch (it->op) { case BasicOp::LogicalOr: - tree = gandiva::TreeExprBuilder::MakeOr({leftNode, rightNode}); + temp_node = gandiva::TreeExprBuilder::MakeOr({leftNode, rightNode}); break; case BasicOp::LogicalAnd: - tree = gandiva::TreeExprBuilder::MakeAnd({leftNode, rightNode}); + temp_node = gandiva::TreeExprBuilder::MakeAnd({leftNode, rightNode}); + break; + case BasicOp::Conditional: + temp_node = gandiva::TreeExprBuilder::MakeIf(condNode, leftNode, rightNode, concreteArrowType(it->type)); break; default: if (it->op < BasicOp::Sqrt) { @@ -524,14 +587,24 @@ gandiva::NodePtr createExpressionTree(Operations const& opSpecs, } else if (it->op == BasicOp::Equal || it->op == BasicOp::NotEqual) { insertEqualizeUpcastNode(leftNode, rightNode, it->left.type, it->right.type); } - tree = gandiva::TreeExprBuilder::MakeFunction(basicOperationsMap[it->op], {leftNode, rightNode}, concreteArrowType(it->type)); + temp_node = gandiva::TreeExprBuilder::MakeFunction(basicOperationsMap[it->op], {leftNode, rightNode}, concreteArrowType(it->type)); } else { leftNode = insertUpcastNode(leftNode, it->left.type); - tree = gandiva::TreeExprBuilder::MakeFunction(basicOperationsMap[it->op], {leftNode}, concreteArrowType(it->type)); + temp_node = gandiva::TreeExprBuilder::MakeFunction(basicOperationsMap[it->op], {leftNode}, concreteArrowType(it->type)); } break; } - opNodes[std::get(it->result.datum)] = tree; + if (it->index == 0) { + tree = temp_node; + } else { + auto subtree = subtrees.find(it->index); + if (subtree == subtrees.end()) { + subtrees.insert({it->index, temp_node}); + } else { + subtree->second = temp_node; + } + } + opNodes[std::get(it->result.datum)] = temp_node; } return tree; diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index 57c61a636671d..a8012abfc70ce 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -37,12 +37,12 @@ static BindingNode testInt{"testInt", 6, atype::INT32}; namespace o2::aod::track { DECLARE_SOA_EXPRESSION_COLUMN(Pze, pz, float, o2::aod::track::tgl*(1.f / o2::aod::track::signed1Pt)); -} +} // namespace o2::aod::track BOOST_AUTO_TEST_CASE(TestTreeParsing) { expressions::Filter f = ((nodes::phi > 1) && (nodes::phi < 2)) && (nodes::eta < 1); - auto specs = createOperations(std::move(f)); + auto specs = createOperations(f); BOOST_REQUIRE_EQUAL(specs[0].left, (DatumSpec{1u, atype::BOOL})); BOOST_REQUIRE_EQUAL(specs[0].right, (DatumSpec{2u, atype::BOOL})); BOOST_REQUIRE_EQUAL(specs[0].result, (DatumSpec{0u, atype::BOOL})); @@ -64,7 +64,7 @@ BOOST_AUTO_TEST_CASE(TestTreeParsing) BOOST_REQUIRE_EQUAL(specs[4].result, (DatumSpec{3u, atype::BOOL})); expressions::Filter g = ((nodes::eta + 2.f) > 0.5) || ((nodes::phi - M_PI) < 3); - auto gspecs = createOperations(std::move(g)); + auto gspecs = createOperations(g); BOOST_REQUIRE_EQUAL(gspecs[0].left, (DatumSpec{1u, atype::BOOL})); BOOST_REQUIRE_EQUAL(gspecs[0].right, (DatumSpec{2u, atype::BOOL})); BOOST_REQUIRE_EQUAL(gspecs[0].result, (DatumSpec{0u, atype::BOOL})); @@ -86,7 +86,7 @@ BOOST_AUTO_TEST_CASE(TestTreeParsing) BOOST_REQUIRE_EQUAL(gspecs[4].result, (DatumSpec{4u, atype::FLOAT})); expressions::Filter h = (nodes::phi == 0) || (nodes::phi == 3); - auto hspecs = createOperations(std::move(h)); + auto hspecs = createOperations(h); BOOST_REQUIRE_EQUAL(hspecs[0].left, (DatumSpec{1u, atype::BOOL})); BOOST_REQUIRE_EQUAL(hspecs[0].right, (DatumSpec{2u, atype::BOOL})); @@ -140,7 +140,7 @@ BOOST_AUTO_TEST_CASE(TestTreeParsing) BOOST_AUTO_TEST_CASE(TestGandivaTreeCreation) { Projector pze = o2::aod::track::Pze::Projector(); - auto pzspecs = createOperations(std::move(pze)); + auto pzspecs = createOperations(pze); BOOST_REQUIRE_EQUAL(pzspecs[0].left, (DatumSpec{std::string{"fTgl"}, typeid(o2::aod::track::Tgl).hash_code(), atype::FLOAT})); BOOST_REQUIRE_EQUAL(pzspecs[0].right, (DatumSpec{1u, atype::FLOAT})); BOOST_REQUIRE_EQUAL(pzspecs[0].result, (DatumSpec{0u, atype::FLOAT})); @@ -159,7 +159,7 @@ BOOST_AUTO_TEST_CASE(TestGandivaTreeCreation) auto projector = createProjector(schema, pzspecs, resfield); Projector pte = o2::aod::track::Pt::Projector(); - auto ptespecs = createOperations(std::move(pte)); + auto ptespecs = createOperations(pte); BOOST_REQUIRE_EQUAL(ptespecs[0].left, (DatumSpec{1u, atype::FLOAT})); BOOST_REQUIRE_EQUAL(ptespecs[0].right, (DatumSpec{})); BOOST_REQUIRE_EQUAL(ptespecs[0].result, (DatumSpec{0u, atype::FLOAT})); @@ -203,3 +203,60 @@ BOOST_AUTO_TEST_CASE(TestGandivaTreeCreation) BOOST_REQUIRE(s.ok()); #endif } + +BOOST_AUTO_TEST_CASE(TestConditionalExpressions) +{ + // simple conditional + Filter cf = nabs(o2::aod::track::eta) < 1.0f && ifnode((o2::aod::track::pt < 1.0f), (o2::aod::track::phiraw > (float)(M_PI / 2.)), (o2::aod::track::phiraw < (float)(M_PI / 2.))); + auto cfspecs = createOperations(cf); + BOOST_REQUIRE_EQUAL(cfspecs[0].left, (DatumSpec{1u, atype::BOOL})); + BOOST_REQUIRE_EQUAL(cfspecs[0].right, (DatumSpec{2u, atype::BOOL})); + BOOST_REQUIRE_EQUAL(cfspecs[0].result, (DatumSpec{0u, atype::BOOL})); + + BOOST_REQUIRE_EQUAL(cfspecs[1].left, (DatumSpec{3u, atype::BOOL})); + BOOST_REQUIRE_EQUAL(cfspecs[1].right, (DatumSpec{4u, atype::BOOL})); + BOOST_REQUIRE_EQUAL(cfspecs[1].condition, (DatumSpec{5u, atype::BOOL})); + BOOST_REQUIRE_EQUAL(cfspecs[1].result, (DatumSpec{2u, atype::BOOL})); + + BOOST_REQUIRE_EQUAL(cfspecs[2].left, (DatumSpec{std::string{"fPt"}, typeid(o2::aod::track::Pt).hash_code(), atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[2].right, (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[2].result, (DatumSpec{5u, atype::BOOL})); + + BOOST_REQUIRE_EQUAL(cfspecs[3].left, (DatumSpec{std::string{"fRawPhi"}, typeid(o2::aod::track::RawPhi).hash_code(), atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[3].right, (DatumSpec{LiteralNode::var_t{(float)(M_PI / 2.)}, atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[3].result, (DatumSpec{4u, atype::BOOL})); + + BOOST_REQUIRE_EQUAL(cfspecs[4].left, (DatumSpec{std::string{"fRawPhi"}, typeid(o2::aod::track::RawPhi).hash_code(), atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[4].right, (DatumSpec{LiteralNode::var_t{(float)(M_PI / 2.)}, atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[4].result, (DatumSpec{3u, atype::BOOL})); + + BOOST_REQUIRE_EQUAL(cfspecs[5].left, (DatumSpec{6u, atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[5].right, (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[5].result, (DatumSpec{1u, atype::BOOL})); + + BOOST_REQUIRE_EQUAL(cfspecs[6].left, (DatumSpec{std::string{"fEta"}, typeid(o2::aod::track::Eta).hash_code(), atype::FLOAT})); + BOOST_REQUIRE_EQUAL(cfspecs[6].right, (DatumSpec{})); + BOOST_REQUIRE_EQUAL(cfspecs[6].result, (DatumSpec{6u, atype::FLOAT})); + + auto infield1 = o2::aod::track::Pt::asArrowField(); + auto infield2 = o2::aod::track::Eta::asArrowField(); + auto infield3 = o2::aod::track::RawPhi::asArrowField(); + auto schema = std::make_shared(std::vector{infield1, infield2, infield3}); + auto gandiva_tree = createExpressionTree(cfspecs, schema); + auto gandiva_condition = makeCondition(gandiva_tree); + auto gandiva_filter = createFilter(schema, gandiva_condition); + + BOOST_CHECK_EQUAL(gandiva_tree->ToString(), "bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fRawPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fRawPhi, (const float) 1.5708 raw(3fc90fdb)) }"); + + // nested conditional + Filter cfn = o2::aod::track::signed1Pt > 0.f && ifnode(std::move(*cf.node), nabs(o2::aod::track::x) > 1.0f, nabs(o2::aod::track::y) > 1.0f); + auto cfnspecs = createOperations(cfn); + auto infield4 = o2::aod::track::Signed1Pt::asArrowField(); + auto infield5 = o2::aod::track::X::asArrowField(); + auto infield6 = o2::aod::track::Y::asArrowField(); + auto schema2 = std::make_shared(std::vector{infield1, infield2, infield3, infield4, infield5, infield6}); + auto gandiva_tree2 = createExpressionTree(cfnspecs, schema2); + auto gandiva_condition2 = makeCondition(gandiva_tree2); + auto gandiva_filter2 = createFilter(schema2, gandiva_condition2); + std::cout << gandiva_tree2->ToString() << std::endl; +} From f5dcf98251e5bc3b396467b200173719c0fada08 Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Mon, 26 Jul 2021 21:02:52 +0200 Subject: [PATCH 2/4] prevent ambiguous templates --- Framework/Core/include/Framework/Expressions.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index b739b913d52d6..8ef79767f5f97 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -356,19 +356,19 @@ inline Node ifnode(Node condition_, Node then_, Node else_) return Node{ConditionalNode{}, std::move(then_), std::move(else_), std::move(condition_)}; } -template +template ::value || std::is_floating_point::value, bool> = true> inline Node ifnode(Node condition_, Node then_, L else_) { return Node{ConditionalNode{}, std::move(then_), LiteralNode{else_}, std::move(condition_)}; } -template +template ::value || std::is_floating_point::value, bool> = true> inline Node ifnode(Node condition_, L then_, Node else_) { return Node{ConditionalNode{}, LiteralNode{then_}, std::move(else_), std::move(condition_)}; } -template +template ::value || std::is_floating_point::value) && (std::is_integral::value || std::is_floating_point::value), bool> = true> inline Node ifnode(Node condition_, L1 then_, L2 else_) { return Node{ConditionalNode{}, LiteralNode{then_}, LiteralNode{else_}, std::move(condition_)}; From 5851b96ff90bfe4dd397265eeaa4769c9b138335 Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Mon, 26 Jul 2021 21:24:24 +0200 Subject: [PATCH 3/4] update test --- Framework/Core/test/test_Expressions.cxx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index a8012abfc70ce..c3fdea4883045 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -258,5 +258,6 @@ BOOST_AUTO_TEST_CASE(TestConditionalExpressions) auto gandiva_tree2 = createExpressionTree(cfnspecs, schema2); auto gandiva_condition2 = makeCondition(gandiva_tree2); auto gandiva_filter2 = createFilter(schema2, gandiva_condition2); - std::cout << gandiva_tree2->ToString() << std::endl; + BOOST_REQUIRE_EQUAL(gandiva_tree2->ToString(), + "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fRawPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fRawPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }"); } From 9c59c111978e86a4d862453ef4ebc37c2dd0776e Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Tue, 27 Jul 2021 10:33:07 +0200 Subject: [PATCH 4/4] prevent template confusion; add explicit method to create a node from a configurable; --- .../Tutorials/src/conditionalExpressions.cxx | 2 +- .../Core/include/Framework/Configurable.h | 9 ++ .../Core/include/Framework/Expressions.h | 110 ++++++++---------- 3 files changed, 60 insertions(+), 61 deletions(-) diff --git a/Analysis/Tutorials/src/conditionalExpressions.cxx b/Analysis/Tutorials/src/conditionalExpressions.cxx index 7582bcc4f28f2..07b9dbac403ed 100644 --- a/Analysis/Tutorials/src/conditionalExpressions.cxx +++ b/Analysis/Tutorials/src/conditionalExpressions.cxx @@ -20,7 +20,7 @@ using namespace o2::framework::expressions; struct ConditionalExpressions { Configurable useFlags{"useFlags", false, "Switch to enable using track flags for selection"}; - Filter trackFilter = nabs(aod::track::eta) < 0.9f && aod::track::pt > 0.5f && ifnode(useFlags == true, (aod::track::flags & static_cast(o2::aod::track::ITSrefit)) != 0u, true); + Filter trackFilter = nabs(aod::track::eta) < 0.9f && aod::track::pt > 0.5f && ifnode(useFlags.node() == true, (aod::track::flags & static_cast(o2::aod::track::ITSrefit)) != 0u, true); OutputObj etapt{TH2F("etapt", ";#eta;#p_{T}", 201, -2.1, 2.1, 601, 0, 60.1)}; void process(aod::Collision const&, soa::Filtered> const& tracks) { diff --git a/Framework/Core/include/Framework/Configurable.h b/Framework/Core/include/Framework/Configurable.h index 3aad0f7154d90..eed9db89836dc 100644 --- a/Framework/Core/include/Framework/Configurable.h +++ b/Framework/Core/include/Framework/Configurable.h @@ -15,6 +15,11 @@ #include namespace o2::framework { +namespace expressions +{ +struct PlaceholderNode; +} + template struct ConfigurableBase { ConfigurableBase(std::string const& name, T&& defaultValue, std::string const& help) @@ -68,6 +73,10 @@ struct Configurable : IP { : IP{name, std::forward(defaultValue), help} { } + auto node() + { + return expressions::PlaceholderNode{*this}; + } }; template diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 8ef79767f5f97..404ef4dc8bbf4 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -128,7 +128,7 @@ struct OpNode { /// A placeholder node for simple type configurable struct PlaceholderNode : LiteralNode { template - PlaceholderNode(Configurable v) : LiteralNode{v.value}, name{v.name} + PlaceholderNode(Configurable const& v) : LiteralNode{v.value}, name{v.name} { if constexpr (variant_trait_v::type> != VariantType::Unknown) { retrieve = [](InitContext& context, std::string const& name) { return LiteralNode::var_t{context.options().get(name.c_str())}; }; @@ -198,65 +198,55 @@ struct Node { /// overloaded operators to build the tree from an expression -#define BINARY_OP_NODES(_operator_, _operation_) \ - template \ - inline Node operator _operator_(Node left, T right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), LiteralNode{right}}; \ - } \ - template \ - inline Node operator _operator_(T left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, LiteralNode{left}, std::move(right)}; \ - } \ - template \ - inline Node operator _operator_(Node left, Configurable right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), PlaceholderNode{right}}; \ - } \ - template \ - inline Node operator _operator_(Configurable left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, std::move(right)}; \ - } \ - inline Node operator _operator_(Node left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), std::move(right)}; \ - } \ - inline Node operator _operator_(BindingNode left, BindingNode right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, right}; \ - } \ - template <> \ - inline Node operator _operator_(BindingNode left, Node right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, std::move(right)}; \ - } \ - template <> \ - inline Node operator _operator_(Node left, BindingNode right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, std::move(left), right}; \ - } \ - \ - template \ - inline Node operator _operator_(Configurable left, BindingNode right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, right}; \ - } \ - template \ - inline Node operator _operator_(BindingNode left, Configurable right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, PlaceholderNode{right}}; \ - } \ - template \ - inline Node operator _operator_(Configurable left, L right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, LiteralNode{right}}; \ - } \ - template \ - inline Node operator _operator_(L left, Configurable right) \ - { \ - return Node{OpNode{BasicOp::_operation_}, left, PlaceholderNode{right}}; \ +#define BINARY_OP_NODES(_operator_, _operation_) \ + template \ + inline Node operator _operator_(Node left, T right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), LiteralNode{right}}; \ + } \ + template \ + inline Node operator _operator_(T left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, LiteralNode{left}, std::move(right)}; \ + } \ + template \ + inline Node operator _operator_(Node left, Configurable right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), PlaceholderNode{right}}; \ + } \ + template \ + inline Node operator _operator_(Configurable left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, std::move(right)}; \ + } \ + inline Node operator _operator_(Node left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), std::move(right)}; \ + } \ + inline Node operator _operator_(BindingNode left, BindingNode right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, right}; \ + } \ + template <> \ + inline Node operator _operator_(BindingNode left, Node right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, std::move(right)}; \ + } \ + template <> \ + inline Node operator _operator_(Node left, BindingNode right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, std::move(left), right}; \ + } \ + \ + template \ + inline Node operator _operator_(Configurable left, BindingNode right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, PlaceholderNode{left}, right}; \ + } \ + template \ + inline Node operator _operator_(BindingNode left, Configurable right) \ + { \ + return Node{OpNode{BasicOp::_operation_}, left, PlaceholderNode{right}}; \ } BINARY_OP_NODES(&, BitwiseAnd);