Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
template <typename Dtype>
void caffe_copy(const int N, const Dtype *X, Dtype *Y);

void caffe_memcpy(const size_t N, const void *X, void *Y);
void caffe_gpu_memcpy(const size_t N, const void *X, void *Y);

template <typename Dtype>
void caffe_set(const int N, const Dtype alpha, Dtype *X);
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/syncedmem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ inline void SyncedMemory::to_cpu() {
CaffeMallocHost(&cpu_ptr_, size_);
own_cpu_data_ = true;
}
caffe_memcpy(size_, gpu_ptr_, cpu_ptr_);
caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_);
head_ = SYNCED;
break;
case HEAD_AT_CPU:
Expand All @@ -53,7 +53,7 @@ inline void SyncedMemory::to_gpu() {
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
}
caffe_memcpy(size_, cpu_ptr_, gpu_ptr_);
caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_);
head_ = SYNCED;
break;
case HEAD_AT_GPU:
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/test/test_math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ TYPED_TEST(MathFunctionsTest, TestCopyCPU) {
const int n = this->blob_bottom_->count();
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
TypeParam* top_data = this->blob_top_->mutable_cpu_data();
Caffe::set_mode(Caffe::CPU);
caffe_copy(n, bottom_data, top_data);
for (int i = 0; i < n; ++i) {
EXPECT_EQ(bottom_data[i], top_data[i]);
Expand All @@ -219,6 +220,7 @@ TYPED_TEST(MathFunctionsTest, TestCopyGPU) {
const int n = this->blob_bottom_->count();
const TypeParam* bottom_data = this->blob_bottom_->gpu_data();
TypeParam* top_data = this->blob_top_->mutable_gpu_data();
Caffe::set_mode(Caffe::GPU);
caffe_copy(n, bottom_data, top_data);
bottom_data = this->blob_bottom_->cpu_data();
top_data = this->blob_top_->mutable_cpu_data();
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/test/test_syncedmem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
// check if values are the same
char* recovered_value = new char[10];
caffe_memcpy(10, gpu_data, recovered_value);
caffe_gpu_memcpy(10, gpu_data, recovered_value);
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 1);
}
Expand All @@ -72,7 +72,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
gpu_data = mem.gpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
// check if values are the same
caffe_memcpy(10, gpu_data, recovered_value);
caffe_gpu_memcpy(10, gpu_data, recovered_value);
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 2);
}
Expand Down
8 changes: 6 additions & 2 deletions src/caffe/util/math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ void caffe_add_scalar(const int N, const double alpha, double* Y) {
template <typename Dtype>
void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
if (X != Y) {
CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault));
if (Caffe::mode() == Caffe::GPU) {
CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault));
} else {
memcpy(Y, X, sizeof(Dtype) * N);
}
}
}

Expand All @@ -162,7 +166,7 @@ template void caffe_copy<unsigned int>(const int N, const unsigned int* X,
template void caffe_copy<float>(const int N, const float* X, float* Y);
template void caffe_copy<double>(const int N, const double* X, double* Y);

void caffe_memcpy(const size_t N, const void* X, void* Y) {
void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) {
if (X != Y) {
CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
}
Expand Down