diff --git a/CHANGELOG.md b/CHANGELOG.md index cabfe5013..7f6e82559 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ - Removed unnecessary full facotorization in the examples and made the input 1 based. +- Added `cons` counterparts to `Vector::getData` methods. + ## Changes to Re::Solve in release 0.99.2 ### Major Features diff --git a/resolve/vector/Vector.cpp b/resolve/vector/Vector.cpp index 6d61d08ec..0bef1b5d9 100644 --- a/resolve/vector/Vector.cpp +++ b/resolve/vector/Vector.cpp @@ -211,10 +211,10 @@ namespace ReSolve * * @pre size of _v_ is equal or larger than the current vector size. */ - int Vector::copyDataFrom(Vector* v, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut) + int Vector::copyDataFrom(Vector* source, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut) { - real_type* data = v->getData(memspaceIn); - return copyDataFrom(data, memspaceIn, memspaceOut); + real_type* source_data = source->getData(memspaceIn); + return copyDataFrom(source_data, memspaceIn, memspaceOut); } /** @@ -228,7 +228,7 @@ namespace ReSolve * * @return 0 if successful, -1 otherwise. */ - int Vector::copyDataFrom(const real_type* data, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut) + int Vector::copyDataFrom(const real_type* source, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut) { int control = -1; if ((memspaceIn == memory::HOST) && (memspaceOut == memory::HOST)) @@ -264,22 +264,22 @@ namespace ReSolve switch (control) { case 0: // cpu->cpu - mem_.copyArrayHostToHost(h_data_, data, n_size_ * k_); + mem_.copyArrayHostToHost(h_data_, source, n_size_ * k_); setHostUpdated(true); setDeviceUpdated(false); break; case 2: // gpu->cpu - mem_.copyArrayDeviceToHost(h_data_, data, n_size_ * k_); + mem_.copyArrayDeviceToHost(h_data_, source, n_size_ * k_); setHostUpdated(true); setDeviceUpdated(false); break; case 1: // cpu->gpu - mem_.copyArrayHostToDevice(d_data_, data, n_size_ * k_); + mem_.copyArrayHostToDevice(d_data_, source, n_size_ * k_); setHostUpdated(false); setDeviceUpdated(true); break; case 3: // gpu->gpu - mem_.copyArrayDeviceToDevice(d_data_, data, n_size_ * k_); + mem_.copyArrayDeviceToDevice(d_data_, source, n_size_ * k_); setHostUpdated(false); setDeviceUpdated(true); break; @@ -298,7 +298,9 @@ namespace ReSolve * vectors are stored column-wise. * * @note This function gives you access to the pointer, not to a copy. - * If you change the values using the pointer, the vector values will change too. + * If you change the values using the pointer, the vector values will + * change too. Make sure to use setDataUpdated function to set the update + * flags correctly after changing the values. */ real_type* Vector::getData(memory::MemorySpace memspace) { @@ -324,6 +326,42 @@ namespace ReSolve } } + /** + * @brief get a pointer to HOST or DEVICE vector data. + * + * @param[in] memspace - Memory space of the pointer (HOST or DEVICE) + * + * @return pointer to the vector data (HOST or DEVICE). In case of multivectors, + * vectors are stored column-wise. + */ + const real_type* Vector::getData(memory::MemorySpace memspace) const + { + using memory::DEVICE; + using memory::HOST; + + switch (memspace) + { + case HOST: + if (cpu_updated_[0] == false) + { + out::error() << "Trying to get data on the host, but host data is out of date!\n" + << "Use syncData function to sync host data with the device data!\n"; + return nullptr; + } + return h_data_; + case DEVICE: + if (gpu_updated_[0] == false) + { + out::error() << "Trying to get data on the device, but device data is out of date!\n" + << "Use syncData function to sync device data with the host data!\n"; + return nullptr; + } + return d_data_; + default: + return nullptr; + } + } + /** * @brief get a pointer to HOST or DEVICE data of a particular vector in a multivector. * @@ -335,7 +373,9 @@ namespace ReSolve * @pre `j` < `k_` i.e, `j` is smaller than the total number of vectors in multivector. * * @note This function gives you access to the pointer, not to a copy. - * If you change the values using the pointer, the vector values will change too. + * If you change the values using the pointer, the vector values will + * change too. Make sure to use setDataUpdated function to set the update + * flags correctly after changing the values. */ real_type* Vector::getData(index_type j, memory::MemorySpace memspace) { @@ -368,6 +408,55 @@ namespace ReSolve } } + /** + * @brief get a const pointer to HOST or DEVICE data of a particular + * vector in a multivector. + * + * @param[in] j - Index of a vector in multivector + * @param[in] memspace - Memory space of the pointer (HOST or DEVICE) + * + * @return pointer to the _i_th vector data (HOST or DEVICE) within a multivector. + * + * @pre `j` < `k_` i.e, `j` is smaller than the total number of vectors in multivector. + * + */ + const real_type* Vector::getData(index_type j, memory::MemorySpace memspace) const + { + using memory::DEVICE; + using memory::HOST; + + if (k_ <= j) + { + out::error() << "Trying to get data for vector " << j << " in multivector" + << " but there are only " << k_ << " vectors!\n"; + return nullptr; + } + + switch (memspace) + { + case HOST: + if (cpu_updated_[j] == false) + { + out::error() << "Trying to get data for vector " << j << " on the host, " + << "but host data is out of date!\n" + << "Use syncData function to sync host data with the device data!\n"; + return nullptr; + } + return &h_data_[j * n_size_]; + case DEVICE: + if (gpu_updated_[j] == false) + { + out::error() << "Trying to get data for vector " << j << " on the device, " + << "but device data is out of date!\n" + << "Use syncData function to sync device data with the host data!\n"; + return nullptr; + } + return &d_data_[j * n_size_]; + default: + return nullptr; + } + } + /** * @brief Sync out of date memory space with the updated one. * @@ -543,11 +632,27 @@ namespace ReSolve delete[] h_data_; h_data_ = new real_type[n_capacity_ * k_]; owns_cpu_data_ = true; + if (gpu_updated_[0]) + { + cpu_updated_[0] = false; + } + else + { + cpu_updated_[0] = true; + } break; case DEVICE: mem_.deleteOnDevice(d_data_); mem_.allocateArrayOnDevice(&d_data_, n_capacity_ * k_); owns_gpu_data_ = true; + if (cpu_updated_[0]) + { + gpu_updated_[0] = false; + } + else + { + gpu_updated_[0] = true; + } break; } return 0; @@ -725,8 +830,8 @@ namespace ReSolve if (new_n_size > n_capacity_) { out::error() << "Trying to resize vector to " << new_n_size - << " elements but memory allocated only for " << n_capacity_ << "elements." - << "\n"; + << " elements but memory allocated only for " << n_capacity_ + << " elements.\n"; return 1; } else @@ -753,7 +858,7 @@ namespace ReSolve * @pre _dest_ is allocated in memspaceInOut memory space. * @post All elements of the vector _i_ are copied to the array _dest_. */ - int Vector::copyDataTo(real_type* dest, + int Vector::copyDataTo(real_type* destination, index_type i, memory::MemorySpace memspaceInOut) { @@ -768,10 +873,10 @@ namespace ReSolve switch (memspaceInOut) { case HOST: - mem_.copyArrayHostToHost(dest, data, n_size_); + mem_.copyArrayHostToHost(destination, data, n_size_); break; case DEVICE: - mem_.copyArrayDeviceToDevice(dest, data, n_size_); + mem_.copyArrayDeviceToDevice(destination, data, n_size_); break; } return 0; @@ -790,17 +895,17 @@ namespace ReSolve * * @pre _dest_ is allocated, and the size of _dest_ is at least _k_ * _n_ . */ - int Vector::copyDataTo(real_type* dest, memory::MemorySpace memspaceInOut) + int Vector::copyDataTo(real_type* destination, memory::MemorySpace memspaceInOut) { using namespace ReSolve::memory; real_type* data = this->getData(memspaceInOut); switch (memspaceInOut) { case HOST: - mem_.copyArrayHostToHost(dest, data, n_size_ * k_); + mem_.copyArrayHostToHost(destination, data, n_size_ * k_); break; case DEVICE: - mem_.copyArrayDeviceToDevice(dest, data, n_size_ * k_); + mem_.copyArrayDeviceToDevice(destination, data, n_size_ * k_); break; } return 0; diff --git a/resolve/vector/Vector.hpp b/resolve/vector/Vector.hpp index dac6183d7..bc6bb9fbc 100644 --- a/resolve/vector/Vector.hpp +++ b/resolve/vector/Vector.hpp @@ -35,10 +35,12 @@ namespace ReSolve Vector(index_type n, index_type k); ~Vector(); - int copyDataFrom(const real_type* data, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut); - int copyDataFrom(Vector* v, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut); - real_type* getData(memory::MemorySpace memspace); - real_type* getData(index_type i, memory::MemorySpace memspace); + int copyDataFrom(const real_type* source, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut); + int copyDataFrom(Vector* source, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut); + real_type* getData(memory::MemorySpace memspace); + real_type* getData(index_type i, memory::MemorySpace memspace); + const real_type* getData(memory::MemorySpace memspace) const; + const real_type* getData(index_type i, memory::MemorySpace memspace) const; index_type getCapacity() const; index_type getSize() const; @@ -55,8 +57,8 @@ namespace ReSolve int syncData(memory::MemorySpace memspaceOut); int syncData(index_type j, memory::MemorySpace memspaceOut); int resize(index_type new_n_current); - int copyDataTo(real_type* dest, index_type i, memory::MemorySpace memspace); - int copyDataTo(real_type* dest, memory::MemorySpace memspace); + int copyDataTo(real_type* destination, index_type i, memory::MemorySpace memspace); + int copyDataTo(real_type* destination, memory::MemorySpace memspace); private: void setHostUpdated(bool is_updated); diff --git a/tests/unit/vector/VectorTests.hpp b/tests/unit/vector/VectorTests.hpp index 6476c8243..e0b693e14 100644 --- a/tests/unit/vector/VectorTests.hpp +++ b/tests/unit/vector/VectorTests.hpp @@ -127,7 +127,7 @@ namespace ReSolve } x.setData(data, memspace_); - real_type* x_data = x.getData(memspace_); + const real_type* x_data = x.getData(memspace_); if (x_data == nullptr) { @@ -187,7 +187,7 @@ namespace ReSolve vector::Vector z(N); z.copyDataFrom(&y, memspace_, memory::HOST); - real_type* z_data = z.getData(memory::HOST); + const real_type* z_data = z.getData(memory::HOST); if (z_data == nullptr) {