From 1413d69b1b1f12386ceb08714a607d2a10dc3196 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 30 Sep 2024 22:15:43 +0000 Subject: [PATCH 01/13] fixed functionality of cpu locked tensor --- csrc/aio/common/deepspeed_aio_utils.cpp | 2 +- csrc/aio/common/deepspeed_aio_utils.h | 2 +- csrc/aio/py_lib/deepspeed_aio_op_desc.cpp | 2 ++ csrc/aio/py_lib/deepspeed_aio_op_desc.h | 2 ++ csrc/aio/py_lib/deepspeed_cpu_op.cpp | 13 ++++++++----- csrc/aio/py_lib/deepspeed_cpu_op.h | 1 + csrc/aio/py_lib/deepspeed_pin_tensor.cpp | 16 +++++++++++++--- csrc/aio/py_lib/deepspeed_pin_tensor.h | 6 ++++-- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 7 ++++--- csrc/aio/py_lib/deepspeed_py_io_handle.h | 2 +- 10 files changed, 37 insertions(+), 16 deletions(-) diff --git a/csrc/aio/common/deepspeed_aio_utils.cpp b/csrc/aio/common/deepspeed_aio_utils.cpp index 763b2c253a34..8fccb1bf96cd 100644 --- a/csrc/aio/common/deepspeed_aio_utils.cpp +++ b/csrc/aio/common/deepspeed_aio_utils.cpp @@ -103,7 +103,7 @@ int get_file_size(const char* filename, long long int& size) return 0; } -void* ds_page_aligned_alloc(const size_t size, const bool lock) +void* ds_page_aligned_alloc(const long long int size, const bool lock) { void* ptr; int retval; diff --git a/csrc/aio/common/deepspeed_aio_utils.h b/csrc/aio/common/deepspeed_aio_utils.h index 9c58c2286610..ea56cd1de236 100644 --- a/csrc/aio/common/deepspeed_aio_utils.h +++ b/csrc/aio/common/deepspeed_aio_utils.h @@ -74,6 +74,6 @@ struct io_prep_generator { int prep_iocbs(const int n_iocbs, std::vector* iocbs); }; -void* ds_page_aligned_alloc(const size_t size, const bool lock = false); +void* ds_page_aligned_alloc(const long long int size, const bool lock = false); int get_file_size(const char* filename, long long int& size); diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp index dc820be528d0..63d2ba022997 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -9,6 +9,7 @@ using namespace std; io_op_desc_t::io_op_desc_t(const bool read_op, const torch::Tensor& buffer, + const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, @@ -16,6 +17,7 @@ io_op_desc_t::io_op_desc_t(const bool read_op, const bool validate) : _read_op(read_op), _buffer(buffer), + _is_managed(is_managed), _fd(fd), _filename(filename), _file_num_bytes(file_num_bytes), diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h index 350d28d29d58..42e55cd2e2a0 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.h +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -12,6 +12,7 @@ struct io_op_desc_t { const bool _read_op; torch::Tensor _buffer; + const bool _is_managed; int _fd; const std::string _filename; const long long int _file_num_bytes; @@ -22,6 +23,7 @@ struct io_op_desc_t { io_op_desc_t(const bool read_op, const torch::Tensor& buffer, + const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp index 41790b99bb88..22b751eef625 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.cpp +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -9,23 +9,25 @@ using namespace std; cpu_op_desc_t::cpu_op_desc_t(const bool read_op, const torch::Tensor& buffer, + const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, const int num_threads, const bool validate) - : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate), + : io_op_desc_t(read_op, buffer, is_managed, fd, filename, file_num_bytes, num_threads, validate), _cpu_buffer(buffer) { // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory. - _use_bounce_buffer = !(_buffer.is_cpu() && _buffer.is_pinned()); + _use_bounce_buffer = !(_buffer.is_cpu() && (_buffer.is_pinned() || _is_managed)); if (_use_bounce_buffer) { if (_read_op) { auto options = torch::TensorOptions() .dtype(_buffer.dtype()) .layout(_buffer.layout()) - .device(torch::kCPU); - _cpu_buffer = torch::empty(_buffer.nbytes(), options).pin_memory(); + .device(torch::kCPU) + .requires_grad(false); + _cpu_buffer = torch::empty(_buffer.numel(), options).pin_memory(); } else { _cpu_buffer = _buffer.to(torch::kCPU).pin_memory(); } @@ -37,9 +39,10 @@ char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_pt void cpu_op_desc_t::finish() { - if (_read_op) { + if (_read_op && _use_bounce_buffer) { if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); } + if (_buffer.is_cpu()) { _buffer.copy_(_cpu_buffer); } #if defined(__ENABLE_CANN__) if (torch_npu::utils::is_npu(_buffer)) { auto device = at::Device("npu:0"); diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h index da96dd2b1d50..efc60f97f93c 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.h +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -13,6 +13,7 @@ struct cpu_op_desc_t : io_op_desc_t { cpu_op_desc_t(const bool read_op, const torch::Tensor& buffer, + const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp index 752823dc7dd2..f57b4394fc99 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp @@ -19,7 +19,7 @@ deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t() _locked_tensors.clear(); } -torch::Tensor deepspeed_pin_tensor_t::alloc(const size_t num_elem, const at::ScalarType& elem_type) +torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem, const at::ScalarType& elem_type) { const auto num_bytes = num_elem * elementSize(elem_type); auto pinned_buffer = ds_page_aligned_alloc(num_bytes, true); @@ -27,9 +27,9 @@ torch::Tensor deepspeed_pin_tensor_t::alloc(const size_t num_elem, const at::Sca _locked_tensors[pinned_buffer] = num_bytes; - auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU); + auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU).requires_grad(false); - return at::from_blob(pinned_buffer, static_cast(num_bytes), options); + return at::from_blob(pinned_buffer, static_cast(num_elem), options); } bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) @@ -43,3 +43,13 @@ bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) return false; } + +bool deepspeed_pin_tensor_t::is_managed(const torch::Tensor& buffer) +{ + auto addr = buffer.data_ptr(); + if (!buffer.is_cpu()){ return false;} + if (_locked_tensors.find(addr) != _locked_tensors.end()) { + return true; + } + return false; +}; diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.h b/csrc/aio/py_lib/deepspeed_pin_tensor.h index 4350a4ac7df6..195696a05833 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.h +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.h @@ -15,13 +15,15 @@ Functionality for managing CPU tensors occupying page-locked memory. #include "deepspeed_py_aio.h" struct deepspeed_pin_tensor_t { - std::map _locked_tensors; + std::map _locked_tensors; deepspeed_pin_tensor_t() = default; ~deepspeed_pin_tensor_t(); - torch::Tensor alloc(const size_t num_elem, const at::ScalarType& elem_type); + torch::Tensor alloc(const long long num_elem, const at::ScalarType& elem_type); bool free(torch::Tensor& locked_tensor); + + bool is_managed(const torch::Tensor& buffer); }; diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index bdf2a858d797..e5e89419269d 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -209,8 +209,9 @@ std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( const long long int file_num_bytes, const bool validate) { + bool is_managed = _pinned_tensor_mgr->is_managed(buffer); return std::make_shared( - read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate); + read_op, buffer, is_managed, fd, filename, file_num_bytes, _num_threads, validate); } int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, @@ -229,7 +230,7 @@ int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes << " != " << num_file_bytes << std::endl; } - assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert(buffer_bytes == num_file_bytes); assert((num_file_bytes % _num_threads) == 0); if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } @@ -288,7 +289,7 @@ int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* return pwrite(buffer, filename, false, true); } -at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const size_t num_elem, +at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const long long int num_elem, const torch::Tensor& example_tensor) { return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index 2974ebe87bfc..e21fd22fc3f7 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -61,7 +61,7 @@ struct deepspeed_io_handle_t { int async_pwrite(const torch::Tensor& buffer, const char* filename); // TODO: Make API's args to be shape and dtype. - torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor); + torch::Tensor new_cpu_locked_tensor(const long long int num_elem, const torch::Tensor& example_tensor); bool free_cpu_locked_tensor(torch::Tensor&); From c13bb10855bef4e0ede03da7a19b73439d51836f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 1 Oct 2024 18:40:51 +0000 Subject: [PATCH 02/13] enabling cpu locked in unittests, and fixing compilation errors --- csrc/gds/py_lib/deepspeed_gds_op.cpp | 3 ++- csrc/gds/py_lib/deepspeed_gds_op.h | 1 + csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 2 +- tests/unit/ops/aio/test_aio.py | 6 +++--- tests/unit/ops/aio/test_gds.py | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp index c370a448e5a2..c0b3a335e268 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.cpp +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -92,12 +92,13 @@ void gds_op_desc_t::remove_buffer_from_registry(const torch::Tensor& buffer) gds_op_desc_t::gds_op_desc_t(const bool read_op, const torch::Tensor& buffer, + const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, const int num_threads, const bool validate) - : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate) + : io_op_desc_t(read_op, buffer,is_managed, fd, filename, file_num_bytes, num_threads, validate) { _contiguous_buffer = _buffer.contiguous(); const int64_t device = _buffer.get_device(); diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h index b7fab64d4054..70c8f7ced4f5 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.h +++ b/csrc/gds/py_lib/deepspeed_gds_op.h @@ -20,6 +20,7 @@ struct gds_op_desc_t : io_op_desc_t { gds_op_desc_t(const bool read_op, const torch::Tensor& buffer, + const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index 15fd516acaae..e65d7cc40cfd 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -107,7 +107,7 @@ std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc( { if (buffer.is_cuda()) { return std::make_shared( - read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate); + read_op, buffer, false, fd, filename, file_num_bytes, _num_threads, validate); } return deepspeed_io_handle_t::_create_io_op_desc( read_op, buffer, fd, filename, file_num_bytes, validate); diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index e6927efc3824..9d4b12a6daf7 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -78,7 +78,7 @@ def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_queue_depth() == QUEUE_DEPTH -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) @pytest.mark.parametrize("single_submit", [True, False]) @pytest.mark.parametrize("overlap_events", [True, False]) class TestRead(DistributedTest): @@ -144,7 +144,7 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap h.free_cpu_locked_tensor(aio_buffer) -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) @pytest.mark.parametrize("single_submit", [True, False]) @pytest.mark.parametrize("overlap_events", [True, False]) class TestWrite(DistributedTest): @@ -213,7 +213,7 @@ def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overla @pytest.mark.sequential -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) @pytest.mark.parametrize("cuda_device", [True, False]) class TestAsyncQueue(DistributedTest): world_size = 1 diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py index 53655994b560..9612cc339876 100644 --- a/tests/unit/ops/aio/test_gds.py +++ b/tests/unit/ops/aio/test_gds.py @@ -54,7 +54,7 @@ def _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, gds_handle, index def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_single_submit() == single_submit assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_thread_count() == 1 assert handle.get_block_size() == BLOCK_SIZE assert handle.get_queue_depth() == QUEUE_DEPTH From b909702a02d8873a80f5958002c55bafb21dadaa Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 18:08:04 +0000 Subject: [PATCH 03/13] passing gds tests --- csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 13 +++++++------ csrc/gds/py_lib/deepspeed_py_gds_handle.h | 6 +++++- tests/unit/ops/aio/test_gds.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index e65d7cc40cfd..79d606b14b3f 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -20,20 +20,21 @@ deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, const bool single_submit, const bool overlap_events, const int num_threads) - : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1) + : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1), + _num_gpu_threads(num_threads) { - _init_cuFile(block_size, queue_depth, num_threads); + _init_cuFile(block_size,queue_depth); } deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); } -void deepspeed_gds_handle_t::_init_cuFile(const int block_size, - const int queue_depth, - const int num_threads) +const int deepspeed_gds_handle_t::get_thread_count() const { return _num_gpu_threads; } + +void deepspeed_gds_handle_t::_init_cuFile(const int block_size, const int queue_depth) { if (deepspeed_gds_handle_t::s_cuFile_init == 0) { std::string depthStr = std::to_string(queue_depth); - std::string threadsStr = std::to_string(num_threads); + std::string threadsStr = std::to_string(_num_gpu_threads); std::string json1 = R"({"execution": {"max_io_queue_depth": )" + depthStr + ", "; std::string json2 = R"("max_request_parallelism": )" + threadsStr + ", "; std::string json3 = R"("max_io_threads": )" + threadsStr + ", "; diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h index f324e6b65e80..3cf49a4db453 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.h +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h @@ -12,6 +12,8 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include "deepspeed_py_io_handle.h" struct deepspeed_gds_handle_t : deepspeed_io_handle_t { + const int _num_gpu_threads; + deepspeed_gds_handle_t(const int block_size, const int queue_depth, const bool single_submit, @@ -29,10 +31,12 @@ struct deepspeed_gds_handle_t : deepspeed_io_handle_t { bool unpin_device_tensor(const torch::Tensor& buffer); - void _init_cuFile(const int block_size, const int queue_length, const int num_threads); + void _init_cuFile(const int block_size, const int queue_depth); void _close_cuFile(); + const int get_thread_count() const; + std::shared_ptr _create_io_op_desc(const bool read_op, const torch::Tensor& buffer, const int fd, diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py index 9612cc339876..53655994b560 100644 --- a/tests/unit/ops/aio/test_gds.py +++ b/tests/unit/ops/aio/test_gds.py @@ -54,7 +54,7 @@ def _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, gds_handle, index def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_single_submit() == single_submit assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == 1 + assert handle.get_thread_count() == IO_PARALLEL assert handle.get_block_size() == BLOCK_SIZE assert handle.get_queue_depth() == QUEUE_DEPTH From b1ee7118434eedb304a002744980f427e0707322 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 18:42:21 +0000 Subject: [PATCH 04/13] renaming all instances of num_threads --- csrc/aio/py_lib/deepspeed_aio_op_desc.cpp | 6 +++--- csrc/aio/py_lib/deepspeed_aio_op_desc.h | 4 ++-- csrc/aio/py_lib/deepspeed_cpu_op.cpp | 6 +++--- csrc/aio/py_lib/deepspeed_cpu_op.h | 2 +- csrc/aio/py_lib/deepspeed_py_aio_handle.cpp | 4 ++-- csrc/aio/py_lib/deepspeed_py_aio_handle.h | 2 +- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 14 +++++++------- csrc/aio/py_lib/deepspeed_py_io_handle.h | 4 ++-- csrc/aio/py_lib/py_ds_aio.cpp | 2 +- csrc/gds/py_lib/deepspeed_gds_op.cpp | 6 +++--- csrc/gds/py_lib/deepspeed_gds_op.h | 2 +- csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 10 +++++----- csrc/gds/py_lib/deepspeed_py_gds_handle.h | 4 ++-- csrc/gds/py_lib/py_ds_gds.cpp | 2 +- 14 files changed, 34 insertions(+), 34 deletions(-) diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp index 63d2ba022997..5abe8b41e1f3 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -13,7 +13,7 @@ io_op_desc_t::io_op_desc_t(const bool read_op, const int fd, const char* filename, const long long int file_num_bytes, - const int num_threads, + const int intra_op_parallelism, const bool validate) : _read_op(read_op), _buffer(buffer), @@ -21,8 +21,8 @@ io_op_desc_t::io_op_desc_t(const bool read_op, _fd(fd), _filename(filename), _file_num_bytes(file_num_bytes), - _num_threads(num_threads), - _num_bytes_per_thread(file_num_bytes / num_threads), + _intra_op_parallelism(intra_op_parallelism), + _num_bytes_per_thread(file_num_bytes / intra_op_parallelism), _validate(validate) { } diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h index 42e55cd2e2a0..e3359708e08a 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.h +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -16,7 +16,7 @@ struct io_op_desc_t { int _fd; const std::string _filename; const long long int _file_num_bytes; - const int _num_threads; + const int _intra_op_parallelism; const long long int _num_bytes_per_thread; torch::Tensor _contiguous_buffer; const bool _validate; @@ -27,7 +27,7 @@ struct io_op_desc_t { const int fd, const char* filename, const long long int file_num_bytes, - const int num_threads, + const int intra_op_parallelism, const bool validate); virtual void run(const int tid, diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp index 22b751eef625..c03b6aae780a 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.cpp +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -13,9 +13,9 @@ cpu_op_desc_t::cpu_op_desc_t(const bool read_op, const int fd, const char* filename, const long long int file_num_bytes, - const int num_threads, + const int intra_op_parallelism, const bool validate) - : io_op_desc_t(read_op, buffer, is_managed, fd, filename, file_num_bytes, num_threads, validate), + : io_op_desc_t(read_op, buffer, is_managed, fd, filename, file_num_bytes, intra_op_parallelism, validate), _cpu_buffer(buffer) { // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory. @@ -61,7 +61,7 @@ void cpu_op_desc_t::run(const int tid, std::unique_ptr& aio_ctxt, deepspeed_aio_config_t* aio_config) { - assert(tid < _num_threads); + assert(tid < _intra_op_parallelism); const auto base_offset = _num_bytes_per_thread * tid; std::unique_ptr xfer_ctxt( diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h index efc60f97f93c..3dd8f7222b23 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.h +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -17,7 +17,7 @@ struct cpu_op_desc_t : io_op_desc_t { const int fd, const char* filename, const long long int file_num_bytes, - const int num_threads, + const int intra_op_parallelism, const bool validate); void run(const int tid, diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index c7ca5e82afde..aed87d0c694d 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -16,8 +16,8 @@ deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads) - : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads) + const int intra_op_parallelism) + : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, intra_op_parallelism) { } diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index eb6b90ea22f0..1398df9a56c9 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -16,7 +16,7 @@ struct deepspeed_aio_handle_t : deepspeed_io_handle_t { const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads); + const int intra_op_parallelism); ~deepspeed_aio_handle_t(); }; diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index e5e89419269d..1e896464d7c2 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -18,16 +18,16 @@ deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads) + const int intra_op_parallelism) : _aio_ctxt(new aio_context(block_size, queue_depth)), _single_submit(single_submit), _overlap_events(overlap_events), - _num_threads(num_threads), + _intra_op_parallelism(intra_op_parallelism), _aio_config(block_size, queue_depth, single_submit, overlap_events, false), _num_pending_ops(0), _pinned_tensor_mgr(new deepspeed_pin_tensor_t()) { - for (auto i = 0; i < num_threads; ++i) { + for (auto i = 0; i < intra_op_parallelism; ++i) { _thread_contexts.push_back(std::make_shared(i, _aio_config)); } @@ -56,7 +56,7 @@ const bool deepspeed_io_handle_t::get_single_submit() const { return _single_sub const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; } -const int deepspeed_io_handle_t::get_thread_count() const { return _num_threads; } +const int deepspeed_io_handle_t::get_thread_count() const { return _intra_op_parallelism; } int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) { @@ -211,7 +211,7 @@ std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( { bool is_managed = _pinned_tensor_mgr->is_managed(buffer); return std::make_shared( - read_op, buffer, is_managed, fd, filename, file_num_bytes, _num_threads, validate); + read_op, buffer, is_managed, fd, filename, file_num_bytes, _intra_op_parallelism, validate); } int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, @@ -231,7 +231,7 @@ int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, << " != " << num_file_bytes << std::endl; } assert(buffer_bytes == num_file_bytes); - assert((num_file_bytes % _num_threads) == 0); + assert((num_file_bytes % _intra_op_parallelism) == 0); if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } @@ -253,7 +253,7 @@ int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, const bool async) { const auto num_write_bytes = static_cast(buffer.nbytes()); - assert((num_write_bytes % _num_threads) == 0); + assert((num_write_bytes % _intra_op_parallelism) == 0); if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index e21fd22fc3f7..9140d468d4f9 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -16,7 +16,7 @@ struct deepspeed_io_handle_t { std::unique_ptr _aio_ctxt; const bool _single_submit; const bool _overlap_events; - const int _num_threads; + const int _intra_op_parallelism; deepspeed_aio_config_t _aio_config; std::vector> _thread_contexts; @@ -28,7 +28,7 @@ struct deepspeed_io_handle_t { const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads); + const int intra_op_parallelism); virtual ~deepspeed_io_handle_t() = 0; diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp index 3171d0c6bf3c..1b0aa17d413f 100644 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "queue_depth"_a = 128, "single_submit"_a = false, "overlap_events"_a = false, - "num_threads"_a = 1) + "intra_op_parallelism"_a = 1) .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp index c0b3a335e268..ed0d6372a274 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.cpp +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -96,9 +96,9 @@ gds_op_desc_t::gds_op_desc_t(const bool read_op, const int fd, const char* filename, const long long int file_num_bytes, - const int num_threads, + const int intra_op_parallelism, const bool validate) - : io_op_desc_t(read_op, buffer,is_managed, fd, filename, file_num_bytes, num_threads, validate) + : io_op_desc_t(read_op, buffer,is_managed, fd, filename, file_num_bytes, intra_op_parallelism, validate) { _contiguous_buffer = _buffer.contiguous(); const int64_t device = _buffer.get_device(); @@ -124,7 +124,7 @@ void gds_op_desc_t::run(const int tid, std::unique_ptr& aio_ctxt, deepspeed_aio_config_t* aio_config) { - assert(tid < _num_threads); + assert(tid < _intra_op_parallelism); check_cudaruntimecall(cudaSetDevice(_buffer.get_device())); int64_t buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr; const auto file_offset = _num_bytes_per_thread * tid; diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h index 70c8f7ced4f5..5fd9540c21c0 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.h +++ b/csrc/gds/py_lib/deepspeed_gds_op.h @@ -24,7 +24,7 @@ struct gds_op_desc_t : io_op_desc_t { const int fd, const char* filename, const long long int file_num_bytes, - const int num_threads, + const int intra_op_parallelism, const bool validate); void run(const int tid, diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index 79d606b14b3f..b33014bb4a05 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -19,22 +19,22 @@ deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads) + const int intra_op_parallelism) : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1), - _num_gpu_threads(num_threads) + _intra_gds_op_parallelism(intra_op_parallelism) { _init_cuFile(block_size,queue_depth); } deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); } -const int deepspeed_gds_handle_t::get_thread_count() const { return _num_gpu_threads; } +const int deepspeed_gds_handle_t::get_thread_count() const { return _intra_gds_op_parallelism; } void deepspeed_gds_handle_t::_init_cuFile(const int block_size, const int queue_depth) { if (deepspeed_gds_handle_t::s_cuFile_init == 0) { std::string depthStr = std::to_string(queue_depth); - std::string threadsStr = std::to_string(_num_gpu_threads); + std::string threadsStr = std::to_string(_intra_gds_op_parallelism); std::string json1 = R"({"execution": {"max_io_queue_depth": )" + depthStr + ", "; std::string json2 = R"("max_request_parallelism": )" + threadsStr + ", "; std::string json3 = R"("max_io_threads": )" + threadsStr + ", "; @@ -108,7 +108,7 @@ std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc( { if (buffer.is_cuda()) { return std::make_shared( - read_op, buffer, false, fd, filename, file_num_bytes, _num_threads, validate); + read_op, buffer, false, fd, filename, file_num_bytes, _intra_op_parallelism, validate); } return deepspeed_io_handle_t::_create_io_op_desc( read_op, buffer, fd, filename, file_num_bytes, validate); diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h index 3cf49a4db453..8e8e8df3d6fe 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.h +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h @@ -12,13 +12,13 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include "deepspeed_py_io_handle.h" struct deepspeed_gds_handle_t : deepspeed_io_handle_t { - const int _num_gpu_threads; + const int _intra_gds_op_parallelism; deepspeed_gds_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads); + const int intra_op_parallelism); ~deepspeed_gds_handle_t(); diff --git a/csrc/gds/py_lib/py_ds_gds.cpp b/csrc/gds/py_lib/py_ds_gds.cpp index 66eb34d4ea8c..14e3eec3fbbb 100644 --- a/csrc/gds/py_lib/py_ds_gds.cpp +++ b/csrc/gds/py_lib/py_ds_gds.cpp @@ -20,7 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "queue_depth"_a = 128, "single_submit"_a = false, "overlap_events"_a = false, - "num_threads"_a = 1) + "intra_op_parallelism"_a = 1) .def("get_block_size", &deepspeed_gds_handle_t::get_block_size) .def("get_queue_depth", &deepspeed_gds_handle_t::get_queue_depth) From ada1b8303b839b70c4f14ca5155cd1b1ba7427a9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 20:44:01 +0000 Subject: [PATCH 05/13] updating function names to match --- csrc/aio/py_lib/deepspeed_aio_thread.cpp | 4 ++-- csrc/aio/py_lib/deepspeed_aio_thread.h | 4 ++-- csrc/aio/py_lib/deepspeed_py_aio_handle.h | 2 +- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 12 ++++++------ csrc/aio/py_lib/deepspeed_py_io_handle.h | 4 ++-- csrc/aio/py_lib/py_ds_aio.cpp | 2 +- csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 2 +- csrc/gds/py_lib/deepspeed_py_gds_handle.h | 2 +- csrc/gds/py_lib/py_ds_gds.cpp | 2 +- docs/_tutorials/deepnvme.md | 4 ++-- 10 files changed, 19 insertions(+), 19 deletions(-) diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 30c3b4914397..25e1df809a85 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -28,7 +28,7 @@ void deepspeed_aio_thread_t::run() { std::unique_lock lock(_work_sync._mutex); - _work_sync._cond_var.wait(lock, + _work_sync._cond_var2.wait(lock, [this] { return (!_work_queue.empty() || _time_to_exit); }); if (!_work_queue.empty()) { next_io_op = _work_queue.front(); @@ -43,7 +43,7 @@ void deepspeed_aio_thread_t::run() std::lock_guard lock(_complete_sync._mutex); _complete_queue.push(next_io_op); } - _complete_sync._cond_var.notify_one(); + _complete_sync._cond_var2.notify_one(); } if (_time_to_exit) { break; } diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h index a192804db13d..ef12b8178fa7 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.h +++ b/csrc/aio/py_lib/deepspeed_aio_thread.h @@ -7,14 +7,14 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ -#include +#include #include #include #include "deepspeed_cpu_op.h" struct thread_sync_t { std::mutex _mutex; - std::condition_variable _cond_var; + std::condition_var2iable _cond_var2; }; struct deepspeed_aio_thread_t { diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index 1398df9a56c9..5af6736afc0f 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -7,7 +7,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ -#include +#include #include #include "deepspeed_py_io_handle.h" diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index 1e896464d7c2..12243c3d902c 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -56,7 +56,7 @@ const bool deepspeed_io_handle_t::get_single_submit() const { return _single_sub const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; } -const int deepspeed_io_handle_t::get_thread_count() const { return _intra_op_parallelism; } +const int deepspeed_io_handle_t::get_intra_op_parallelism() const { return _intra_op_parallelism; } int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) { @@ -137,7 +137,7 @@ void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr lock(ctxt->_work_sync._mutex); ctxt->_work_queue.push(scheduled_op); } - ctxt->_work_sync._cond_var.notify_one(); + ctxt->_work_sync._cond_var2.notify_one(); } _num_pending_ops++; } @@ -147,7 +147,7 @@ std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() std::shared_ptr completed_op = nullptr; for (auto& ctxt : _thread_contexts) { std::unique_lock lock(ctxt->_complete_sync._mutex); - ctxt->_complete_sync._cond_var.wait(lock, + ctxt->_complete_sync._cond_var2.wait(lock, [ctxt] { return !ctxt->_complete_queue.empty(); }); completed_op = ctxt->_complete_queue.front(); ctxt->_complete_queue.pop(); @@ -163,7 +163,7 @@ void deepspeed_io_handle_t::_stop_threads() std::lock_guard lock(ctxt->_work_sync._mutex); ctxt->_time_to_exit = true; } - ctxt->_work_sync._cond_var.notify_one(); + ctxt->_work_sync._cond_var2.notify_one(); } } @@ -192,9 +192,9 @@ bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes) { const auto op_string = read_op ? "Read" : "Write"; - if (num_bytes % get_thread_count()) { + if (num_bytes % get_intra_op_parallelism()) { std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes - << " not divisible by thread count = " << get_thread_count() << std::endl; + << " not divisible by thread count = " << get_intra_op_parallelism() << std::endl; return false; } diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index 9140d468d4f9..3e3bed73b229 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -7,7 +7,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ -#include +#include #include #include "deepspeed_aio_thread.h" #include "deepspeed_pin_tensor.h" @@ -36,7 +36,7 @@ struct deepspeed_io_handle_t { const int get_queue_depth() const; const bool get_single_submit() const; const bool get_overlap_events() const; - const int get_thread_count() const; + const int get_intra_op_parallelism() const; int read(torch::Tensor& buffer, const char* filename, const bool validate); diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp index 1b0aa17d413f..b80fa2d6c8e6 100644 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -33,7 +33,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) .def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit) .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) - .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) + .def("get_intra_op_parallelism", &deepspeed_aio_handle_t::get_intra_op_parallelism) .def("read", &deepspeed_aio_handle_t::read, diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index b33014bb4a05..02bd5c990661 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -28,7 +28,7 @@ deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); } -const int deepspeed_gds_handle_t::get_thread_count() const { return _intra_gds_op_parallelism; } +const int deepspeed_gds_handle_t::get_intra_op_parallelism() const { return _intra_gds_op_parallelism; } void deepspeed_gds_handle_t::_init_cuFile(const int block_size, const int queue_depth) { diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h index 8e8e8df3d6fe..a3c10a4f6467 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.h +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h @@ -35,7 +35,7 @@ struct deepspeed_gds_handle_t : deepspeed_io_handle_t { void _close_cuFile(); - const int get_thread_count() const; + const int get_intra_op_parallelism() const; std::shared_ptr _create_io_op_desc(const bool read_op, const torch::Tensor& buffer, diff --git a/csrc/gds/py_lib/py_ds_gds.cpp b/csrc/gds/py_lib/py_ds_gds.cpp index 14e3eec3fbbb..57bf8d2207c4 100644 --- a/csrc/gds/py_lib/py_ds_gds.cpp +++ b/csrc/gds/py_lib/py_ds_gds.cpp @@ -26,7 +26,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) .def("get_queue_depth", &deepspeed_gds_handle_t::get_queue_depth) .def("get_single_submit", &deepspeed_gds_handle_t::get_single_submit) .def("get_overlap_events", &deepspeed_gds_handle_t::get_overlap_events) - .def("get_thread_count", &deepspeed_gds_handle_t::get_thread_count) + .def("get_intra_op_parallelism", &deepspeed_gds_handle_t::get_intra_op_parallelism) .def("read", &deepspeed_gds_handle_t::read, diff --git a/docs/_tutorials/deepnvme.md b/docs/_tutorials/deepnvme.md index 70c6ac097963..f31621999a59 100644 --- a/docs/_tutorials/deepnvme.md +++ b/docs/_tutorials/deepnvme.md @@ -50,7 +50,7 @@ Type "help", "copyright", "credits" or "license" for more information. >>> h = AsyncIOBuilder().load().aio_handle() >>> h. h.async_pread( h.free_cpu_locked_tensor( h.get_overlap_events( h.get_single_submit( h.new_cpu_locked_tensor( h.pwrite( h.sync_pread( h.wait( -h.async_pwrite( h.get_block_size( h.get_queue_depth( h.get_thread_count( h.pread( h.read( h.sync_pwrite( h.write( +h.async_pwrite( h.get_block_size( h.get_queue_depth( h.get_intra_op_parallelism( h.pread( h.read( h.sync_pwrite( h.write( ``` The APIs of interest for performing I/O operations are those named with `pread` and `pwrite` substrings. For brevity, we will focus on the file write APIs, namely `sync_pwrite`, `async_pwrite`, and `pwrite`. We will discuss only `sync_pwrite` and `async_pwrite` below because they are specializations of `pwrite`. @@ -292,6 +292,6 @@ Function | Description |---|---| get_queue_depth | Return queue depth setting | get_single_submit | Return whether single_submit is enabled | -get_thread_count | Return I/O parallelism degree | +get_intra_op_parallelism | Return I/O parallelism degree | get_block_size | Return I/O block size setting | get_overlap_events | Return whether overlap_event is enabled | From 1cb88ce5e4b47ec7cde22ab982730527857b82d5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 20:45:09 +0000 Subject: [PATCH 06/13] fix formatting --- csrc/aio/py_lib/deepspeed_aio_thread.cpp | 2 +- csrc/aio/py_lib/deepspeed_cpu_op.cpp | 9 ++++++++- csrc/aio/py_lib/deepspeed_pin_tensor.cpp | 9 ++++----- csrc/aio/py_lib/deepspeed_pin_tensor.h | 2 +- csrc/aio/py_lib/deepspeed_py_aio_handle.cpp | 6 +++++- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 2 +- csrc/aio/py_lib/deepspeed_py_io_handle.h | 3 ++- csrc/gds/py_lib/deepspeed_gds_op.cpp | 9 ++++++++- csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 9 ++++++--- tests/unit/ops/aio/test_aio.py | 2 +- 10 files changed, 37 insertions(+), 16 deletions(-) diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 25e1df809a85..8c51087a0b5d 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -29,7 +29,7 @@ void deepspeed_aio_thread_t::run() { std::unique_lock lock(_work_sync._mutex); _work_sync._cond_var2.wait(lock, - [this] { return (!_work_queue.empty() || _time_to_exit); }); + [this] { return (!_work_queue.empty() || _time_to_exit); }); if (!_work_queue.empty()) { next_io_op = _work_queue.front(); _work_queue.pop(); diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp index c03b6aae780a..170d6da75987 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.cpp +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -15,7 +15,14 @@ cpu_op_desc_t::cpu_op_desc_t(const bool read_op, const long long int file_num_bytes, const int intra_op_parallelism, const bool validate) - : io_op_desc_t(read_op, buffer, is_managed, fd, filename, file_num_bytes, intra_op_parallelism, validate), + : io_op_desc_t(read_op, + buffer, + is_managed, + fd, + filename, + file_num_bytes, + intra_op_parallelism, + validate), _cpu_buffer(buffer) { // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory. diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp index f57b4394fc99..8d337c14d34c 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp @@ -19,7 +19,8 @@ deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t() _locked_tensors.clear(); } -torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem, const at::ScalarType& elem_type) +torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem, + const at::ScalarType& elem_type) { const auto num_bytes = num_elem * elementSize(elem_type); auto pinned_buffer = ds_page_aligned_alloc(num_bytes, true); @@ -47,9 +48,7 @@ bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) bool deepspeed_pin_tensor_t::is_managed(const torch::Tensor& buffer) { auto addr = buffer.data_ptr(); - if (!buffer.is_cpu()){ return false;} - if (_locked_tensors.find(addr) != _locked_tensors.end()) { - return true; - } + if (!buffer.is_cpu()) { return false; } + if (_locked_tensors.find(addr) != _locked_tensors.end()) { return true; } return false; }; diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.h b/csrc/aio/py_lib/deepspeed_pin_tensor.h index 195696a05833..95d2f1f92e10 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.h +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.h @@ -24,6 +24,6 @@ struct deepspeed_pin_tensor_t { torch::Tensor alloc(const long long num_elem, const at::ScalarType& elem_type); bool free(torch::Tensor& locked_tensor); - + bool is_managed(const torch::Tensor& buffer); }; diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index aed87d0c694d..2b1093e99286 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -17,7 +17,11 @@ deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, const bool single_submit, const bool overlap_events, const int intra_op_parallelism) - : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, intra_op_parallelism) + : deepspeed_io_handle_t(block_size, + queue_depth, + single_submit, + overlap_events, + intra_op_parallelism) { } diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index 12243c3d902c..31083e36ed8c 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -148,7 +148,7 @@ std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() for (auto& ctxt : _thread_contexts) { std::unique_lock lock(ctxt->_complete_sync._mutex); ctxt->_complete_sync._cond_var2.wait(lock, - [ctxt] { return !ctxt->_complete_queue.empty(); }); + [ctxt] { return !ctxt->_complete_queue.empty(); }); completed_op = ctxt->_complete_queue.front(); ctxt->_complete_queue.pop(); } diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index 3e3bed73b229..3016da59d86d 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -61,7 +61,8 @@ struct deepspeed_io_handle_t { int async_pwrite(const torch::Tensor& buffer, const char* filename); // TODO: Make API's args to be shape and dtype. - torch::Tensor new_cpu_locked_tensor(const long long int num_elem, const torch::Tensor& example_tensor); + torch::Tensor new_cpu_locked_tensor(const long long int num_elem, + const torch::Tensor& example_tensor); bool free_cpu_locked_tensor(torch::Tensor&); diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp index ed0d6372a274..c44b4655a9a9 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.cpp +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -98,7 +98,14 @@ gds_op_desc_t::gds_op_desc_t(const bool read_op, const long long int file_num_bytes, const int intra_op_parallelism, const bool validate) - : io_op_desc_t(read_op, buffer,is_managed, fd, filename, file_num_bytes, intra_op_parallelism, validate) + : io_op_desc_t(read_op, + buffer, + is_managed, + fd, + filename, + file_num_bytes, + intra_op_parallelism, + validate) { _contiguous_buffer = _buffer.contiguous(); const int64_t device = _buffer.get_device(); diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index 02bd5c990661..d093187597d4 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -21,14 +21,17 @@ deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, const bool overlap_events, const int intra_op_parallelism) : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1), - _intra_gds_op_parallelism(intra_op_parallelism) + _intra_gds_op_parallelism(intra_op_parallelism) { - _init_cuFile(block_size,queue_depth); + _init_cuFile(block_size, queue_depth); } deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); } -const int deepspeed_gds_handle_t::get_intra_op_parallelism() const { return _intra_gds_op_parallelism; } +const int deepspeed_gds_handle_t::get_intra_op_parallelism() const +{ + return _intra_gds_op_parallelism; +} void deepspeed_gds_handle_t::_init_cuFile(const int block_size, const int queue_depth) { diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index 9d4b12a6daf7..f6d175ce67bc 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -144,7 +144,7 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap h.free_cpu_locked_tensor(aio_buffer) -@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) @pytest.mark.parametrize("single_submit", [True, False]) @pytest.mark.parametrize("overlap_events", [True, False]) class TestWrite(DistributedTest): From f5528daf5147b2dcb56e77a38e08d6068c7b8d2f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 21:53:47 +0000 Subject: [PATCH 07/13] variable name change to fix compilation --- csrc/aio/py_lib/deepspeed_aio_thread.cpp | 4 ++-- csrc/aio/py_lib/deepspeed_aio_thread.h | 4 ++-- csrc/aio/py_lib/deepspeed_py_aio_handle.h | 2 +- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 6 +++--- csrc/aio/py_lib/deepspeed_py_io_handle.h | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 8c51087a0b5d..3b7b048f16bd 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -28,7 +28,7 @@ void deepspeed_aio_thread_t::run() { std::unique_lock lock(_work_sync._mutex); - _work_sync._cond_var2.wait(lock, + _work_sync._cond_var.wait(lock, [this] { return (!_work_queue.empty() || _time_to_exit); }); if (!_work_queue.empty()) { next_io_op = _work_queue.front(); @@ -43,7 +43,7 @@ void deepspeed_aio_thread_t::run() std::lock_guard lock(_complete_sync._mutex); _complete_queue.push(next_io_op); } - _complete_sync._cond_var2.notify_one(); + _complete_sync._cond_var.notify_one(); } if (_time_to_exit) { break; } diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h index ef12b8178fa7..a192804db13d 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.h +++ b/csrc/aio/py_lib/deepspeed_aio_thread.h @@ -7,14 +7,14 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ -#include +#include #include #include #include "deepspeed_cpu_op.h" struct thread_sync_t { std::mutex _mutex; - std::condition_var2iable _cond_var2; + std::condition_variable _cond_var; }; struct deepspeed_aio_thread_t { diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index 5af6736afc0f..1398df9a56c9 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -7,7 +7,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ -#include +#include #include #include "deepspeed_py_io_handle.h" diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index 31083e36ed8c..8bd7966ba086 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -137,7 +137,7 @@ void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr lock(ctxt->_work_sync._mutex); ctxt->_work_queue.push(scheduled_op); } - ctxt->_work_sync._cond_var2.notify_one(); + ctxt->_work_sync._cond_var.notify_one(); } _num_pending_ops++; } @@ -147,7 +147,7 @@ std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() std::shared_ptr completed_op = nullptr; for (auto& ctxt : _thread_contexts) { std::unique_lock lock(ctxt->_complete_sync._mutex); - ctxt->_complete_sync._cond_var2.wait(lock, + ctxt->_complete_sync._cond_var.wait(lock, [ctxt] { return !ctxt->_complete_queue.empty(); }); completed_op = ctxt->_complete_queue.front(); ctxt->_complete_queue.pop(); @@ -163,7 +163,7 @@ void deepspeed_io_handle_t::_stop_threads() std::lock_guard lock(ctxt->_work_sync._mutex); ctxt->_time_to_exit = true; } - ctxt->_work_sync._cond_var2.notify_one(); + ctxt->_work_sync._cond_var.notify_one(); } } diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index 3016da59d86d..8e649be1e4c2 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -7,7 +7,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ -#include +#include #include #include "deepspeed_aio_thread.h" #include "deepspeed_pin_tensor.h" From f576d291a7675df1b85e4201b58ea1013b38938c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 22:02:10 +0000 Subject: [PATCH 08/13] formatting --- csrc/aio/py_lib/deepspeed_aio_thread.cpp | 2 +- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 3b7b048f16bd..30c3b4914397 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -29,7 +29,7 @@ void deepspeed_aio_thread_t::run() { std::unique_lock lock(_work_sync._mutex); _work_sync._cond_var.wait(lock, - [this] { return (!_work_queue.empty() || _time_to_exit); }); + [this] { return (!_work_queue.empty() || _time_to_exit); }); if (!_work_queue.empty()) { next_io_op = _work_queue.front(); _work_queue.pop(); diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index 8bd7966ba086..596c2427feed 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -148,7 +148,7 @@ std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() for (auto& ctxt : _thread_contexts) { std::unique_lock lock(ctxt->_complete_sync._mutex); ctxt->_complete_sync._cond_var.wait(lock, - [ctxt] { return !ctxt->_complete_queue.empty(); }); + [ctxt] { return !ctxt->_complete_queue.empty(); }); completed_op = ctxt->_complete_queue.front(); ctxt->_complete_queue.pop(); } From 5a47bf3a5a932ba9fac38401ae374d456c5c479c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 3 Oct 2024 16:03:41 +0000 Subject: [PATCH 09/13] update references in tutorial --- docs/_tutorials/deepnvme.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/_tutorials/deepnvme.md b/docs/_tutorials/deepnvme.md index f31621999a59..4ed528412eae 100644 --- a/docs/_tutorials/deepnvme.md +++ b/docs/_tutorials/deepnvme.md @@ -107,7 +107,7 @@ Similar safety problems apply to reading the destination tensor of a non-blockin ### Parallel File Write -An important DeepNVMe optimization is the ability to parallelize individual I/O operations. This optimization is enabled by specifying the desired parallelism degree when constructing a DeepNVMe handle. Subsequent I/O operations with that handle are automatically parallelized over the requested number of host or device threads, as appropriate. I/O parallelism is composable with either the blocking or non-blocking I/O APIs. The example below illustrates 4-way parallelism of a file write using `async_pwrite`. Note the use of `num_threads` argument to specify the desired parallelism degree in handle creation. +An important DeepNVMe optimization is the ability to parallelize individual I/O operations. This optimization is enabled by specifying the desired parallelism degree when constructing a DeepNVMe handle. Subsequent I/O operations with that handle are automatically parallelized over the requested number of host or device threads, as appropriate. I/O parallelism is composable with either the blocking or non-blocking I/O APIs. The example below illustrates 4-way parallelism of a file write using `async_pwrite`. Note the use of `intra_op_parallelism` argument to specify the desired parallelism degree in handle creation. ```bash >>> import os @@ -116,7 +116,7 @@ False >>> import torch >>> t=torch.empty(1024**3, dtype=torch.uint8).cuda() >>> from deepspeed.ops.op_builder import AsyncIOBuilder ->>> h = AsyncIOBuilder().load().aio_handle(num_threads=4) +>>> h = AsyncIOBuilder().load().aio_handle(intra_op_parallelism=4) >>> h.async_pwrite(t,'/local_nvme/test_1GB.pt') >>> h.wait() 1 @@ -188,7 +188,7 @@ This tutorial has been significantly improved by feedback from [Guanhua Wang](ht ## Appendix ### Advanced Handle Creation -Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, PCIe, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `num_threads`. The `aio_handle` constructor parameters and default values are illustrated below: +Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, PCIe, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `intra_op_parallelism`. The `aio_handle` constructor parameters and default values are illustrated below: ```bash >>> from deepspeed.ops.op_builder import AsyncIOBuilder >>> help(AsyncIOBuilder().load().aio_handle()) @@ -203,7 +203,7 @@ class aio_handle(pybind11_builtins.pybind11_object) | Methods defined here: | | __init__(...) - | __init__(self: async_io.aio_handle, block_size: int = 1048576, queue_depth: int = 128, single_submit: bool = False, overlap_events: bool = False, num_threads: int = 1) -> None + | __init__(self: async_io.aio_handle, block_size: int = 1048576, queue_depth: int = 128, single_submit: bool = False, overlap_events: bool = False, intra_op_parallelism: int = 1) -> None | | AIO handle constructor ``` @@ -219,7 +219,7 @@ Best performance (GB/sec): read = 3.69, write = 3.18 "aio": { "single_submit": "false", "overlap_events": "true", - "num_threads": 8, + "intra_op_parallelism": 8, "queue_depth": 32, "block_size": 1048576 } @@ -233,7 +233,7 @@ The above tuning was executed on a Lambda workstation equipped with two NVIDIA A queue_depth=32, single_submit=False, overlap_events=True, - num_threads=8) + intra_op_parallelism=8) ``` From 884c0fd6ded5d936f8491889937d02b7c10488dc Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 8 Oct 2024 21:30:58 -0400 Subject: [PATCH 10/13] async_io operator for CPU accelerator --- accelerator/cpu_accelerator.py | 6 +- csrc/aio/py_lib/deepspeed_aio_op_desc.cpp | 2 - csrc/aio/py_lib/deepspeed_aio_op_desc.h | 2 - csrc/aio/py_lib/deepspeed_cpu_op.cpp | 87 +++++++++++-------- csrc/aio/py_lib/deepspeed_cpu_op.h | 8 +- csrc/aio/py_lib/deepspeed_pin_tensor.cpp | 18 ++-- csrc/aio/py_lib/deepspeed_pin_tensor.h | 1 + csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 11 ++- csrc/gds/py_lib/deepspeed_gds_op.cpp | 10 +-- csrc/gds/py_lib/deepspeed_gds_op.h | 1 - csrc/gds/py_lib/deepspeed_py_gds_handle.cpp | 2 +- op_builder/cpu/__init__.py | 1 + op_builder/cpu/async_io.py | 92 +++++++++++++++++++++ tests/unit/ops/aio/test_aio.py | 78 ++++++++--------- tests/unit/ops/aio/test_gds.py | 2 +- 15 files changed, 220 insertions(+), 101 deletions(-) create mode 100644 op_builder/cpu/async_io.py diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index d4fcbb0b1e3e..1e4335b19292 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -301,9 +301,9 @@ def get_op_builder(self, class_name): # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path from op_builder import __deepspeed__ # noqa: F401 # type: ignore - from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder except ImportError: - from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from deepspeed.ops.op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder if class_name == "CCLCommBuilder": return CCLCommBuilder @@ -313,6 +313,8 @@ def get_op_builder(self, class_name): return FusedAdamBuilder elif class_name == "CPUAdamBuilder": return CPUAdamBuilder + elif class_name == "AsyncIOBuilder": + return AsyncIOBuilder else: # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests return NotImplementedBuilder diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp index 5abe8b41e1f3..bbdf7e7d2321 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -9,7 +9,6 @@ using namespace std; io_op_desc_t::io_op_desc_t(const bool read_op, const torch::Tensor& buffer, - const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, @@ -17,7 +16,6 @@ io_op_desc_t::io_op_desc_t(const bool read_op, const bool validate) : _read_op(read_op), _buffer(buffer), - _is_managed(is_managed), _fd(fd), _filename(filename), _file_num_bytes(file_num_bytes), diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h index e3359708e08a..8b7b6f577346 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.h +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -12,7 +12,6 @@ struct io_op_desc_t { const bool _read_op; torch::Tensor _buffer; - const bool _is_managed; int _fd; const std::string _filename; const long long int _file_num_bytes; @@ -23,7 +22,6 @@ struct io_op_desc_t { io_op_desc_t(const bool read_op, const torch::Tensor& buffer, - const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp index 170d6da75987..24d207575d68 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.cpp +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -4,40 +4,30 @@ // DeepSpeed Team #include "deepspeed_cpu_op.h" +#include "deepspeed_pin_tensor.h" using namespace std; -cpu_op_desc_t::cpu_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const bool is_managed, - const int fd, - const char* filename, - const long long int file_num_bytes, - const int intra_op_parallelism, - const bool validate) - : io_op_desc_t(read_op, - buffer, - is_managed, - fd, - filename, - file_num_bytes, - intra_op_parallelism, - validate), - _cpu_buffer(buffer) +cpu_op_desc_t::cpu_op_desc_t( + const bool read_op, + const torch::Tensor& buffer, + const std::unique_ptr& pinned_tensor_mgr, + const int fd, + const char* filename, + const long long int file_num_bytes, + const int intra_op_parallelism, + const bool validate) + : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, intra_op_parallelism, validate), + _cpu_buffer(buffer), + _pinned_tensor_mgr(pinned_tensor_mgr), + _is_managed_bounce_buffer(false) { // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory. - _use_bounce_buffer = !(_buffer.is_cpu() && (_buffer.is_pinned() || _is_managed)); + _use_bounce_buffer = + !(_buffer.is_cpu() && (_buffer.is_pinned() || _pinned_tensor_mgr->is_managed(_buffer))); if (_use_bounce_buffer) { - if (_read_op) { - auto options = torch::TensorOptions() - .dtype(_buffer.dtype()) - .layout(_buffer.layout()) - .device(torch::kCPU) - .requires_grad(false); - _cpu_buffer = torch::empty(_buffer.numel(), options).pin_memory(); - } else { - _cpu_buffer = _buffer.to(torch::kCPU).pin_memory(); - } + _alloc_bounce_buffer(); + if (!_read_op) { _cpu_buffer.copy_(_buffer); } } _contiguous_buffer = _cpu_buffer.contiguous(); } @@ -46,16 +36,20 @@ char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_pt void cpu_op_desc_t::finish() { - if (_read_op && _use_bounce_buffer) { - if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } - if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); } - if (_buffer.is_cpu()) { _buffer.copy_(_cpu_buffer); } + if (_use_bounce_buffer) { + if (_read_op) { + if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } + if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); } + if (_buffer.is_cpu()) { _buffer.copy_(_cpu_buffer); } #if defined(__ENABLE_CANN__) - if (torch_npu::utils::is_npu(_buffer)) { - auto device = at::Device("npu:0"); - _buffer.copy_(_cpu_buffer.to(device)); - } + if (torch_npu::utils::is_npu(_buffer)) { + auto device = at::Device("npu:0"); + _buffer.copy_(_cpu_buffer.to(device)); + } #endif + } + + _free_bounce_buffer(); } } @@ -80,3 +74,24 @@ void cpu_op_desc_t::run(const int tid, do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr); } } + +void cpu_op_desc_t::_alloc_bounce_buffer() +{ + auto options = torch::TensorOptions() + .dtype(_buffer.dtype()) + .layout(_buffer.layout()) + .device(torch::kCPU) + .requires_grad(false); + +#if defined(__CUDA_ARCH__) + _cpu_buffer = torch::empty(_buffer.numel(), options).pin_memory(); +#else + _is_managed_bounce_buffer = true; + _cpu_buffer = _pinned_tensor_mgr->alloc(_buffer.numel(), options); +#endif +} + +void cpu_op_desc_t::_free_bounce_buffer() +{ + if (_is_managed_bounce_buffer) { _pinned_tensor_mgr->free(_cpu_buffer); } +} diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h index 3dd8f7222b23..09c4cf261907 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.h +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -6,14 +6,17 @@ #include #include #include "deepspeed_aio_op_desc.h" +// #include "deepspeed_pin_tensor.h" struct cpu_op_desc_t : io_op_desc_t { torch::Tensor _cpu_buffer; bool _use_bounce_buffer; + bool _is_managed_bounce_buffer; + const std::unique_ptr& _pinned_tensor_mgr; cpu_op_desc_t(const bool read_op, const torch::Tensor& buffer, - const bool is_managed, + const std::unique_ptr& pinned_tensor_mgr, const int fd, const char* filename, const long long int file_num_bytes, @@ -29,4 +32,7 @@ struct cpu_op_desc_t : io_op_desc_t { void validate(); void finish(); + + void _alloc_bounce_buffer(); + void _free_bounce_buffer(); }; diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp index 8d337c14d34c..7f484ba95740 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp @@ -19,20 +19,26 @@ deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t() _locked_tensors.clear(); } -torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem, - const at::ScalarType& elem_type) +torch::Tensor deepspeed_pin_tensor_t::alloc(const int64_t num_elem, + const torch::TensorOptions& options) { - const auto num_bytes = num_elem * elementSize(elem_type); + const auto scalar_dtype = torch::typeMetaToScalarType(options.dtype()); + const auto num_bytes = num_elem * torch::elementSize(scalar_dtype); auto pinned_buffer = ds_page_aligned_alloc(num_bytes, true); assert(nullptr != pinned_buffer); _locked_tensors[pinned_buffer] = num_bytes; - auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU).requires_grad(false); - return at::from_blob(pinned_buffer, static_cast(num_elem), options); } +torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem, + const at::ScalarType& elem_type) +{ + auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU).requires_grad(false); + return alloc(num_elem, options); +} + bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) { auto addr = locked_tensor.data_ptr(); @@ -47,8 +53,8 @@ bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) bool deepspeed_pin_tensor_t::is_managed(const torch::Tensor& buffer) { - auto addr = buffer.data_ptr(); if (!buffer.is_cpu()) { return false; } + auto addr = buffer.data_ptr(); if (_locked_tensors.find(addr) != _locked_tensors.end()) { return true; } return false; }; diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.h b/csrc/aio/py_lib/deepspeed_pin_tensor.h index 95d2f1f92e10..c298c3d78013 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.h +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.h @@ -22,6 +22,7 @@ struct deepspeed_pin_tensor_t { ~deepspeed_pin_tensor_t(); torch::Tensor alloc(const long long num_elem, const at::ScalarType& elem_type); + torch::Tensor alloc(const int64_t num_elem, const torch::TensorOptions& options); bool free(torch::Tensor& locked_tensor); diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index 596c2427feed..d0b5a7a418f9 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -209,9 +209,14 @@ std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( const long long int file_num_bytes, const bool validate) { - bool is_managed = _pinned_tensor_mgr->is_managed(buffer); - return std::make_shared( - read_op, buffer, is_managed, fd, filename, file_num_bytes, _intra_op_parallelism, validate); + return std::make_shared(read_op, + buffer, + _pinned_tensor_mgr, + fd, + filename, + file_num_bytes, + _intra_op_parallelism, + validate); } int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp index c44b4655a9a9..dae2eef21c6f 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.cpp +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -92,20 +92,12 @@ void gds_op_desc_t::remove_buffer_from_registry(const torch::Tensor& buffer) gds_op_desc_t::gds_op_desc_t(const bool read_op, const torch::Tensor& buffer, - const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, const int intra_op_parallelism, const bool validate) - : io_op_desc_t(read_op, - buffer, - is_managed, - fd, - filename, - file_num_bytes, - intra_op_parallelism, - validate) + : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, intra_op_parallelism, validate) { _contiguous_buffer = _buffer.contiguous(); const int64_t device = _buffer.get_device(); diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h index 5fd9540c21c0..c9d4c076f189 100644 --- a/csrc/gds/py_lib/deepspeed_gds_op.h +++ b/csrc/gds/py_lib/deepspeed_gds_op.h @@ -20,7 +20,6 @@ struct gds_op_desc_t : io_op_desc_t { gds_op_desc_t(const bool read_op, const torch::Tensor& buffer, - const bool is_managed, const int fd, const char* filename, const long long int file_num_bytes, diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp index d093187597d4..43705939dc3e 100644 --- a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -111,7 +111,7 @@ std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc( { if (buffer.is_cuda()) { return std::make_shared( - read_op, buffer, false, fd, filename, file_num_bytes, _intra_op_parallelism, validate); + read_op, buffer, fd, filename, file_num_bytes, _intra_op_parallelism, validate); } return deepspeed_io_handle_t::_create_io_op_desc( read_op, buffer, fd, filename, file_num_bytes, validate); diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py index 30238add3f90..7084db8469f1 100644 --- a/op_builder/cpu/__init__.py +++ b/op_builder/cpu/__init__.py @@ -8,3 +8,4 @@ from .fused_adam import FusedAdamBuilder from .cpu_adam import CPUAdamBuilder from .no_impl import NotImplementedBuilder +from .async_io import AsyncIOBuilder diff --git a/op_builder/cpu/async_io.py b/op_builder/cpu/async_io.py new file mode 100644 index 000000000000..56eb71303d78 --- /dev/null +++ b/op_builder/cpu/async_io.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import distutils.spawn +import subprocess + +from .builder import CPUOpBuilder + + +class AsyncIOBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_AIO" + NAME = "async_io" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.aio.{self.NAME}_op' + + def lib_sources(self): + src_list = [ + 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', 'csrc/aio/py_lib/deepspeed_py_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', + 'csrc/aio/common/deepspeed_aio_utils.cpp', 'csrc/aio/common/deepspeed_aio_common.cpp', + 'csrc/aio/common/deepspeed_aio_types.cpp', 'csrc/aio/py_lib/deepspeed_cpu_op.cpp', + 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', 'csrc/aio/py_lib/deepspeed_py_copy.cpp', + 'csrc/aio/py_lib/deepspeed_pin_tensor.cpp' + ] + return src_list + + def sources(self): + return self.lib_sources() + ['csrc/aio/py_lib/py_ds_aio.cpp'] + + def include_paths(self): + return ['csrc/aio/py_lib', 'csrc/aio/common'] + + def cxx_args(self): + # -O0 for improved debugging, since performance is bound by I/O + args = super().cxx_args() + import torch + TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) + if not (TORCH_MAJOR >= 2 and TORCH_MINOR >= 1): + args.remove('-std=c++17') + args.append('-std=c++14') + args += ['-Wall', '-O0', '-shared', '-fPIC', '-Wno-reorder'] + return args + + def extra_ldflags(self): + return ['-laio', '-fopenmp'] + + def check_for_libaio_pkg(self): + libs = dict( + dpkg=["-l", "libaio-dev", "apt"], + pacman=["-Q", "libaio", "pacman"], + rpm=["-q", "libaio-devel", "yum"], + ) + + found = False + for pkgmgr, data in libs.items(): + flag, lib, tool = data + path = distutils.spawn.find_executable(pkgmgr) + if path is not None: + cmd = [pkgmgr, flag, lib] + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.wait() == 0: + found = True + else: + self.warning(f"{self.NAME}: please install the {lib} package with {tool}") + break + return found + + def is_compatible(self, verbose=False): + # Check for the existence of libaio by using distutils + # to compile and link a test program that calls io_submit, + # which is a function provided by libaio that is used in the async_io op. + # If needed, one can define -I and -L entries in CFLAGS and LDFLAGS + # respectively to specify the directories for libaio.h and libaio.so. + aio_compatible = self.has_function('io_submit', ('aio', )) + if verbose and not aio_compatible: + self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.") + + # Check for the libaio package via known package managers + # to print suggestions on which package to install. + self.check_for_libaio_pkg() + + self.warning( + "If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found." + ) + return super().is_compatible(verbose) and aio_compatible diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index f6d175ce67bc..a074cfca317f 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -23,12 +23,10 @@ pytest.skip('Skip tests since async-io is not compatible', allow_module_level=True) -def _skip_for_invalid_environment(use_cuda_device=True, use_cuda_pinned_tensor=True): - if not get_accelerator().is_available(): - if use_cuda_device: - pytest.skip("GPU tensors only supported in CUDA environments.") +def _skip_for_invalid_environment(use_cuda_pinned_tensor=True): + if get_accelerator().device_name() != 'cuda': if use_cuda_pinned_tensor: - pytest.skip("CUDA-pinned tensors only supported in CUDA environments.") + pytest.skip("torch.pin_memory is only supported in CUDA environments.") def _get_local_rank(): @@ -52,13 +50,13 @@ def _get_test_write_file(tmpdir, index): return os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') -def _get_test_write_file_and_cuda_buffer(tmpdir, ref_buffer, index=0): +def _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffer, index=0): test_file = _get_test_write_file(tmpdir, index) test_buffer = get_accelerator().ByteTensor(list(ref_buffer)) return test_file, test_buffer -def _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, aio_handle=None, index=0): +def _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer, aio_handle=None, index=0): test_file = _get_test_write_file(tmpdir, index) if aio_handle is None: test_buffer = get_accelerator().pin_memory(torch.ByteTensor(list(ref_buffer))) @@ -73,7 +71,7 @@ def _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, aio_handle=None, ind def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_single_submit() == single_submit assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_intra_op_parallelism() == IO_PARALLEL assert handle.get_block_size() == BLOCK_SIZE assert handle.get_queue_depth() == QUEUE_DEPTH @@ -89,12 +87,15 @@ class TestRead(DistributedTest): init_distributed = False set_dist_env = False - def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events): - _skip_for_invalid_environment(use_cuda_device=False, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) - if use_cuda_pinned_tensor: + if use_unpinned_tensor: + aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + elif use_cuda_pinned_tensor: aio_buffer = get_accelerator().pin_memory(torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu')) else: aio_buffer = h.new_cpu_locked_tensor(IO_SIZE, torch.empty(0, dtype=torch.uint8)) @@ -112,14 +113,14 @@ def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, over if not use_cuda_pinned_tensor: h.free_cpu_locked_tensor(aio_buffer) - @pytest.mark.parametrize("cuda_device", [True, False]) - def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) use_cpu_locked_tensor = False h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) - if cuda_device: + if use_unpinned_tensor: aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) elif use_cuda_pinned_tensor: aio_buffer = get_accelerator().pin_memory(torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu')) @@ -155,16 +156,19 @@ class TestWrite(DistributedTest): init_distributed = False set_dist_env = False - def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events): - _skip_for_invalid_environment(use_cuda_device=False, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_file, ref_buffer = _do_ref_write(tmpdir) h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + if use_unpinned_tensor: + aio_file, aio_buffer = _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffer) if use_cuda_pinned_tensor: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer) else: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, h) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer, h) _validate_handle_state(h, single_submit, overlap_events) @@ -179,20 +183,20 @@ def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, ove filecmp.clear_cache() assert filecmp.cmp(ref_file, aio_file, shallow=False) - @pytest.mark.parametrize("cuda_device", [True, False]) - def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_file, ref_buffer = _do_ref_write(tmpdir) h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) use_cpu_locked_tensor = False - if cuda_device: - aio_file, aio_buffer = _get_test_write_file_and_cuda_buffer(tmpdir, ref_buffer) + if use_unpinned_tensor: + aio_file, aio_buffer = _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffer) elif use_cuda_pinned_tensor: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer) else: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, h) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer, h) use_cpu_locked_tensor = True _validate_handle_state(h, single_submit, overlap_events) @@ -214,7 +218,7 @@ def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overla @pytest.mark.sequential @pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) -@pytest.mark.parametrize("cuda_device", [True, False]) +@pytest.mark.parametrize("use_unpinned_tensor", [True, False]) class TestAsyncQueue(DistributedTest): world_size = 1 requires_cuda_env = False @@ -223,8 +227,8 @@ class TestAsyncQueue(DistributedTest): set_dist_env = False @pytest.mark.parametrize("async_queue", [2, 3]) - def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_files = [] for i in range(async_queue): @@ -236,7 +240,7 @@ def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) use_cpu_locked_tensor = False - if cuda_device: + if use_unpinned_tensor: aio_buffers = [ torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) for _ in range(async_queue) @@ -270,8 +274,8 @@ def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): h.free_cpu_locked_tensor(t) @pytest.mark.parametrize("async_queue", [2, 3]) - def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_files = [] ref_buffers = [] @@ -287,16 +291,16 @@ def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, cuda_device): aio_files = [] aio_buffers = [] for i in range(async_queue): - if cuda_device: - f, buf = _get_test_write_file_and_cuda_buffer(tmpdir, ref_buffers[i], i) + if use_unpinned_tensor: + f, buf = _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffers[i], i) elif use_cuda_pinned_tensor: - f, buf = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffers[i], None, i) + f, buf = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffers[i], None, i) else: - f, buf = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffers[i], h, i) + f, buf = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffers[i], h, i) aio_files.append(f) aio_buffers.append(buf) - use_cpu_locked_tensor = not (cuda_device or use_cuda_pinned_tensor) + use_cpu_locked_tensor = not (use_unpinned_tensor or use_cuda_pinned_tensor) _validate_handle_state(h, single_submit, overlap_events) diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py index 53655994b560..e94d42cd22af 100644 --- a/tests/unit/ops/aio/test_gds.py +++ b/tests/unit/ops/aio/test_gds.py @@ -54,7 +54,7 @@ def _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, gds_handle, index def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_single_submit() == single_submit assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_intra_op_parallelism() == IO_PARALLEL assert handle.get_block_size() == BLOCK_SIZE assert handle.get_queue_depth() == QUEUE_DEPTH From 98988cdf81747192448acbd1ee3370ec292bbb8b Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 8 Oct 2024 21:41:09 -0400 Subject: [PATCH 11/13] Formatting; Use int64_t --- csrc/aio/common/deepspeed_aio_common.cpp | 43 +++++++++++----------- csrc/aio/common/deepspeed_aio_common.h | 2 +- csrc/aio/common/deepspeed_aio_utils.cpp | 18 ++++----- csrc/aio/common/deepspeed_aio_utils.h | 22 +++++------ csrc/aio/py_lib/deepspeed_aio_op_desc.cpp | 2 +- csrc/aio/py_lib/deepspeed_aio_op_desc.h | 6 +-- csrc/aio/py_lib/deepspeed_cpu_op.cpp | 2 +- csrc/aio/py_lib/deepspeed_cpu_op.h | 3 +- csrc/aio/py_lib/deepspeed_pin_tensor.cpp | 5 +-- csrc/aio/py_lib/deepspeed_pin_tensor.h | 4 +- csrc/aio/py_lib/deepspeed_py_aio.cpp | 6 +-- csrc/aio/py_lib/deepspeed_py_io_handle.cpp | 19 +++++----- csrc/aio/py_lib/deepspeed_py_io_handle.h | 17 ++++----- op_builder/cpu/async_io.py | 1 - 14 files changed, 72 insertions(+), 78 deletions(-) diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index a65cc500cc82..81c315e9a558 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -68,8 +68,8 @@ static void _get_aio_latencies(std::vector>& raw_l std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); } -static void _do_io_submit_singles(const long long int n_iocbs, - const long long int iocb_index, +static void _do_io_submit_singles(const int64_t n_iocbs, + const int64_t iocb_index, std::unique_ptr& aio_ctxt, std::vector>& submit_times) { @@ -89,8 +89,8 @@ static void _do_io_submit_singles(const long long int n_iocbs, } } -static void _do_io_submit_block(const long long int n_iocbs, - const long long int iocb_index, +static void _do_io_submit_block(const int64_t n_iocbs, + const int64_t iocb_index, std::unique_ptr& aio_ctxt, std::vector>& submit_times) { @@ -109,18 +109,18 @@ static void _do_io_submit_block(const long long int n_iocbs, assert(submit_ret > 0); } -static int _do_io_complete(const long long int min_completes, - const long long int max_completes, +static int _do_io_complete(const int64_t min_completes, + const int64_t max_completes, std::unique_ptr& aio_ctxt, std::vector>& reap_times) { const auto start_time = std::chrono::high_resolution_clock::now(); - long long int n_completes = io_pgetevents(aio_ctxt->_io_ctxt, - min_completes, - max_completes, - aio_ctxt->_io_events.data(), - nullptr, - nullptr); + int64_t n_completes = io_pgetevents(aio_ctxt->_io_ctxt, + min_completes, + max_completes, + aio_ctxt->_io_events.data(), + nullptr, + nullptr); reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); assert(n_completes >= min_completes); return n_completes; @@ -134,7 +134,7 @@ void do_aio_operation_sequential(const bool read_op, { struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); - const auto num_io_blocks = static_cast( + const auto num_io_blocks = static_cast( ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); #if DEBUG_DS_AIO_PERF const auto io_op_name = std::string(read_op ? "read" : "write"); @@ -145,15 +145,14 @@ void do_aio_operation_sequential(const bool read_op, std::vector> submit_times; std::vector> reap_times; const auto max_queue_bytes = - static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); + static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); auto start = std::chrono::high_resolution_clock::now(); - for (long long iocb_index = 0; iocb_index < num_io_blocks; - iocb_index += aio_ctxt->_queue_depth) { + for (int64_t iocb_index = 0; iocb_index < num_io_blocks; iocb_index += aio_ctxt->_queue_depth) { const auto start_offset = iocb_index * aio_ctxt->_block_size; const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; const auto n_iocbs = - min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); + min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); @@ -285,13 +284,13 @@ int open_file(const char* filename, const bool read_op) int regular_read(const char* filename, std::vector& buffer) { - long long int num_bytes; + int64_t num_bytes; const auto f_size = get_file_size(filename, num_bytes); assert(f_size != -1); buffer.resize(num_bytes); const auto fd = open(filename, O_RDONLY, 0600); assert(fd != -1); - long long int read_bytes = 0; + int64_t read_bytes = 0; auto r = 0; do { const auto buffer_ptr = buffer.data() + read_bytes; @@ -309,7 +308,7 @@ int regular_read(const char* filename, std::vector& buffer) return 0; } -static bool _validate_buffer(const char* filename, void* aio_buffer, const long long int num_bytes) +static bool _validate_buffer(const char* filename, void* aio_buffer, const int64_t num_bytes) { std::vector regular_buffer; const auto reg_ret = regular_read(filename, regular_buffer); @@ -317,7 +316,7 @@ static bool _validate_buffer(const char* filename, void* aio_buffer, const long std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" << std::endl; - if (static_cast(regular_buffer.size()) != num_bytes) { return false; } + if (static_cast(regular_buffer.size()) != num_bytes) { return false; } return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); } @@ -325,7 +324,7 @@ static bool _validate_buffer(const char* filename, void* aio_buffer, const long bool validate_aio_operation(const bool read_op, const char* filename, void* aio_buffer, - const long long int num_bytes) + const int64_t num_bytes) { const auto msg_suffix = std::string("deepspeed_aio_") + std::string(read_op ? "read()" : "write()") + diff --git a/csrc/aio/common/deepspeed_aio_common.h b/csrc/aio/common/deepspeed_aio_common.h index 2940de945ee8..aa4e49f4f4ed 100644 --- a/csrc/aio/common/deepspeed_aio_common.h +++ b/csrc/aio/common/deepspeed_aio_common.h @@ -35,4 +35,4 @@ int regular_read(const char* filename, std::vector& buffer); bool validate_aio_operation(const bool read_op, const char* filename, void* aio_buffer, - const long long int num_bytes); + const int64_t num_bytes); diff --git a/csrc/aio/common/deepspeed_aio_utils.cpp b/csrc/aio/common/deepspeed_aio_utils.cpp index 8fccb1bf96cd..0536ff6a362e 100644 --- a/csrc/aio/common/deepspeed_aio_utils.cpp +++ b/csrc/aio/common/deepspeed_aio_utils.cpp @@ -18,8 +18,8 @@ const int c_block_size = 128 * 1024; const int c_io_queue_depth = 8; io_xfer_ctxt::io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, + const int64_t file_offset, + const int64_t num_bytes, const void* buffer) : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) { @@ -36,7 +36,7 @@ io_prep_context::io_prep_context(const bool read_op, void io_prep_context::prep_iocbs(const int n_iocbs, const size_t num_bytes, const void* start_buffer, - const long long int start_offset) + const int64_t start_offset) { assert(static_cast(n_iocbs) <= _iocbs->size()); for (auto i = 0; i < n_iocbs; ++i) { @@ -64,24 +64,24 @@ io_prep_generator::io_prep_generator(const bool read_op, _next_iocb_index(0) { _num_io_blocks = - static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); + static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); _remaining_io_blocks = _num_io_blocks; } int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) { if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { - assert(static_cast(_remaining_bytes) == _remaining_io_blocks); + assert(static_cast(_remaining_bytes) == _remaining_io_blocks); return 0; } assert(static_cast(n_iocbs) <= iocbs->size()); - auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); + auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; - const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); + const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); if (_read_op) { io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); @@ -95,7 +95,7 @@ int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* return actual_n_iocbs; } -int get_file_size(const char* filename, long long int& size) +int get_file_size(const char* filename, int64_t& size) { struct stat st; if (stat(filename, &st) == -1) { return -1; } @@ -103,7 +103,7 @@ int get_file_size(const char* filename, long long int& size) return 0; } -void* ds_page_aligned_alloc(const long long int size, const bool lock) +void* ds_page_aligned_alloc(const int64_t size, const bool lock) { void* ptr; int retval; diff --git a/csrc/aio/common/deepspeed_aio_utils.h b/csrc/aio/common/deepspeed_aio_utils.h index ea56cd1de236..20e81fe8eebd 100644 --- a/csrc/aio/common/deepspeed_aio_utils.h +++ b/csrc/aio/common/deepspeed_aio_utils.h @@ -30,13 +30,13 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. struct io_xfer_ctxt { const int _fd; - const long long int _base_offset; + const int64_t _base_offset; const void* _mem_buffer; - const long long int _num_bytes; + const int64_t _num_bytes; io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, + const int64_t file_offset, + const int64_t num_bytes, const void* buffer); }; @@ -54,7 +54,7 @@ struct io_prep_context { void prep_iocbs(const int n_iocbs, const size_t num_bytes, const void* start_buffer, - const long long int start_offset); + const int64_t start_offset); }; struct io_prep_generator { @@ -62,10 +62,10 @@ struct io_prep_generator { const std::unique_ptr& _xfer_ctxt; const size_t _block_size; - long long int _remaining_bytes; - long long int _num_io_blocks; - long long int _remaining_io_blocks; - long long int _next_iocb_index; + int64_t _remaining_bytes; + int64_t _num_io_blocks; + int64_t _remaining_io_blocks; + int64_t _next_iocb_index; io_prep_generator(const bool read_op, const std::unique_ptr& xfer_ctxt, @@ -74,6 +74,6 @@ struct io_prep_generator { int prep_iocbs(const int n_iocbs, std::vector* iocbs); }; -void* ds_page_aligned_alloc(const long long int size, const bool lock = false); +void* ds_page_aligned_alloc(const int64_t size, const bool lock = false); -int get_file_size(const char* filename, long long int& size); +int get_file_size(const char* filename, int64_t& size); diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp index bbdf7e7d2321..6f311c5400c7 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -11,7 +11,7 @@ io_op_desc_t::io_op_desc_t(const bool read_op, const torch::Tensor& buffer, const int fd, const char* filename, - const long long int file_num_bytes, + const int64_t file_num_bytes, const int intra_op_parallelism, const bool validate) : _read_op(read_op), diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h index 8b7b6f577346..f841b8ce520a 100644 --- a/csrc/aio/py_lib/deepspeed_aio_op_desc.h +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -14,9 +14,9 @@ struct io_op_desc_t { torch::Tensor _buffer; int _fd; const std::string _filename; - const long long int _file_num_bytes; + const int64_t _file_num_bytes; const int _intra_op_parallelism; - const long long int _num_bytes_per_thread; + const int64_t _num_bytes_per_thread; torch::Tensor _contiguous_buffer; const bool _validate; @@ -24,7 +24,7 @@ struct io_op_desc_t { const torch::Tensor& buffer, const int fd, const char* filename, - const long long int file_num_bytes, + const int64_t file_num_bytes, const int intra_op_parallelism, const bool validate); diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp index 24d207575d68..da1a52d9c6e3 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.cpp +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -14,7 +14,7 @@ cpu_op_desc_t::cpu_op_desc_t( const std::unique_ptr& pinned_tensor_mgr, const int fd, const char* filename, - const long long int file_num_bytes, + const int64_t file_num_bytes, const int intra_op_parallelism, const bool validate) : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, intra_op_parallelism, validate), diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h index 09c4cf261907..9de2fa254048 100644 --- a/csrc/aio/py_lib/deepspeed_cpu_op.h +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -6,7 +6,6 @@ #include #include #include "deepspeed_aio_op_desc.h" -// #include "deepspeed_pin_tensor.h" struct cpu_op_desc_t : io_op_desc_t { torch::Tensor _cpu_buffer; @@ -19,7 +18,7 @@ struct cpu_op_desc_t : io_op_desc_t { const std::unique_ptr& pinned_tensor_mgr, const int fd, const char* filename, - const long long int file_num_bytes, + const int64_t file_num_bytes, const int intra_op_parallelism, const bool validate); diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp index 7f484ba95740..6d2800468e06 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp @@ -29,11 +29,10 @@ torch::Tensor deepspeed_pin_tensor_t::alloc(const int64_t num_elem, _locked_tensors[pinned_buffer] = num_bytes; - return at::from_blob(pinned_buffer, static_cast(num_elem), options); + return at::from_blob(pinned_buffer, static_cast(num_elem), options); } -torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem, - const at::ScalarType& elem_type) +torch::Tensor deepspeed_pin_tensor_t::alloc(const int64_t num_elem, const at::ScalarType& elem_type) { auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU).requires_grad(false); return alloc(num_elem, options); diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.h b/csrc/aio/py_lib/deepspeed_pin_tensor.h index c298c3d78013..4b8ad7e76085 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.h +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.h @@ -15,13 +15,13 @@ Functionality for managing CPU tensors occupying page-locked memory. #include "deepspeed_py_aio.h" struct deepspeed_pin_tensor_t { - std::map _locked_tensors; + std::map _locked_tensors; deepspeed_pin_tensor_t() = default; ~deepspeed_pin_tensor_t(); - torch::Tensor alloc(const long long num_elem, const at::ScalarType& elem_type); + torch::Tensor alloc(const int64_t num_elem, const at::ScalarType& elem_type); torch::Tensor alloc(const int64_t num_elem, const torch::TensorOptions& options); bool free(torch::Tensor& locked_tensor); diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index eac268d33433..02b04057d1ac 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -51,7 +51,7 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer, if (fd == -1) { return -1; } auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); + const auto num_write_bytes = static_cast(buffer.nbytes()); std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); @@ -83,7 +83,7 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, const bool validate) { const auto start_time = std::chrono::high_resolution_clock::now(); - long long num_file_bytes; + int64_t num_file_bytes; if (-1 == get_file_size(filename, num_file_bytes)) { const auto error_code = errno; report_file_error(filename, " fstat for read", error_code); @@ -95,7 +95,7 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, if (fd == -1) { return -1; } auto read_buffer = (char*)buffer.data_ptr(); - assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert(static_cast(buffer.nbytes()) == num_file_bytes); std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp index d0b5a7a418f9..48ea8a1339d4 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -64,13 +64,13 @@ int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, con assert(_aio_ctxt); - long long num_file_bytes; + int64_t num_file_bytes; if (-1 == get_file_size(filename, num_file_bytes)) { const auto error_code = errno; report_file_error(filename, " fstat for read", error_code); return -1; } - assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert(static_cast(buffer.nbytes()) == num_file_bytes); const auto fd = open_file(filename, true); if (fd == -1) { return -1; } @@ -108,7 +108,7 @@ int deepspeed_io_handle_t::write(const torch::Tensor& buffer, if (fd == -1) { return -1; } auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); + const auto num_write_bytes = static_cast(buffer.nbytes()); std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); if (_aio_config._overlap_events) { @@ -188,8 +188,7 @@ int deepspeed_io_handle_t::wait() return num_completed_ops; } -bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, - const long long int num_bytes) +bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, const int64_t num_bytes) { const auto op_string = read_op ? "Read" : "Write"; if (num_bytes % get_intra_op_parallelism()) { @@ -206,7 +205,7 @@ std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( const torch::Tensor& buffer, const int fd, const char* filename, - const long long int file_num_bytes, + const int64_t file_num_bytes, const bool validate) { return std::make_shared(read_op, @@ -224,13 +223,13 @@ int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, const bool validate, const bool async) { - long long num_file_bytes; + int64_t num_file_bytes; if (-1 == get_file_size(filename, num_file_bytes)) { const auto error_code = errno; report_file_error(filename, " fstat for read", error_code); return -1; } - const auto buffer_bytes = static_cast(buffer.nbytes()); + const auto buffer_bytes = static_cast(buffer.nbytes()); if (buffer_bytes != num_file_bytes) { std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes << " != " << num_file_bytes << std::endl; @@ -257,7 +256,7 @@ int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, const bool validate, const bool async) { - const auto num_write_bytes = static_cast(buffer.nbytes()); + const auto num_write_bytes = static_cast(buffer.nbytes()); assert((num_write_bytes % _intra_op_parallelism) == 0); if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } @@ -294,7 +293,7 @@ int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* return pwrite(buffer, filename, false, true); } -at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const long long int num_elem, +at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const int64_t num_elem, const torch::Tensor& example_tensor) { return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h index 8e649be1e4c2..4fedf8080818 100644 --- a/csrc/aio/py_lib/deepspeed_py_io_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -61,7 +61,7 @@ struct deepspeed_io_handle_t { int async_pwrite(const torch::Tensor& buffer, const char* filename); // TODO: Make API's args to be shape and dtype. - torch::Tensor new_cpu_locked_tensor(const long long int num_elem, + torch::Tensor new_cpu_locked_tensor(const int64_t num_elem, const torch::Tensor& example_tensor); bool free_cpu_locked_tensor(torch::Tensor&); @@ -74,13 +74,12 @@ struct deepspeed_io_handle_t { std::shared_ptr _wait_for_aio_work(); - bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); + bool _is_valid_parallel_aio_op(const bool read_op, const int64_t num_bytes); - virtual std::shared_ptr _create_io_op_desc( - const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int file_num_bytes, - const bool validate); + virtual std::shared_ptr _create_io_op_desc(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int64_t file_num_bytes, + const bool validate); }; diff --git a/op_builder/cpu/async_io.py b/op_builder/cpu/async_io.py index 56eb71303d78..493ef174566e 100644 --- a/op_builder/cpu/async_io.py +++ b/op_builder/cpu/async_io.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os import distutils.spawn import subprocess From 90e25da4ebeeaa7d8b2eb7a6def3eaba48ca2f42 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 9 Oct 2024 09:09:30 -0400 Subject: [PATCH 12/13] Skip fp16 tests on CPU --- tests/unit/runtime/zero/test_nvme_checkpointing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/runtime/zero/test_nvme_checkpointing.py b/tests/unit/runtime/zero/test_nvme_checkpointing.py index 75cba2e789c1..850c8eb3e349 100644 --- a/tests/unit/runtime/zero/test_nvme_checkpointing.py +++ b/tests/unit/runtime/zero/test_nvme_checkpointing.py @@ -15,6 +15,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import Init from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator class TestNVMeCheckpointing(DistributedTest): @@ -29,6 +30,9 @@ def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_de first_stage_steps, second_stage_steps = 2, 2 + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: pytest.skip('Skip tests since async-io is not compatible') From 8a52388eb57b3cb481acf63fefdb0ba61b80b932 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 9 Oct 2024 14:46:17 -0400 Subject: [PATCH 13/13] Add Cuda 12.6 --- op_builder/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index f95341f137b4..1609bc9005f4 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -76,7 +76,7 @@ def get_default_compute_capabilities(): cuda_minor_mismatch_ok = { 10: ["10.0", "10.1", "10.2"], 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], - 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5"], + 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6"], }