diff --git a/include/bbp/sonata/report_reader.h b/include/bbp/sonata/report_reader.h index e724a778..66a54538 100644 --- a/include/bbp/sonata/report_reader.h +++ b/include/bbp/sonata/report_reader.h @@ -8,8 +8,8 @@ #include -#include #include +#include namespace H5 = HighFive; @@ -97,6 +97,9 @@ template class SONATA_API ReportReader { public: + using Range = std::pair; + using Ranges = std::vector; + class Population { public: @@ -123,19 +126,22 @@ class SONATA_API ReportReader /** * \param node_ids limit the report to the given selection. - * \param tstart return spikes occurring on or after tstart. tstart=nonstd::nullopt - * indicates no limit. \param tstop return spikes occurring on or before tstop. - * tstop=nonstd::nullopt indicates no limit. + * \param tstart return voltages occurring on or after tstart. tstart=nonstd::nullopt + * indicates no limit. \param tstop return voltages occurring on or before tstop. + * tstop=nonstd::nullopt indicates no limit. \param tstride indicates every how many + * timesteps we read data. tstride=nonstd::nullopt indicates that all timesteps are read. */ DataFrame get(const nonstd::optional& node_ids = nonstd::nullopt, const nonstd::optional& tstart = nonstd::nullopt, - const nonstd::optional& tstop = nonstd::nullopt) const; + const nonstd::optional& tstop = nonstd::nullopt, + const nonstd::optional& tstride = nonstd::nullopt) const; private: Population(const H5::File& file, const std::string& populationName); - std::pair getIndex(const nonstd::optional& tstart, const nonstd::optional& tstop) const; + std::pair getIndex(const nonstd::optional& tstart, + const nonstd::optional& tstop) const; - std::vector>> nodes_pointers_; + std::map nodes_pointers_; H5::Group pop_group_; std::vector nodes_ids_; double tstart_, tstop_, tstep_; diff --git a/python/bindings.cpp b/python/bindings.cpp index 5acc9da5..e94af2ba 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -343,10 +343,12 @@ void bindReportReader(py::module& m, const std::string& prefix) { "A population inside a ReportReader") .def("get", &ReportType::Population::get, - "Return reports with all those node_ids between 'tstart' and 'tstop'", + "Return reports with all those node_ids between 'tstart' and 'tstop' with a stride " + "tstride", "node_ids"_a = nonstd::nullopt, "tstart"_a = nonstd::nullopt, - "tstop"_a = nonstd::nullopt) + "tstop"_a = nonstd::nullopt, + "tstride"_a = nonstd::nullopt) .def("get_node_ids", &ReportType::Population::getNodeIds, "Return the list of nodes ids for this population") diff --git a/python/generated/docstrings.h b/python/generated/docstrings.h index 0b25fe02..0821d499 100644 --- a/python/generated/docstrings.h +++ b/python/generated/docstrings.h @@ -255,12 +255,16 @@ R"doc(Parameter ``node_ids``: limit the report to the given selection. Parameter ``tstart``: - return spikes occurring on or after tstart. tstart=nonstd::nullopt - indicates no limit. + return voltages occurring on or after tstart. + tstart=nonstd::nullopt indicates no limit. Parameter ``tstop``: - return spikes occurring on or before tstop. tstop=nonstd::nullopt - indicates no limit.)doc"; + return voltages occurring on or before tstop. + tstop=nonstd::nullopt indicates no limit. + +Parameter ``tstride``: + indicates every how many timesteps we read data. + tstride=nonstd::nullopt indicates that all timesteps are read.)doc"; static const char *__doc_bbp_sonata_ReportReader_Population_getDataUnits = R"doc(Return the unit of data.)doc"; diff --git a/python/tests/test.py b/python/tests/test.py index 791dba72..5c602af2 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -275,13 +275,21 @@ def test_get_reports_from_population(self): self.assertEqual(self.test_obj['All'].times, (0., 1., 0.1)) self.assertEqual(self.test_obj['All'].time_units, 'ms') self.assertEqual(self.test_obj['All'].data_units, 'mV') - self.assertTrue(self.test_obj['All'].sorted) + self.assertFalse(self.test_obj['All'].sorted) self.assertEqual(len(self.test_obj['All'].get().ids), 20) # Number of nodes self.assertEqual(len(self.test_obj['All'].get().times), 10) # number of times self.assertEqual(len(self.test_obj['All'].get().data), 10) # should be the same + sel = self.test_obj['All'].get(node_ids=[13, 14], tstart=0.8, tstop=1.0) self.assertEqual(len(sel.times), 2) # Number of timestamp (0.8 and 0.9) self.assertEqual(list(sel.ids), [13, 14]) + np.testing.assert_allclose(sel.data, [[13.8, 14.8], [13.9, 14.9]]) + + sel_all = self.test_obj['All'].get() + self.assertEqual(sel_all.ids, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]) + + sel_empty = self.test_obj['All'].get(node_ids=[]) + np.testing.assert_allclose(sel_empty.data, np.empty(shape=(0, 0))) class TestElementReportPopulation(unittest.TestCase): def setUp(self): @@ -308,8 +316,8 @@ def test_get_reports_from_population(self): self.assertEqual(self.test_obj['All'].time_units, 'ms') self.assertEqual(self.test_obj['All'].data_units, 'mV') self.assertTrue(self.test_obj['All'].sorted) - self.assertEqual(len(self.test_obj['All'].get().data), 20) # Number of times in this range - self.assertEqual(len(self.test_obj['All'].get().times), 20) # Should be the same + self.assertEqual(len(self.test_obj['All'].get(tstride=2).data), 10) # Number of times in this range + self.assertEqual(len(self.test_obj['All'].get(tstride=2).times), 10) # Should be the same self.assertEqual(len(self.test_obj['All'].get().ids), 100) sel = self.test_obj['All'].get(node_ids=[13, 14], tstart=0.8, tstop=1.2) keys = list(sel.ids) @@ -327,6 +335,7 @@ def test_get_reports_from_population(self): # check following calls succeed (no memory destroyed) np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[1, 2], tstart=3., tstop=3.).data[0], [150.0, 150.1, 150.2, 150.3, 150.4, 150.5, 150.6, 150.7, 150.8, 150.9]) np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[3, 4], tstart=0.2, tstop=0.4).data[0], [11.0, 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9], 1e-6, 0) + np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[3, 4], tstride=4).data[2], [81.0, 81.1, 81.2, 81.3, 81.4, 81.5, 81.6, 81.7, 81.8, 81.9], 1e-6, 0) if __name__ == '__main__': unittest.main() diff --git a/src/report_reader.cpp b/src/report_reader.cpp index 232242d5..68785f27 100644 --- a/src/report_reader.cpp +++ b/src/report_reader.cpp @@ -1,4 +1,5 @@ #include +#include constexpr double EPSILON = 1e-6; @@ -225,8 +226,8 @@ ReportReader::Population::Population(const H5::File& file, const std::string& mapping_group.getDataSet("index_pointers").read(index_pointers); for (size_t i = 0; i < nodes_ids_.size(); ++i) { - nodes_pointers_.emplace_back(nodes_ids_[i], - std::make_pair(index_pointers[i], index_pointers[i + 1])); + nodes_pointers_.emplace(nodes_ids_[i], + std::make_pair(index_pointers[i], index_pointers[i + 1])); } { // Get times @@ -315,18 +316,21 @@ std::pair ReportReader::Population::getIndex( template DataFrame ReportReader::Population::get(const nonstd::optional& selection, const nonstd::optional& tstart, - const nonstd::optional& tstop) const { + const nonstd::optional& tstop, + const nonstd::optional& tstride) const { DataFrame data_frame; - size_t index_start = 0; size_t index_stop = 0; std::tie(index_start, index_stop) = getIndex(tstart, tstop); - + const size_t stride = tstride.value_or(1); + if (stride == 0) { + throw SonataError("tstride should be > 0"); + } if (index_start > index_stop) { throw SonataError("tstart should be <= to tstop"); } - for (size_t i = index_start; i <= index_stop; ++i) { + for (size_t i = index_start; i <= index_stop; i += stride) { data_frame.times.push_back(times_index_[i].second); } @@ -337,10 +341,11 @@ DataFrame ReportReader::Population::get(const nonstd::optional& Selection::Values node_ids; if (!selection) { // Take all nodes in this case + node_ids.reserve(nodes_pointers_.size()); std::transform(nodes_pointers_.begin(), nodes_pointers_.end(), std::back_inserter(node_ids), - [](const std::pair>& node_pointer) { + [](const std::pair& node_pointer) { return node_pointer.first; }); } else if (selection->empty()) { @@ -349,22 +354,24 @@ DataFrame ReportReader::Population::get(const nonstd::optional& node_ids = selection->flatten(); } + Ranges positions; + // min and max offsets of the node_ids requested are calculated + // to reduce the amount of IO that is brought to memory + uint64_t min = std::numeric_limits::max(); + uint64_t max = std::numeric_limits::min(); + auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids"); for (const auto& node_id : node_ids) { - const auto it = std::find_if( - nodes_pointers_.begin(), - nodes_pointers_.end(), - [&node_id](const std::pair>& node_pointer) { - return node_pointer.first == node_id; - }); + const auto it = nodes_pointers_.find(node_id); if (it == nodes_pointers_.end()) { continue; } + min = std::min(it->second.first, min); + max = std::max(it->second.second, max); + positions.emplace_back(it->second.first, it->second.second); - std::vector element_ids; - pop_group_.getGroup("mapping") - .getDataSet("element_ids") - .select({it->second.first}, {it->second.second - it->second.first}) - .read(element_ids); + std::vector element_ids(it->second.second - it->second.first); + dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first}) + .read(element_ids.data()); for (const auto& elem : element_ids) { data_frame.ids.push_back(make_key(node_id, elem)); } @@ -374,43 +381,40 @@ DataFrame ReportReader::Population::get(const nonstd::optional& } // Fill .data member - - auto n_time_entries = index_stop - index_start + 1; - auto n_ids = data_frame.ids.size(); + size_t n_time_entries = ((index_stop - index_start) / stride) + 1; + size_t n_ids = data_frame.ids.size(); data_frame.data.resize(n_time_entries * n_ids); - // FIXME: It will be good to do it for ranges but if node_ids are not sorted it is not easy - // TODO: specialized this function for sorted node_ids? - int ids_index = 0; - for (const auto& node_id : node_ids) { - const auto it = std::find_if( - nodes_pointers_.begin(), - nodes_pointers_.end(), - [&node_id](const std::pair>& node_pointer) { - return node_pointer.first == node_id; - }); - if (it == nodes_pointers_.end()) { - continue; - } - - // elems are by timestamp and by Nodes_id - std::vector> data; - pop_group_.getDataSet("data") - .select({index_start, it->second.first}, - {index_stop - index_start + 1, it->second.second - it->second.first}) - .read(data); - - int timer_index = 0; - - for (const std::vector& datum : data) { - std::copy(datum.data(), - datum.data() + datum.size(), - &data_frame.data[timer_index * n_ids + ids_index]); - ++timer_index; + auto dataset = pop_group_.getDataSet("data"); + auto dataset_type = dataset.getDataType(); + if (dataset_type.getClass() != HighFive::DataTypeClass::Float || dataset_type.getSize() != 4) { + throw SonataError( + fmt::format("DataType of dataset 'data' should be Float32 ('{}' was found)", + dataset_type.string())); + } + std::vector buffer(max - min); + for (size_t timer_index = index_start; timer_index <= index_stop; timer_index += stride) { + // Note: The code assumes that the file is chunked by rows and not by columns + // (i.e., if the chunking changes in the future, the reading method must also be adapted) + dataset.select({timer_index, min}, {1, max - min}).read(buffer.data()); + + off_t offset = 0; + off_t data_offset = (timer_index - index_start) / stride; + auto data_ptr = &data_frame.data[data_offset * n_ids]; + for (const auto& position : positions) { + uint64_t elements_per_gid = position.second - position.first; + uint64_t gid_start = position.first - min; + + // Soma report + if (elements_per_gid == 1) { + data_ptr[offset] = buffer[gid_start]; + } else { // Elements report + uint64_t gid_end = position.second - min; + std::copy(&buffer[gid_start], &buffer[gid_end], &data_ptr[offset]); + } + offset += elements_per_gid; } - ids_index += data[0].size(); } - return data_frame; } diff --git a/tests/data/generate.py b/tests/data/generate.py index 9923bd81..e03e798c 100755 --- a/tests/data/generate.py +++ b/tests/data/generate.py @@ -134,7 +134,7 @@ def write_edges(filepath): def write_soma_report(filepath): population_names = ['All', 'soma1', 'soma2'] - node_ids = np.arange(1, 21) + node_ids = np.concatenate((np.arange(10, 21), np.arange(1, 10)), axis=None) index_pointers = np.arange(0, 21) element_ids = np.zeros(20) times = (0.0, 1.0, 0.1) @@ -148,13 +148,20 @@ def write_soma_report(filepath): gmapping = h5f.create_group('/report/' + population_names[0] + '/mapping') dnodes = gmapping.create_dataset('node_ids', data=node_ids, dtype=np.uint64) - dnodes.attrs.create('sorted', data=True, dtype=np.uint8) gmapping.create_dataset('index_pointers', data=index_pointers, dtype=np.uint64) gmapping.create_dataset('element_ids', data=element_ids, dtype=np.uint32) dtimes = gmapping.create_dataset('time', data=times, dtype=np.double) dtimes.attrs.create('units', data="ms", dtype=string_dtype) gpop_soma1 = h5f.create_group('/report/' + population_names[1]) + ddata2 = gpop_soma1.create_dataset('data', data=data, dtype=np.float64) + ddata2.attrs.create('units', data="mV", dtype=string_dtype) + gmapping2 = h5f.create_group('/report/' + population_names[1] + '/mapping') + gmapping2.create_dataset('node_ids', data=node_ids, dtype=np.uint64) + gmapping2.create_dataset('index_pointers', data=index_pointers, dtype=np.uint64) + gmapping2.create_dataset('element_ids', data=element_ids, dtype=np.uint32) + dtimes2 = gmapping2.create_dataset('time', data=times, dtype=np.double) + dtimes2.attrs.create('units', data="ms", dtype=string_dtype) gpop_soma2 = h5f.create_group('/report/' + population_names[2]) diff --git a/tests/data/somas.h5 b/tests/data/somas.h5 index fce18061..7540da0d 100644 Binary files a/tests/data/somas.h5 and b/tests/data/somas.h5 differ diff --git a/tests/test_report_reader.cpp b/tests/test_report_reader.cpp index 6e36ba51..de1cc25d 100644 --- a/tests/test_report_reader.cpp +++ b/tests/test_report_reader.cpp @@ -57,6 +57,10 @@ TEST_CASE("SomaReportReader limits", "[base]") { // Negatives times REQUIRE_THROWS(pop.get(Selection({{1, 2}}), -1., -2.)); + + // DataType of dataset 'data' should be Float32 + auto pop2 = reader.openPopulation("soma1"); + REQUIRE_THROWS(pop2.get()); } TEST_CASE("SomaReportReader", "[base]") { @@ -72,15 +76,22 @@ TEST_CASE("SomaReportReader", "[base]") { REQUIRE(pop.getDataUnits() == "mV"); - REQUIRE(pop.getSorted()); + REQUIRE(pop.getSorted() == false); - REQUIRE(pop.getNodeIds() == std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + REQUIRE(pop.getNodeIds() == std::vector{10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto data = pop.get(Selection({{3, 5}}), 0.2, 0.5); REQUIRE(data.ids == DataFrame::DataType{{3, 4}}); testTimes(data.times, 0.2, 0.1, 4); REQUIRE(data.data == std::vector{3.2f, 4.2f, 3.3f, 4.3f, 3.4f, 4.4f, 3.5f, 4.5f}); + + auto data_all = pop.get(); + REQUIRE(data_all.ids == DataFrame::DataType{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}}); + + auto data_empty = pop.get(Selection({})); + REQUIRE(data_empty.data == std::vector{}); } TEST_CASE("ElementReportReader limits", "[base]") { @@ -106,6 +117,9 @@ TEST_CASE("ElementReportReader limits", "[base]") { // Negatives times REQUIRE_THROWS(pop.get(Selection({{1, 2}}), -1., -2.)); + + // Stride = 0 + REQUIRE_THROWS(pop.get(Selection({{1, 2}}), 0.1, 0.2, 0)); } TEST_CASE("ElementReportReader", "[base]") {