diff --git a/include/bbp/sonata/report_reader.h b/include/bbp/sonata/report_reader.h index 8e33966e..0fa7c255 100644 --- a/include/bbp/sonata/report_reader.h +++ b/include/bbp/sonata/report_reader.h @@ -127,8 +127,29 @@ class SONATA_API ReportReader * Return true if the data is sorted. */ bool getSorted() const; + + /** + * Return all the node ids. + */ std::vector getNodeIds() const; + /** + * Return the ElementIds for the passed Node. + * The return type will depend on the report reader: + * - For Soma report reader, the return value will be the Node ID to which the report + * value belongs to. + * - For Element/full compartment readers, the return value will be an array with 2 + * elements, the first element is the Node ID and the second element is the + * compartment ID of the given Node. + * + * \param node_ids limit the report to the given selection. If nullptr, all nodes in the + * report are used + * \param fn lambda applied to all ranges for all node ids + */ + typename DataFrame::DataType getNodeIdElementIdMapping( + const nonstd::optional& node_ids = nonstd::nullopt, + std::function fn = nullptr) const; + /** * \param node_ids limit the report to the given selection. * \param tstart return voltages occurring on or after tstart. tstart=nonstd::nullopt @@ -154,6 +175,8 @@ class SONATA_API ReportReader std::string time_units_; std::string data_units_; bool nodes_ids_sorted_ = false; + Selection::Values node_ids_from_selection( + const nonstd::optional& node_ids = nonstd::nullopt) const; friend ReportReader; }; diff --git a/python/bindings.cpp b/python/bindings.cpp index 25b5d163..40514f9b 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -364,6 +364,14 @@ void bindReportReader(py::module& m, const std::string& prefix) { .def("get_node_ids", &ReportType::Population::getNodeIds, "Return the list of nodes ids for this population") + .def( + "get_node_id_element_id_mapping", + [](const typename ReportType::Population& population, + const nonstd::optional& selection) { + return population.getNodeIdElementIdMapping(selection, nullptr); + }, + DOC_REPORTREADER_POP(getNodeIdElementIdMapping), + "selection"_a = nonstd::nullopt) .def_property_readonly("sorted", &ReportType::Population::getSorted, DOC_REPORTREADER_POP(getSorted)) diff --git a/python/generated/docstrings.h b/python/generated/docstrings.h index 3491a096..e7d690f3 100644 --- a/python/generated/docstrings.h +++ b/python/generated/docstrings.h @@ -431,7 +431,22 @@ static const char *__doc_bbp_sonata_ReportReader_Population_getDataUnits = R"doc static const char *__doc_bbp_sonata_ReportReader_Population_getIndex = R"doc()doc"; -static const char *__doc_bbp_sonata_ReportReader_Population_getNodeIds = R"doc()doc"; +static const char *__doc_bbp_sonata_ReportReader_Population_getNodeIdElementIdMapping = +R"doc(Return the ElementIds for the passed Node. The return type will depend +on the report reader: - For Soma report reader, the return value will +be the Node ID to which the report value belongs to. - For +Element/full compartment readers, the return value will be an array +with 2 elements, the first element is the Node ID and the second +element is the compartment ID of the given Node. + +Parameter ``node_ids``: + limit the report to the given selection. If nullptr, all nodes in + the report are used + +Parameter ``fn``: + lambda applied to all ranges for all node ids)doc"; + +static const char *__doc_bbp_sonata_ReportReader_Population_getNodeIds = R"doc(Return all the node ids.)doc"; static const char *__doc_bbp_sonata_ReportReader_Population_getSorted = R"doc(Return true if the data is sorted.)doc"; @@ -439,6 +454,8 @@ static const char *__doc_bbp_sonata_ReportReader_Population_getTimeUnits = R"doc static const char *__doc_bbp_sonata_ReportReader_Population_getTimes = R"doc(Return (tstart, tstop, tstep) of the population)doc"; +static const char *__doc_bbp_sonata_ReportReader_Population_node_ids_from_selection = R"doc()doc"; + static const char *__doc_bbp_sonata_ReportReader_Population_nodes_ids = R"doc()doc"; static const char *__doc_bbp_sonata_ReportReader_Population_nodes_ids_sorted = R"doc()doc"; diff --git a/python/tests/test.py b/python/tests/test.py index 9fe70269..3fdc7e30 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -294,6 +294,7 @@ def test_get_spikes_from_population(self): def test_getTimes_from_population(self): self.assertEqual(self.test_obj['All'].times, (0.1, 1.3)) + class TestSomaReportPopulation(unittest.TestCase): def setUp(self): path = os.path.join(PATH, "somas.h5") @@ -332,6 +333,10 @@ def test_get_reports_from_population(self): sel_empty = self.test_obj['All'].get(node_ids=[]) np.testing.assert_allclose(sel_empty.data, np.empty(shape=(0, 0))) + def test_get_node_id_element_id_mapping(self): + self.assertEqual(self.test_obj['All'].get_node_id_element_id_mapping([[3, 5]]), + [3, 4]) + class TestElementReportPopulation(unittest.TestCase): def setUp(self): @@ -384,6 +389,10 @@ def test_get_reports_from_population(self): np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[3, 4], tstride=4).data[2], [81.0, 81.1, 81.2, 81.3, 81.4, 81.5, 81.6, 81.7, 81.8, 81.9], 1e-6, 0) + def test_get_node_id_element_id_mapping(self): + self.assertEqual(self.test_obj['All'].get_node_id_element_id_mapping([[3, 5]]), + [[3, 5], [3, 5], [3, 6], [3, 6], [3, 7], [4, 7], [4, 8], [4, 8], [4, 9], [4, 9]]) + class TestNodePopulationFailure(unittest.TestCase): def test_CorrectStructure(self): diff --git a/src/report_reader.cpp b/src/report_reader.cpp index 6df810b8..3723c7b2 100644 --- a/src/report_reader.cpp +++ b/src/report_reader.cpp @@ -283,6 +283,27 @@ std::vector ReportReader::Population::getNodeIds() const { return nodes_ids_; } +template +Selection::Values ReportReader::Population::node_ids_from_selection( + const nonstd::optional& selection) const { + Selection::Values node_ids; + + if (!selection) { // Take all nodes in this case + node_ids.reserve(nodes_pointers_.size()); + std::transform(nodes_pointers_.begin(), + nodes_pointers_.end(), + std::back_inserter(node_ids), + [](const std::pair& node_pointer) { + return node_pointer.first; + }); + } else if (selection->empty()) { + return {}; + } else { + node_ids = selection->flatten(); + } + return node_ids; +} + template std::pair ReportReader::Population::getIndex( const nonstd::optional& tstart, const nonstd::optional& tstop) const { @@ -318,6 +339,34 @@ std::pair ReportReader::Population::getIndex( } +template +typename DataFrame::DataType ReportReader::Population::getNodeIdElementIdMapping( + const nonstd::optional& selection, std::function fn) const { + typename DataFrame::DataType ids{}; + + Selection::Values node_ids = node_ids_from_selection(selection); + + auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids"); + for (const auto& node_id : node_ids) { + const auto it = nodes_pointers_.find(node_id); + if (it == nodes_pointers_.end()) { + continue; + } + + std::vector element_ids(it->second.second - it->second.first); + dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first}) + .read(element_ids.data()); + for (const auto& elem : element_ids) { + ids.push_back(make_key(node_id, elem)); + } + + if (fn) { + fn(it->second); + } + } + return ids; +} + template DataFrame ReportReader::Population::get(const nonstd::optional& selection, const nonstd::optional& tstart, @@ -339,48 +388,16 @@ DataFrame ReportReader::Population::get(const nonstd::optional& data_frame.times.push_back(times_index_[i].second); } - // Simplify selection - // We should remove duplicates - // And when we can work with ranges let's sort them - // auto nodes_ids_ = Selection::fromValues(node_ids.flatten().sort()); - Selection::Values node_ids; - - if (!selection) { // Take all nodes in this case - node_ids.reserve(nodes_pointers_.size()); - std::transform(nodes_pointers_.begin(), - nodes_pointers_.end(), - std::back_inserter(node_ids), - [](const std::pair& node_pointer) { - return node_pointer.first; - }); - } else if (selection->empty()) { - return DataFrame{{}, {}, {}}; - } else { - node_ids = selection->flatten(); - } - - Ranges positions; // min and max offsets of the node_ids requested are calculated // to reduce the amount of IO that is brought to memory + Ranges positions; uint64_t min = std::numeric_limits::max(); uint64_t max = std::numeric_limits::min(); - auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids"); - for (const auto& node_id : node_ids) { - const auto it = nodes_pointers_.find(node_id); - if (it == nodes_pointers_.end()) { - continue; - } - min = std::min(it->second.first, min); - max = std::max(it->second.second, max); - positions.emplace_back(it->second.first, it->second.second); - - std::vector element_ids(it->second.second - it->second.first); - dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first}) - .read(element_ids.data()); - for (const auto& elem : element_ids) { - data_frame.ids.push_back(make_key(node_id, elem)); - } - } + data_frame.ids = getNodeIdElementIdMapping(selection, [&](const Range& range) { + min = std::min(range.first, min); + max = std::max(range.second, max); + positions.emplace_back(range.first, range.second); + }); if (data_frame.ids.empty()) { // At the end no data available (wrong node_ids?) return DataFrame{{}, {}, {}}; } diff --git a/tests/test_report_reader.cpp b/tests/test_report_reader.cpp index ce881a41..16c73283 100644 --- a/tests/test_report_reader.cpp +++ b/tests/test_report_reader.cpp @@ -94,6 +94,9 @@ TEST_CASE("SomaReportReader", "[base]") { auto data_empty = pop.get(Selection({})); REQUIRE(data_empty.data == std::vector{}); + + auto ids = pop.getNodeIdElementIdMapping(Selection({{3, 5}})); + REQUIRE(ids == std::vector{3, 4}); } TEST_CASE("ElementReportReader limits", "[base]") { @@ -155,4 +158,7 @@ TEST_CASE("ElementReportReader", "[base]") { // Select only one time REQUIRE(pop.get(Selection({{1, 2}}), 0.6, 0.6).data == std::vector{30.0f, 30.1f, 30.2f, 30.3f, 30.4f}); + + auto ids = pop.getNodeIdElementIdMapping(Selection({{3, 5}})); + REQUIRE(ids == std::vector{{3, 5}, {3, 5}, {3, 6}, {3, 6}, {3, 7}, {4, 7}, {4, 8}, {4, 8}, {4, 9}, {4, 9}}); }