From 1470733e805097488b4e7bb77e365a6c5398a2fc Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Thu, 7 Dec 2023 11:06:55 -0500 Subject: [PATCH 1/7] First implementation --- server/api/DataLoader.cpp | 51 ++++++++++++++++++++++++++++++ server/api/DataLoader.h | 21 ++++++++++++ server/api/DataMarker.cpp | 51 ++++++++++++++++++++++++++++++ server/api/DataMarker.h | 21 ++++++++++++ server/api/UnshuffledCsvLoader.cpp | 38 ++++++++++++++++++++++ server/api/UnshuffledCsvLoader.h | 19 +++++++++++ server/api/UnshuffledDataLoader.h | 16 ++++++++++ 7 files changed, 217 insertions(+) create mode 100644 server/api/DataLoader.cpp create mode 100644 server/api/DataLoader.h create mode 100644 server/api/DataMarker.cpp create mode 100644 server/api/DataMarker.h create mode 100644 server/api/UnshuffledCsvLoader.cpp create mode 100644 server/api/UnshuffledCsvLoader.h create mode 100644 server/api/UnshuffledDataLoader.h diff --git a/server/api/DataLoader.cpp b/server/api/DataLoader.cpp new file mode 100644 index 00000000..b08f06a9 --- /dev/null +++ b/server/api/DataLoader.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include "DataLoader.h" +#include "Blob.h" + +void generate_rearrangement(std::vector& rearrangement, std::size_t size) { + rearrangement.resize(size); + for (int i = 0; i < rearrangement.size(); ++i) { + rearrangement[i] = i; + } + // Some shuffle magic from StackOverflow + auto rd = std::random_device {}; + auto rng = std::default_random_engine { rd() }; + std::shuffle(rearrangement.begin(), rearrangement.end(), rng); +} + +DataLoader::DataLoader(UnshuffledDataLoader* _loader): loader(_loader) { + generate_rearrangement(rearrangement, loader->size()); +} + +DataLoader::DataLoader(UnshuffledDataLoader* _loader, std::string path): loader(_loader) { + loader->load_data(path); + generate_rearrangement(rearrangement, loader->size()); +} + +void DataLoader::load_data(std::string path) { + loader->load_data(path); +} + +std::pair DataLoader::operator[](std::size_t index) const { + if (index >= loader->size()) { + throw std::out_of_range("Index out of range"); + } + return (*loader)[rearrangement[index]]; +} + +std::size_t DataLoader::size() const { + return loader->size(); +} + +void DataLoader::add_data(std::pair, float> instance) { + loader->add_data(instance); +} + +std::pair, float> DataLoader::get_raw(std::size_t index) const { + if (index >= loader->size()) { + throw std::out_of_range("Index out of range"); + } + return loader->get_raw(index); +} diff --git a/server/api/DataLoader.h b/server/api/DataLoader.h new file mode 100644 index 00000000..98c19c8b --- /dev/null +++ b/server/api/DataLoader.h @@ -0,0 +1,21 @@ +#pragma once + +#include "UnshuffledDataLoader.h" +#include + +void generate_rearrangement(std::vector& rearrangement, std::size_t size); + +class DataLoader { +private: + UnshuffledDataLoader* loader; + std::vector rearrangement; +public: + DataLoader() = default; + DataLoader(UnshuffledDataLoader* _loader); + DataLoader(UnshuffledDataLoader* _loader, std::string path); + void load_data(std::string path); + std::pair operator[](std::size_t index) const; + void add_data(std::pair, float> instance); + std::size_t size() const; + std::pair, float> get_raw(std::size_t index) const; +}; diff --git a/server/api/DataMarker.cpp b/server/api/DataMarker.cpp new file mode 100644 index 00000000..1b9945bb --- /dev/null +++ b/server/api/DataMarker.cpp @@ -0,0 +1,51 @@ +#include +#include "DataMarker.h" +#include "UnshuffledCsvLoader.h" +#include "Blob.h" + +DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_train) { + if (percentage_for_train > 100 || percentage_for_train < 0) { + throw std::logic_error("Wrong percentage"); + } + DataLoader file_loader; + UnshuffledCsvLoader file_unshuffled_loader; + if (type == FileExtension::Csv) { + file_unshuffled_loader = UnshuffledCsvLoader(); + + train_unshuffled_loader = new UnshuffledCsvLoader; + check_unshuffled_loader = new UnshuffledCsvLoader; + } + else if (type == FileExtension::Png) { + throw std::logic_error("Not implemented"); + } + else { + throw std::logic_error("Unsupported type"); + } + file_loader = DataLoader(&file_unshuffled_loader, path); + std::vector rearrangement; + generate_rearrangement(rearrangement, file_loader.size()); + train_loader = DataLoader(train_unshuffled_loader); + check_loader = DataLoader(check_unshuffled_loader); + int instances_for_train = percentage_for_train * (file_loader.size()) / 100; + for (int i = 0; i < file_loader.size(); ++i) { + if (i < instances_for_train) { + train_loader.add_data(file_loader.get_raw(rearrangement[i])); + } + else { + check_loader.add_data(file_loader.get_raw(rearrangement[i])); + } + } +} + +DataMarker::~DataMarker() { + delete train_unshuffled_loader; + delete check_unshuffled_loader; +} + +DataLoader DataMarker::get_check_loader() { + return check_loader; +} + +DataLoader DataMarker::get_train_loader() { + return train_loader; +} \ No newline at end of file diff --git a/server/api/DataMarker.h b/server/api/DataMarker.h new file mode 100644 index 00000000..0eaba1db --- /dev/null +++ b/server/api/DataMarker.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include "UnshuffledDataLoader.h" +#include "DataLoader.h" + +enum class FileExtension {Csv, Png}; + +class DataMarker { +private: + UnshuffledDataLoader* train_unshuffled_loader; + DataLoader train_loader; + UnshuffledDataLoader* check_unshuffled_loader; + DataLoader check_loader; +public: + DataMarker() = default; + DataMarker(std::string path, FileExtension file_type, int percentage_for_train); + ~DataMarker(); + DataLoader get_train_loader(); + DataLoader get_check_loader(); +}; diff --git a/server/api/UnshuffledCsvLoader.cpp b/server/api/UnshuffledCsvLoader.cpp new file mode 100644 index 00000000..96565c22 --- /dev/null +++ b/server/api/UnshuffledCsvLoader.cpp @@ -0,0 +1,38 @@ +#include +#include "UnshuffledCsvLoader.h" +#include "CsvLoader.h" +#include "Allocator.h" +#include "Blob.h" + +void UnshuffledCsvLoader::load_data(std::string path) { + data.clear(); + auto file_data = CsvLoader::load_csv(path); + data.resize(file_data.size()); + for (int i = 0; i < file_data.size(); ++i) { + float result = file_data[i].back(); + file_data[i].pop_back(); + data[i] = {file_data[i], result}; + } +} + +std::pair UnshuffledCsvLoader::operator[](std::size_t index) const { + if (index >= data.size()) { + throw std::out_of_range("Index out of range"); + } + return {Blob::constBlob(Shape({0, 0, 1, data[index].first.size()}), data[index].first.data()), data[index].second}; +} + +void UnshuffledCsvLoader::add_data(std::pair, float> instance) { + data.push_back(instance); +} + +std::size_t UnshuffledCsvLoader::size() const { + return data.size(); +} + +std::pair, float> UnshuffledCsvLoader::get_raw(std::size_t index) const { + if (index >= data.size()) { + throw std::out_of_range("Index out of range"); + } + return {data[index].first, data[index].second}; +} \ No newline at end of file diff --git a/server/api/UnshuffledCsvLoader.h b/server/api/UnshuffledCsvLoader.h new file mode 100644 index 00000000..8e8913ab --- /dev/null +++ b/server/api/UnshuffledCsvLoader.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include "UnshuffledDataLoader.h" +#include "Allocator.h" + + +class UnshuffledCsvLoader: public UnshuffledDataLoader { +private: + std::vector, float>> data; +public: + UnshuffledCsvLoader() = default; + void load_data(std::string path) override; + std::pair operator[](std::size_t index) const override; + void add_data(std::pair, float> instance) override; + std::size_t size() const override; + virtual std::pair, float> get_raw(std::size_t index) const override; +}; diff --git a/server/api/UnshuffledDataLoader.h b/server/api/UnshuffledDataLoader.h new file mode 100644 index 00000000..2daf80e6 --- /dev/null +++ b/server/api/UnshuffledDataLoader.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include "Allocator.h" + +class UnshuffledDataLoader { +public: + UnshuffledDataLoader() = default; + virtual ~UnshuffledDataLoader() = default; + virtual void load_data(std::string path) = 0; + virtual std::pair operator[](std::size_t index) const = 0; + virtual void add_data(std::pair, float> instance) = 0; + virtual std::size_t size() const = 0; + virtual std::pair, float> get_raw(std::size_t index) const = 0; +}; From 5db37f436c56475673dd061ba2b4ca0f3344a0e4 Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Sat, 9 Dec 2023 09:56:59 -0500 Subject: [PATCH 2/7] Add tests for csv loading --- server/tests/DataMarkerTests.cpp | 48 ++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 server/tests/DataMarkerTests.cpp diff --git a/server/tests/DataMarkerTests.cpp b/server/tests/DataMarkerTests.cpp new file mode 100644 index 00000000..c2bd6644 --- /dev/null +++ b/server/tests/DataMarkerTests.cpp @@ -0,0 +1,48 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest.h" +#include "DataMarker.h" +#include + +void check_vectors(std::vector, float>>& ans, std::vector, float>>& res) { + CHECK(ans.size() == res.size()); + for (int i = 0; i < ans.size(); ++i) { + CHECK(ans[i].first.size() == res[i].first.size()); + CHECK(ans[i].second == res[i].second); + for (int j = 0; j < ans[i].first.size(); ++j) { + CHECK(ans[i].first[j] == res[i].first[j]); + } + } +} + +TEST_CASE("Csv-test") { + SUBCASE("and-train") { + DataMarker loader = DataMarker("./tests/data/and-train.csv", FileExtension::Csv, 50); + DataLoader for_train = loader.get_train_loader(); + DataLoader for_check = loader.get_check_loader(); + std::vector, float>> ans = {{{0, 0}, 0}, {{0, 1}, 0}, {{1, 0}, 0}, {{1, 1}, 1}}; + std::vector, float>> res; + CHECK(for_train.size() == 2); + CHECK(for_check.size() == 2); + for (int i = 0; i < 2; ++i) { + res.push_back(for_train.get_raw(i)); + res.push_back(for_check.get_raw(i)); + } + sort(res.begin(), res.end()); + check_vectors(ans, res); + } + SUBCASE("xor-train") { + DataMarker loader = DataMarker("./tests/data/xor-train.csv", FileExtension::Csv, 50); + DataLoader for_train = loader.get_train_loader(); + DataLoader for_check = loader.get_check_loader(); + std::vector, float>> ans = {{{0, 0}, 0}, {{0, 1}, 1}, {{1, 0}, 1}, {{1, 1}, 0}}; + std::vector, float>> res; + CHECK(for_train.size() == 2); + CHECK(for_check.size() == 2); + for (int i = 0; i < 2; ++i) { + res.push_back(for_train.get_raw(i)); + res.push_back(for_check.get_raw(i)); + } + sort(res.begin(), res.end()); + check_vectors(ans, res); + } +} \ No newline at end of file From 89650b541300f72ead356348ed9a710732d457e4 Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Sat, 9 Dec 2023 13:19:58 -0500 Subject: [PATCH 3/7] Add unshuffled loader for zip --- server/api/CsvLoader.cpp | 20 +++++++++++++++ server/api/CsvLoader.h | 1 + server/api/DataLoader.cpp | 4 +-- server/api/DataLoader.h | 2 +- server/api/DataMarker.cpp | 4 +-- server/api/ImageLoader.cpp | 7 ++++- server/api/ImageLoader.h | 3 ++- server/api/UnshuffledCsvLoader.cpp | 4 +-- server/api/UnshuffledCsvLoader.h | 2 +- server/api/UnshuffledDataLoader.h | 2 +- server/api/UnshuffledImgLoader.cpp | 41 ++++++++++++++++++++++++++++++ server/api/UnshuffledImgLoader.h | 18 +++++++++++++ 12 files changed, 97 insertions(+), 11 deletions(-) create mode 100644 server/api/UnshuffledImgLoader.cpp create mode 100644 server/api/UnshuffledImgLoader.h diff --git a/server/api/CsvLoader.cpp b/server/api/CsvLoader.cpp index bd5970b3..98f75724 100644 --- a/server/api/CsvLoader.cpp +++ b/server/api/CsvLoader.cpp @@ -21,3 +21,23 @@ std::vector> CsvLoader::load_csv(std::string path) { } return result; } + +std::vector> CsvLoader::load_labels(std::string path) { + std::ifstream fin(path); + if (!fin) { + throw std::runtime_error("No such csv file in directory"); + } + std::string line; + std::vector> result; + getline(fin, line); + while (!line.empty()) { + std::stringstream line_stream(line); + std::string file; + std::string label; + getline(line_stream, file, ','); + getline(line_stream, label, ','); + result.push_back({file, std::stof(label)}); + getline(fin, line); + } + return result; +} \ No newline at end of file diff --git a/server/api/CsvLoader.h b/server/api/CsvLoader.h index a9890d89..1d597ef1 100644 --- a/server/api/CsvLoader.h +++ b/server/api/CsvLoader.h @@ -6,4 +6,5 @@ class CsvLoader { public: static std::vector> load_csv(std::string path); + static std::vector> load_labels(std::string path); }; diff --git a/server/api/DataLoader.cpp b/server/api/DataLoader.cpp index b08f06a9..b8dfbdcb 100644 --- a/server/api/DataLoader.cpp +++ b/server/api/DataLoader.cpp @@ -39,8 +39,8 @@ std::size_t DataLoader::size() const { return loader->size(); } -void DataLoader::add_data(std::pair, float> instance) { - loader->add_data(instance); +void DataLoader::add_data(const DataLoader& other, int index) { + loader->add_data(other.loader, index); } std::pair, float> DataLoader::get_raw(std::size_t index) const { diff --git a/server/api/DataLoader.h b/server/api/DataLoader.h index 98c19c8b..d4650259 100644 --- a/server/api/DataLoader.h +++ b/server/api/DataLoader.h @@ -15,7 +15,7 @@ class DataLoader { DataLoader(UnshuffledDataLoader* _loader, std::string path); void load_data(std::string path); std::pair operator[](std::size_t index) const; - void add_data(std::pair, float> instance); + void add_data(const DataLoader& other, int index); std::size_t size() const; std::pair, float> get_raw(std::size_t index) const; }; diff --git a/server/api/DataMarker.cpp b/server/api/DataMarker.cpp index 1b9945bb..47096b55 100644 --- a/server/api/DataMarker.cpp +++ b/server/api/DataMarker.cpp @@ -29,10 +29,10 @@ DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_ int instances_for_train = percentage_for_train * (file_loader.size()) / 100; for (int i = 0; i < file_loader.size(); ++i) { if (i < instances_for_train) { - train_loader.add_data(file_loader.get_raw(rearrangement[i])); + train_loader.add_data(file_loader, rearrangement[i]); } else { - check_loader.add_data(file_loader.get_raw(rearrangement[i])); + check_loader.add_data(file_loader, rearrangement[i]); } } } diff --git a/server/api/ImageLoader.cpp b/server/api/ImageLoader.cpp index a6201421..2aaa6ac9 100644 --- a/server/api/ImageLoader.cpp +++ b/server/api/ImageLoader.cpp @@ -1,6 +1,6 @@ #include "ImageLoader.h" -std::vector ImageLoader::load_image(char* path) { +std::vector ImageLoader::load_image(const char* path) { cimg_library::CImg image(path); return get_pixels(image); } @@ -22,4 +22,9 @@ std::vector ImageLoader::get_pixels(cimg_library::CImg img } } return ans; +} + +std::pair ImageLoader::get_size(const char *path) { + cimg_library::CImg image(path); + return {image.width(), image.height()}; } \ No newline at end of file diff --git a/server/api/ImageLoader.h b/server/api/ImageLoader.h index ca57561b..8fd09733 100644 --- a/server/api/ImageLoader.h +++ b/server/api/ImageLoader.h @@ -7,6 +7,7 @@ class ImageLoader { public: - static std::vector load_image(char* path); + static std::vector load_image(const char* path); static std::vector get_pixels(cimg_library::CImg); + static std::pair get_size(const char* path); }; diff --git a/server/api/UnshuffledCsvLoader.cpp b/server/api/UnshuffledCsvLoader.cpp index 96565c22..b4b9f585 100644 --- a/server/api/UnshuffledCsvLoader.cpp +++ b/server/api/UnshuffledCsvLoader.cpp @@ -22,8 +22,8 @@ std::pair UnshuffledCsvLoader::operator[](std::size_t index) const return {Blob::constBlob(Shape({0, 0, 1, data[index].first.size()}), data[index].first.data()), data[index].second}; } -void UnshuffledCsvLoader::add_data(std::pair, float> instance) { - data.push_back(instance); +void UnshuffledCsvLoader::add_data(const UnshuffledDataLoader* other, int index) { + data.push_back(other->get_raw(index)); } std::size_t UnshuffledCsvLoader::size() const { diff --git a/server/api/UnshuffledCsvLoader.h b/server/api/UnshuffledCsvLoader.h index 8e8913ab..a6466a22 100644 --- a/server/api/UnshuffledCsvLoader.h +++ b/server/api/UnshuffledCsvLoader.h @@ -13,7 +13,7 @@ class UnshuffledCsvLoader: public UnshuffledDataLoader { UnshuffledCsvLoader() = default; void load_data(std::string path) override; std::pair operator[](std::size_t index) const override; - void add_data(std::pair, float> instance) override; + void add_data(const UnshuffledDataLoader* other, int index) override; std::size_t size() const override; virtual std::pair, float> get_raw(std::size_t index) const override; }; diff --git a/server/api/UnshuffledDataLoader.h b/server/api/UnshuffledDataLoader.h index 2daf80e6..72f82c97 100644 --- a/server/api/UnshuffledDataLoader.h +++ b/server/api/UnshuffledDataLoader.h @@ -10,7 +10,7 @@ class UnshuffledDataLoader { virtual ~UnshuffledDataLoader() = default; virtual void load_data(std::string path) = 0; virtual std::pair operator[](std::size_t index) const = 0; - virtual void add_data(std::pair, float> instance) = 0; + virtual void add_data(const UnshuffledDataLoader* other, int index) = 0; virtual std::size_t size() const = 0; virtual std::pair, float> get_raw(std::size_t index) const = 0; }; diff --git a/server/api/UnshuffledImgLoader.cpp b/server/api/UnshuffledImgLoader.cpp new file mode 100644 index 00000000..357e443f --- /dev/null +++ b/server/api/UnshuffledImgLoader.cpp @@ -0,0 +1,41 @@ +#include "UnshuffledImgLoader.h" +#include "CsvLoader.h" +#include "Blob.h" +#include "Allocator.h" +#include "ImageLoader.h" +#include +#include + +void UnshuffledImgLoader::load_data(std::string path) { + for (auto const& dir_entry : std::filesystem::recursive_directory_iterator(path)) { + std::string file_path = dir_entry.path(); + if (file_path.size() >= 4 && file_path.substr(file_path.size() - 4, 4) == ".csv") { + data = CsvLoader::load_labels(file_path.c_str()); + break; + } + } +} + +std::pair UnshuffledImgLoader::operator[](std::size_t index) const { + + auto index_data = get_raw(index); + auto img_size = ImageLoader::get_size(data[index].first.c_str()); + return {Blob::constBlob(Shape({1, 3, img_size.first, img_size.second}), index_data.first.data()), index_data.second}; +} + +std::pair, float> UnshuffledImgLoader::get_raw(std::size_t index) const { + if (index >= data.size()) { + throw std::out_of_range("Index out of range"); + } + std::string file_path = data[index].first; + float ans = data[index].second; + return {ImageLoader::load_image(file_path.c_str()), ans}; +} + +std::size_t UnshuffledImgLoader::size() const { + return data.size(); +} + +void UnshuffledImgLoader::add_data(const UnshuffledDataLoader* other, int index) { + data.push_back(reinterpret_cast(other)->data[index]); +} \ No newline at end of file diff --git a/server/api/UnshuffledImgLoader.h b/server/api/UnshuffledImgLoader.h new file mode 100644 index 00000000..911b2cab --- /dev/null +++ b/server/api/UnshuffledImgLoader.h @@ -0,0 +1,18 @@ +#pragma once + +#include "UnshuffledDataLoader.h" +#include "Blob.h" +#include +#include + +class UnshuffledImgLoader: public UnshuffledDataLoader { +private: + std::vector> data; +public: + UnshuffledImgLoader() = default; + void load_data(std::string path) override; // path to folder + std::pair operator[](std::size_t index) const override; + void add_data(const UnshuffledDataLoader* other, int index) override; + std::size_t size() const override; + virtual std::pair, float> get_raw(std::size_t index) const override; +}; From d539cdc54ac8ffb3ab4f8f0c372fd762258647a2 Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Sat, 9 Dec 2023 14:32:29 -0500 Subject: [PATCH 4/7] Add test for image --- server/api/DataMarker.cpp | 13 ++++++++----- server/api/UnshuffledImgLoader.cpp | 3 +++ server/tests/DataMarkerTests.cpp | 17 +++++++++++++++++ server/tests/data/1/black_pixel.png | Bin 0 -> 120 bytes server/tests/data/1/labels.csv | 5 +++++ server/tests/data/1/lazure_pixel.png | Bin 0 -> 120 bytes server/tests/data/1/picture.png | Bin 0 -> 143 bytes server/tests/data/1/traffic_light.png | Bin 0 -> 126 bytes server/tests/data/1/white_pixel.png | Bin 0 -> 120 bytes 9 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 server/tests/data/1/black_pixel.png create mode 100644 server/tests/data/1/labels.csv create mode 100644 server/tests/data/1/lazure_pixel.png create mode 100644 server/tests/data/1/picture.png create mode 100644 server/tests/data/1/traffic_light.png create mode 100644 server/tests/data/1/white_pixel.png diff --git a/server/api/DataMarker.cpp b/server/api/DataMarker.cpp index 47096b55..8191c07e 100644 --- a/server/api/DataMarker.cpp +++ b/server/api/DataMarker.cpp @@ -1,6 +1,7 @@ #include #include "DataMarker.h" #include "UnshuffledCsvLoader.h" +#include "UnshuffledImgLoader.h" #include "Blob.h" DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_train) { @@ -8,20 +9,21 @@ DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_ throw std::logic_error("Wrong percentage"); } DataLoader file_loader; - UnshuffledCsvLoader file_unshuffled_loader; + UnshuffledDataLoader* file_unshuffled_loader; if (type == FileExtension::Csv) { - file_unshuffled_loader = UnshuffledCsvLoader(); - + file_unshuffled_loader = new UnshuffledCsvLoader; train_unshuffled_loader = new UnshuffledCsvLoader; check_unshuffled_loader = new UnshuffledCsvLoader; } else if (type == FileExtension::Png) { - throw std::logic_error("Not implemented"); + file_unshuffled_loader = new UnshuffledImgLoader; + train_unshuffled_loader = new UnshuffledImgLoader; + check_unshuffled_loader = new UnshuffledImgLoader; } else { throw std::logic_error("Unsupported type"); } - file_loader = DataLoader(&file_unshuffled_loader, path); + file_loader = DataLoader(file_unshuffled_loader, path); std::vector rearrangement; generate_rearrangement(rearrangement, file_loader.size()); train_loader = DataLoader(train_unshuffled_loader); @@ -35,6 +37,7 @@ DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_ check_loader.add_data(file_loader, rearrangement[i]); } } + delete file_unshuffled_loader; } DataMarker::~DataMarker() { diff --git a/server/api/UnshuffledImgLoader.cpp b/server/api/UnshuffledImgLoader.cpp index 357e443f..5d18fd34 100644 --- a/server/api/UnshuffledImgLoader.cpp +++ b/server/api/UnshuffledImgLoader.cpp @@ -14,6 +14,9 @@ void UnshuffledImgLoader::load_data(std::string path) { break; } } + for (int i = 0; i < data.size(); ++i) { + data[i].first = path + "/" + data[i].first; + } } std::pair UnshuffledImgLoader::operator[](std::size_t index) const { diff --git a/server/tests/DataMarkerTests.cpp b/server/tests/DataMarkerTests.cpp index c2bd6644..a4c50881 100644 --- a/server/tests/DataMarkerTests.cpp +++ b/server/tests/DataMarkerTests.cpp @@ -45,4 +45,21 @@ TEST_CASE("Csv-test") { sort(res.begin(), res.end()); check_vectors(ans, res); } +} + +TEST_CASE("Image-test") { + DataMarker loader = DataMarker("./tests/data/1", FileExtension::Png, 80); + DataLoader for_train = loader.get_train_loader(); + DataLoader for_check = loader.get_check_loader(); + std::vector, float>> ans = {{{255, 255, 255}, 0}, {{0, 0, 0}, 0}, {{159, 252, 253}, 0}, {{255, 255, 0, 0, 255, 255, 0, 0, 0}, 1}, {{0, 255, 100, 153, 136, 255, 0, 174, 100, 217, 0, 255, 0, 201, 100, 234, 21, 255}, 1}}; + std::vector, float>> res; + CHECK(for_train.size() == 4); + CHECK(for_check.size() == 1); + for (int i = 0; i < 4; ++i) { + res.push_back(for_train.get_raw(i)); + } + res.push_back(for_check.get_raw(0)); + sort(res.begin(), res.end()); + sort(ans.begin(), ans.end()); + check_vectors(ans, res); } \ No newline at end of file diff --git a/server/tests/data/1/black_pixel.png b/server/tests/data/1/black_pixel.png new file mode 100644 index 0000000000000000000000000000000000000000..0279819e92a4b9651ea106eaaf6ad25f15455f18 GIT binary patch literal 120 zcmeAS@N?(olHy`uVBq!ia0vp^j3CUx1|;Q0k92}K#X;^)4C~IxyaaMs(j9#r85lP9 zbN@+X1@buyJR*x37=%hdnDJhkd<9UD*VDx@L?S#nAtB)hKLaBRBSV2gDmPGq!PC{x JWt~$(697xK7`p%f literal 0 HcmV?d00001 diff --git a/server/tests/data/1/labels.csv b/server/tests/data/1/labels.csv new file mode 100644 index 00000000..4d1e5d6a --- /dev/null +++ b/server/tests/data/1/labels.csv @@ -0,0 +1,5 @@ +picture.png, 1 +traffic_light.png, 1 +black_pixel.png, 0 +white_pixel.png, 0 +lazure_pixel.png, 0 diff --git a/server/tests/data/1/lazure_pixel.png b/server/tests/data/1/lazure_pixel.png new file mode 100644 index 0000000000000000000000000000000000000000..e914eb77d2514e2943aec9dd0fc22b5d07bda7d1 GIT binary patch literal 120 zcmeAS@N?(olHy`uVBq!ia0vp^j3CUx1|;Q0k92}K#X;^)4C~IxyaaMs(j9#r85lP9 zbN@+X1@buyJR*x37=%hdnDJhkd<9UD*VDx@L?S$S#{cK<>lrxmm}ex-sCNcRFnGH9 KxvXa~60+7Besim4Gngy)^j>prDGUi(`mHcydZY!Vmk;&$QW={eSIev?u&} hT7rkC=g=jJdFRaGZ^1y`?a}{ R`zBDD!PC{xWt~$(699r_A5s7S literal 0 HcmV?d00001 diff --git a/server/tests/data/1/white_pixel.png b/server/tests/data/1/white_pixel.png new file mode 100644 index 0000000000000000000000000000000000000000..b201b72e55464f60720eb1d283852b691b7de327 GIT binary patch literal 120 zcmeAS@N?(olHy`uVBq!ia0vp^j3CUx1|;Q0k92}K#X;^)4C~IxyaaMs(j9#r85lP9 zbN@+X1@buyJR*x37=%hdnDJhkd<9UD*VDx@L?S%-$N&HT>lrwIGyi3Ec0C1@VDNPH Kb6Mw<&;$Uh3mshm literal 0 HcmV?d00001 From 866bb3cc1432f0beb301e9597b46ff95fd6f2303 Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Sat, 9 Dec 2023 15:24:26 -0500 Subject: [PATCH 5/7] Add batches --- server/api/DataLoader.cpp | 36 +++++++++++++++++++++++++----- server/api/DataLoader.h | 9 ++++---- server/api/DataMarker.cpp | 8 +++---- server/api/DataMarker.h | 2 +- server/api/UnshuffledCsvLoader.cpp | 13 ++++------- server/api/UnshuffledCsvLoader.h | 5 ++--- server/api/UnshuffledDataLoader.h | 2 +- server/api/UnshuffledImgLoader.cpp | 6 ++--- server/api/UnshuffledImgLoader.h | 4 ++-- 9 files changed, 51 insertions(+), 34 deletions(-) diff --git a/server/api/DataLoader.cpp b/server/api/DataLoader.cpp index b8dfbdcb..c29ef39b 100644 --- a/server/api/DataLoader.cpp +++ b/server/api/DataLoader.cpp @@ -3,6 +3,7 @@ #include #include "DataLoader.h" #include "Blob.h" +#include "Allocator.h" void generate_rearrangement(std::vector& rearrangement, std::size_t size) { rearrangement.resize(size); @@ -15,11 +16,11 @@ void generate_rearrangement(std::vector& rearrangement, std::size_t size) { std::shuffle(rearrangement.begin(), rearrangement.end(), rng); } -DataLoader::DataLoader(UnshuffledDataLoader* _loader): loader(_loader) { +DataLoader::DataLoader(UnshuffledDataLoader* _loader, std::size_t _batch_size): loader(_loader), batch_size(_batch_size) { generate_rearrangement(rearrangement, loader->size()); } -DataLoader::DataLoader(UnshuffledDataLoader* _loader, std::string path): loader(_loader) { +DataLoader::DataLoader(UnshuffledDataLoader* _loader, std::size_t _batch_size, std::string path): loader(_loader), batch_size(_batch_size) { loader->load_data(path); generate_rearrangement(rearrangement, loader->size()); } @@ -28,11 +29,13 @@ void DataLoader::load_data(std::string path) { loader->load_data(path); } -std::pair DataLoader::operator[](std::size_t index) const { +std::pair> DataLoader::operator[](std::size_t index) const { // batch_size lines from index if (index >= loader->size()) { throw std::out_of_range("Index out of range"); } - return (*loader)[rearrangement[index]]; + auto data = get_raw(index); + Shape shape = loader->get_appropriate_shape(index, batch_size); + return {Blob::constBlob(shape, data.first.data()), data.second}; } std::size_t DataLoader::size() const { @@ -43,9 +46,30 @@ void DataLoader::add_data(const DataLoader& other, int index) { loader->add_data(other.loader, index); } -std::pair, float> DataLoader::get_raw(std::size_t index) const { +std::pair, std::vector> DataLoader::get_raw(std::size_t index) const { // batch_size lines from index if (index >= loader->size()) { throw std::out_of_range("Index out of range"); } - return loader->get_raw(index); + std::vector data; + std::vector res(batch_size, 0); + Shape shape = loader->get_appropriate_shape(index, batch_size); + auto dims = shape.getDims(); + int data_size = 1; + for (int i = 0; i < dims.size(); ++i) { + data_size *= dims[i]; + } + data.resize(data_size, 0); + int cur_data = 0; + for (int i = index; i < index + batch_size; ++i) { + if (i >= loader->size()) { + break; + } + auto line = loader->get_raw(i); + res[i - index] = line.second; + for (int j = 0; j < line.first.size(); ++j) { + data[cur_data] = line.first[j]; + cur_data++; + } + } + return {data, res}; } diff --git a/server/api/DataLoader.h b/server/api/DataLoader.h index d4650259..adbca094 100644 --- a/server/api/DataLoader.h +++ b/server/api/DataLoader.h @@ -9,13 +9,14 @@ class DataLoader { private: UnshuffledDataLoader* loader; std::vector rearrangement; + std::size_t batch_size; public: DataLoader() = default; - DataLoader(UnshuffledDataLoader* _loader); - DataLoader(UnshuffledDataLoader* _loader, std::string path); + DataLoader(UnshuffledDataLoader* _loader, std::size_t _batch_size); + DataLoader(UnshuffledDataLoader* _loader, std::size_t _batch_size, std::string path); void load_data(std::string path); - std::pair operator[](std::size_t index) const; + std::pair> operator[](std::size_t index) const; void add_data(const DataLoader& other, int index); std::size_t size() const; - std::pair, float> get_raw(std::size_t index) const; + std::pair, std::vector> get_raw(std::size_t index) const; }; diff --git a/server/api/DataMarker.cpp b/server/api/DataMarker.cpp index 8191c07e..e700a929 100644 --- a/server/api/DataMarker.cpp +++ b/server/api/DataMarker.cpp @@ -4,7 +4,7 @@ #include "UnshuffledImgLoader.h" #include "Blob.h" -DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_train) { +DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_train, std::size_t batch_size) { if (percentage_for_train > 100 || percentage_for_train < 0) { throw std::logic_error("Wrong percentage"); } @@ -23,11 +23,11 @@ DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_ else { throw std::logic_error("Unsupported type"); } - file_loader = DataLoader(file_unshuffled_loader, path); + file_loader = DataLoader(file_unshuffled_loader, batch_size, path); std::vector rearrangement; generate_rearrangement(rearrangement, file_loader.size()); - train_loader = DataLoader(train_unshuffled_loader); - check_loader = DataLoader(check_unshuffled_loader); + train_loader = DataLoader(train_unshuffled_loader, batch_size); + check_loader = DataLoader(check_unshuffled_loader, batch_size); int instances_for_train = percentage_for_train * (file_loader.size()) / 100; for (int i = 0; i < file_loader.size(); ++i) { if (i < instances_for_train) { diff --git a/server/api/DataMarker.h b/server/api/DataMarker.h index 0eaba1db..4beff113 100644 --- a/server/api/DataMarker.h +++ b/server/api/DataMarker.h @@ -14,7 +14,7 @@ class DataMarker { DataLoader check_loader; public: DataMarker() = default; - DataMarker(std::string path, FileExtension file_type, int percentage_for_train); + DataMarker(std::string path, FileExtension file_type, int percentage_for_train, std::size_t batch_size); ~DataMarker(); DataLoader get_train_loader(); DataLoader get_check_loader(); diff --git a/server/api/UnshuffledCsvLoader.cpp b/server/api/UnshuffledCsvLoader.cpp index b4b9f585..68c6a3f7 100644 --- a/server/api/UnshuffledCsvLoader.cpp +++ b/server/api/UnshuffledCsvLoader.cpp @@ -1,8 +1,6 @@ #include #include "UnshuffledCsvLoader.h" #include "CsvLoader.h" -#include "Allocator.h" -#include "Blob.h" void UnshuffledCsvLoader::load_data(std::string path) { data.clear(); @@ -15,13 +13,6 @@ void UnshuffledCsvLoader::load_data(std::string path) { } } -std::pair UnshuffledCsvLoader::operator[](std::size_t index) const { - if (index >= data.size()) { - throw std::out_of_range("Index out of range"); - } - return {Blob::constBlob(Shape({0, 0, 1, data[index].first.size()}), data[index].first.data()), data[index].second}; -} - void UnshuffledCsvLoader::add_data(const UnshuffledDataLoader* other, int index) { data.push_back(other->get_raw(index)); } @@ -35,4 +26,8 @@ std::pair, float> UnshuffledCsvLoader::get_raw(std::size_t in throw std::out_of_range("Index out of range"); } return {data[index].first, data[index].second}; +} + +Shape UnshuffledCsvLoader::get_appropriate_shape(std::size_t index, std::size_t batch_size) const { + return Shape({batch_size, data[index].first.size()}); } \ No newline at end of file diff --git a/server/api/UnshuffledCsvLoader.h b/server/api/UnshuffledCsvLoader.h index a6466a22..64b63279 100644 --- a/server/api/UnshuffledCsvLoader.h +++ b/server/api/UnshuffledCsvLoader.h @@ -3,7 +3,6 @@ #include #include #include "UnshuffledDataLoader.h" -#include "Allocator.h" class UnshuffledCsvLoader: public UnshuffledDataLoader { @@ -12,8 +11,8 @@ class UnshuffledCsvLoader: public UnshuffledDataLoader { public: UnshuffledCsvLoader() = default; void load_data(std::string path) override; - std::pair operator[](std::size_t index) const override; void add_data(const UnshuffledDataLoader* other, int index) override; std::size_t size() const override; - virtual std::pair, float> get_raw(std::size_t index) const override; + std::pair, float> get_raw(std::size_t index) const override; + Shape get_appropriate_shape(std::size_t index, std::size_t batch_size) const override; }; diff --git a/server/api/UnshuffledDataLoader.h b/server/api/UnshuffledDataLoader.h index 72f82c97..6e938111 100644 --- a/server/api/UnshuffledDataLoader.h +++ b/server/api/UnshuffledDataLoader.h @@ -9,8 +9,8 @@ class UnshuffledDataLoader { UnshuffledDataLoader() = default; virtual ~UnshuffledDataLoader() = default; virtual void load_data(std::string path) = 0; - virtual std::pair operator[](std::size_t index) const = 0; virtual void add_data(const UnshuffledDataLoader* other, int index) = 0; virtual std::size_t size() const = 0; virtual std::pair, float> get_raw(std::size_t index) const = 0; + virtual Shape get_appropriate_shape(std::size_t index, std::size_t batch_size) const = 0; }; diff --git a/server/api/UnshuffledImgLoader.cpp b/server/api/UnshuffledImgLoader.cpp index 5d18fd34..772bf959 100644 --- a/server/api/UnshuffledImgLoader.cpp +++ b/server/api/UnshuffledImgLoader.cpp @@ -19,11 +19,9 @@ void UnshuffledImgLoader::load_data(std::string path) { } } -std::pair UnshuffledImgLoader::operator[](std::size_t index) const { - - auto index_data = get_raw(index); +Shape UnshuffledImgLoader::get_appropriate_shape(std::size_t index, std::size_t batch_size) const { auto img_size = ImageLoader::get_size(data[index].first.c_str()); - return {Blob::constBlob(Shape({1, 3, img_size.first, img_size.second}), index_data.first.data()), index_data.second}; + return Shape({batch_size, 3, img_size.first, img_size.second}); } std::pair, float> UnshuffledImgLoader::get_raw(std::size_t index) const { diff --git a/server/api/UnshuffledImgLoader.h b/server/api/UnshuffledImgLoader.h index 911b2cab..2eb0753c 100644 --- a/server/api/UnshuffledImgLoader.h +++ b/server/api/UnshuffledImgLoader.h @@ -11,8 +11,8 @@ class UnshuffledImgLoader: public UnshuffledDataLoader { public: UnshuffledImgLoader() = default; void load_data(std::string path) override; // path to folder - std::pair operator[](std::size_t index) const override; void add_data(const UnshuffledDataLoader* other, int index) override; std::size_t size() const override; - virtual std::pair, float> get_raw(std::size_t index) const override; + std::pair, float> get_raw(std::size_t index) const override; + Shape get_appropriate_shape(std::size_t index, std::size_t batch_size) const override; }; From 30f5e25ac8f795082e0d9f1768959fd1f7dae865 Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Sat, 9 Dec 2023 15:25:36 -0500 Subject: [PATCH 6/7] Add batches --- server/tests/DataMarkerTests.cpp | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/server/tests/DataMarkerTests.cpp b/server/tests/DataMarkerTests.cpp index a4c50881..e4ba27a8 100644 --- a/server/tests/DataMarkerTests.cpp +++ b/server/tests/DataMarkerTests.cpp @@ -16,7 +16,7 @@ void check_vectors(std::vector, float>>& ans, std:: TEST_CASE("Csv-test") { SUBCASE("and-train") { - DataMarker loader = DataMarker("./tests/data/and-train.csv", FileExtension::Csv, 50); + DataMarker loader = DataMarker("./tests/data/and-train.csv", FileExtension::Csv, 50, 1); DataLoader for_train = loader.get_train_loader(); DataLoader for_check = loader.get_check_loader(); std::vector, float>> ans = {{{0, 0}, 0}, {{0, 1}, 0}, {{1, 0}, 0}, {{1, 1}, 1}}; @@ -24,14 +24,18 @@ TEST_CASE("Csv-test") { CHECK(for_train.size() == 2); CHECK(for_check.size() == 2); for (int i = 0; i < 2; ++i) { - res.push_back(for_train.get_raw(i)); - res.push_back(for_check.get_raw(i)); + auto line1 = for_train.get_raw(i); + CHECK(line1.second.size() == 1); + res.push_back({line1.first, line1.second[0]}); + auto line2 = for_check.get_raw(i); + CHECK(line2.second.size() == 1); + res.push_back({line2.first, line2.second[0]}); } sort(res.begin(), res.end()); check_vectors(ans, res); } SUBCASE("xor-train") { - DataMarker loader = DataMarker("./tests/data/xor-train.csv", FileExtension::Csv, 50); + DataMarker loader = DataMarker("./tests/data/xor-train.csv", FileExtension::Csv, 50, 1); DataLoader for_train = loader.get_train_loader(); DataLoader for_check = loader.get_check_loader(); std::vector, float>> ans = {{{0, 0}, 0}, {{0, 1}, 1}, {{1, 0}, 1}, {{1, 1}, 0}}; @@ -39,8 +43,12 @@ TEST_CASE("Csv-test") { CHECK(for_train.size() == 2); CHECK(for_check.size() == 2); for (int i = 0; i < 2; ++i) { - res.push_back(for_train.get_raw(i)); - res.push_back(for_check.get_raw(i)); + auto line1 = for_train.get_raw(i); + CHECK(line1.second.size() == 1); + res.push_back({line1.first, line1.second[0]}); + auto line2 = for_check.get_raw(i); + CHECK(line2.second.size() == 1); + res.push_back({line2.first, line2.second[0]}); } sort(res.begin(), res.end()); check_vectors(ans, res); @@ -48,7 +56,7 @@ TEST_CASE("Csv-test") { } TEST_CASE("Image-test") { - DataMarker loader = DataMarker("./tests/data/1", FileExtension::Png, 80); + DataMarker loader = DataMarker("./tests/data/1", FileExtension::Png, 80, 1); DataLoader for_train = loader.get_train_loader(); DataLoader for_check = loader.get_check_loader(); std::vector, float>> ans = {{{255, 255, 255}, 0}, {{0, 0, 0}, 0}, {{159, 252, 253}, 0}, {{255, 255, 0, 0, 255, 255, 0, 0, 0}, 1}, {{0, 255, 100, 153, 136, 255, 0, 174, 100, 217, 0, 255, 0, 201, 100, 234, 21, 255}, 1}}; @@ -56,9 +64,12 @@ TEST_CASE("Image-test") { CHECK(for_train.size() == 4); CHECK(for_check.size() == 1); for (int i = 0; i < 4; ++i) { - res.push_back(for_train.get_raw(i)); + auto line = for_train.get_raw(i); + CHECK(line.second.size() == 1); + res.push_back({line.first, line.second[0]}); } - res.push_back(for_check.get_raw(0)); + CHECK(for_check.get_raw(0).second.size() == 1); + res.push_back({for_check.get_raw(0).first, for_check.get_raw(0).second[0]}); sort(res.begin(), res.end()); sort(ans.begin(), ans.end()); check_vectors(ans, res); From aecc66e8045cdabaae71908fc3f4f2f88625fd96 Mon Sep 17 00:00:00 2001 From: MaxVorosh Date: Sun, 10 Dec 2023 10:59:49 -0500 Subject: [PATCH 7/7] Add shuffle by seed, manual shuffle and fix bug --- server/api/DataLoader.cpp | 11 +++++++---- server/api/DataLoader.h | 1 + server/api/DataMarker.cpp | 2 ++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/server/api/DataLoader.cpp b/server/api/DataLoader.cpp index c29ef39b..9c643ed4 100644 --- a/server/api/DataLoader.cpp +++ b/server/api/DataLoader.cpp @@ -11,8 +11,7 @@ void generate_rearrangement(std::vector& rearrangement, std::size_t size) { rearrangement[i] = i; } // Some shuffle magic from StackOverflow - auto rd = std::random_device {}; - auto rng = std::default_random_engine { rd() }; + auto rng = std::default_random_engine { 32 }; std::shuffle(rearrangement.begin(), rearrangement.end(), rng); } @@ -52,7 +51,7 @@ std::pair, std::vector> DataLoader::get_raw(std::size_ } std::vector data; std::vector res(batch_size, 0); - Shape shape = loader->get_appropriate_shape(index, batch_size); + Shape shape = loader->get_appropriate_shape(rearrangement[index], batch_size); auto dims = shape.getDims(); int data_size = 1; for (int i = 0; i < dims.size(); ++i) { @@ -64,7 +63,7 @@ std::pair, std::vector> DataLoader::get_raw(std::size_ if (i >= loader->size()) { break; } - auto line = loader->get_raw(i); + auto line = loader->get_raw(rearrangement[i]); res[i - index] = line.second; for (int j = 0; j < line.first.size(); ++j) { data[cur_data] = line.first[j]; @@ -73,3 +72,7 @@ std::pair, std::vector> DataLoader::get_raw(std::size_ } return {data, res}; } + +void DataLoader::shuffle() { + generate_rearrangement(rearrangement, loader->size()); +} diff --git a/server/api/DataLoader.h b/server/api/DataLoader.h index adbca094..a10582cb 100644 --- a/server/api/DataLoader.h +++ b/server/api/DataLoader.h @@ -19,4 +19,5 @@ class DataLoader { void add_data(const DataLoader& other, int index); std::size_t size() const; std::pair, std::vector> get_raw(std::size_t index) const; + void shuffle(); }; diff --git a/server/api/DataMarker.cpp b/server/api/DataMarker.cpp index e700a929..be37d5cd 100644 --- a/server/api/DataMarker.cpp +++ b/server/api/DataMarker.cpp @@ -37,6 +37,8 @@ DataMarker::DataMarker(std::string path, FileExtension type, int percentage_for_ check_loader.add_data(file_loader, rearrangement[i]); } } + train_loader.shuffle(); + check_loader.shuffle(); delete file_unshuffled_loader; }