Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class WorkerSolver : public Solver<Dtype> {
explicit WorkerSolver(const SolverParameter& param,
const Solver<Dtype>* root_solver = NULL)
: Solver<Dtype>(param, root_solver) {}
virtual ~WorkerSolver();

protected:
void ApplyUpdate() {}
Expand Down
15 changes: 6 additions & 9 deletions include/caffe/syncedmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,8 @@ inline void CaffeFreeHost(void* ptr, bool use_cuda) {
*/
class SyncedMemory {
public:
SyncedMemory()
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
own_cpu_data_(false), cpu_malloc_use_cuda_(false), own_gpu_data_(false),
gpu_device_(-1) {}
explicit SyncedMemory(size_t size)
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
own_cpu_data_(false), cpu_malloc_use_cuda_(false), own_gpu_data_(false),
gpu_device_(-1) {}
SyncedMemory();
explicit SyncedMemory(size_t size);
~SyncedMemory();
const void* cpu_data();
void set_cpu_data(void* data);
Expand All @@ -68,6 +62,9 @@ class SyncedMemory {
#endif

private:
void check_device();
void check_device(void* data);

void to_cpu();
void to_gpu();
void* cpu_ptr_;
Expand All @@ -77,7 +74,7 @@ class SyncedMemory {
bool own_cpu_data_;
bool cpu_malloc_use_cuda_;
bool own_gpu_data_;
int gpu_device_;
int device_;

DISABLE_COPY_AND_ASSIGN(SyncedMemory);
}; // class SyncedMemory
Expand Down
15 changes: 15 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,21 @@ void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
}
}

template <typename Dtype>
WorkerSolver<Dtype>::~WorkerSolver() {
#ifndef CPU_ONLY
int device;
CUDA_CHECK(cudaGetDevice(&device));
CUDA_CHECK(cudaSetDevice(this->param_.device_id()));
#endif
this->net_.reset();
this->test_nets_.resize(0);
#ifndef CPU_ONLY
CUDA_CHECK(cudaSetDevice(device));
#endif
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(WorkerSolver);

} // namespace caffe
62 changes: 47 additions & 15 deletions src/caffe/syncedmem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,41 @@
#include "caffe/util/math_functions.hpp"

namespace caffe {
SyncedMemory::SyncedMemory()
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
own_cpu_data_(false), cpu_malloc_use_cuda_(false), own_gpu_data_(false) {
#ifndef CPU_ONLY
#ifdef DEBUG
CUDA_CHECK(cudaGetDevice(&device_));
#endif
#endif
}

SyncedMemory::SyncedMemory(size_t size)
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
own_cpu_data_(false), cpu_malloc_use_cuda_(false), own_gpu_data_(false) {
#ifndef CPU_ONLY
#ifdef DEBUG
CUDA_CHECK(cudaGetDevice(&device_));
#endif
#endif
}

SyncedMemory::~SyncedMemory() {
check_device();
if (cpu_ptr_ && own_cpu_data_) {
CaffeFreeHost(cpu_ptr_, cpu_malloc_use_cuda_);
}

#ifndef CPU_ONLY
if (gpu_ptr_ && own_gpu_data_) {
int initial_device;
cudaGetDevice(&initial_device);
if (gpu_device_ != -1) {
CUDA_CHECK(cudaSetDevice(gpu_device_));
}
CUDA_CHECK(cudaFree(gpu_ptr_));
cudaSetDevice(initial_device);
}
#endif // CPU_ONLY
}

inline void SyncedMemory::to_cpu() {
check_device();
switch (head_) {
case UNINITIALIZED:
CaffeMallocHost(&cpu_ptr_, size_, &cpu_malloc_use_cuda_);
Expand All @@ -49,18 +64,17 @@ inline void SyncedMemory::to_cpu() {
}

inline void SyncedMemory::to_gpu() {
check_device();
#ifndef CPU_ONLY
switch (head_) {
case UNINITIALIZED:
CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
caffe_gpu_memset(size_, 0, gpu_ptr_);
head_ = HEAD_AT_GPU;
own_gpu_data_ = true;
break;
case HEAD_AT_CPU:
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
own_gpu_data_ = true;
}
Expand All @@ -77,11 +91,13 @@ inline void SyncedMemory::to_gpu() {
}

const void* SyncedMemory::cpu_data() {
check_device();
to_cpu();
return (const void*)cpu_ptr_;
}

void SyncedMemory::set_cpu_data(void* data) {
check_device();
CHECK(data);
if (own_cpu_data_) {
CaffeFreeHost(cpu_ptr_, cpu_malloc_use_cuda_);
Expand All @@ -92,6 +108,7 @@ void SyncedMemory::set_cpu_data(void* data) {
}

const void* SyncedMemory::gpu_data() {
check_device();
#ifndef CPU_ONLY
to_gpu();
return (const void*)gpu_ptr_;
Expand All @@ -102,16 +119,11 @@ const void* SyncedMemory::gpu_data() {
}

void SyncedMemory::set_gpu_data(void* data) {
check_device(data);
#ifndef CPU_ONLY
CHECK(data);
if (own_gpu_data_) {
int initial_device;
cudaGetDevice(&initial_device);
if (gpu_device_ != -1) {
CUDA_CHECK(cudaSetDevice(gpu_device_));
}
CUDA_CHECK(cudaFree(gpu_ptr_));
cudaSetDevice(initial_device);
}
gpu_ptr_ = data;
head_ = HEAD_AT_GPU;
Expand All @@ -122,12 +134,14 @@ void SyncedMemory::set_gpu_data(void* data) {
}

void* SyncedMemory::mutable_cpu_data() {
check_device();
to_cpu();
head_ = HEAD_AT_CPU;
return cpu_ptr_;
}

void* SyncedMemory::mutable_gpu_data() {
check_device();
#ifndef CPU_ONLY
to_gpu();
head_ = HEAD_AT_GPU;
Expand All @@ -140,9 +154,9 @@ void* SyncedMemory::mutable_gpu_data() {

#ifndef CPU_ONLY
void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
check_device();
CHECK(head_ == HEAD_AT_CPU);
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
own_gpu_data_ = true;
}
Expand All @@ -153,5 +167,23 @@ void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
}
#endif

void SyncedMemory::check_device() {
if (gpu_ptr_ && own_gpu_data_) {
check_device(gpu_ptr_);
}
}
void SyncedMemory::check_device(void* data) {
#ifndef CPU_ONLY
#ifdef DEBUG
int device;
cudaGetDevice(&device);
CHECK(device == device_);
cudaPointerAttributes attributes;
CUDA_CHECK(cudaPointerGetAttributes(&attributes, data));
CHECK(attributes.device == device_);
#endif
#endif
}

} // namespace caffe