Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cpp/src/arrow/buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,4 +1023,17 @@ TEST(TestBufferConcatenation, EmptyBuffer) {
AssertMyBufferEqual(*result, contents);
}

TEST(TestDeviceRegistry, Basics) {
// Test the error cases for the device registry

// CPU is already registered
ASSERT_RAISES(KeyError,
RegisterDeviceMapper(DeviceAllocationType::kCPU, [](int64_t device_id) {
return default_cpu_memory_manager();
}));

// VPI is not registered
ASSERT_RAISES(KeyError, GetDeviceMapper(DeviceAllocationType::kVPI));
}

} // namespace arrow
11 changes: 5 additions & 6 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1967,12 +1967,11 @@ Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
return ImportRecordBatch(array, *maybe_schema);
}

Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType device_type,
int64_t device_id) {
if (device_type != ARROW_DEVICE_CPU) {
return Status::NotImplemented("Only importing data on CPU is supported");
}
return default_cpu_memory_manager();
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMemoryMapper(
ArrowDeviceType device_type, int64_t device_id) {
ARROW_ASSIGN_OR_RAISE(auto mapper,
GetDeviceMapper(static_cast<DeviceAllocationType>(device_type)));
return mapper(device_id);
}

Result<std::shared_ptr<Array>> ImportDeviceArray(struct ArrowDeviceArray* array,
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/arrow/c/bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ using DeviceMemoryMapper =
std::function<Result<std::shared_ptr<MemoryManager>>(ArrowDeviceType, int64_t)>;

ARROW_EXPORT
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType device_type,
int64_t device_id);
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMemoryMapper(
ArrowDeviceType device_type, int64_t device_id);

/// \brief EXPERIMENTAL: Import C++ device array from the C data interface.
///
Expand All @@ -236,7 +236,7 @@ Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType devic
ARROW_EXPORT
Result<std::shared_ptr<Array>> ImportDeviceArray(
struct ArrowDeviceArray* array, std::shared_ptr<DataType> type,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// \brief EXPERIMENTAL: Import C++ device array and its type from the C data interface.
///
Expand All @@ -253,7 +253,7 @@ Result<std::shared_ptr<Array>> ImportDeviceArray(
ARROW_EXPORT
Result<std::shared_ptr<Array>> ImportDeviceArray(
struct ArrowDeviceArray* array, struct ArrowSchema* type,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device from the C data
/// interface.
Expand All @@ -271,7 +271,7 @@ Result<std::shared_ptr<Array>> ImportDeviceArray(
ARROW_EXPORT
Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
struct ArrowDeviceArray* array, std::shared_ptr<Schema> schema,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device and its schema
/// from the C data interface.
Expand All @@ -291,7 +291,7 @@ Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
ARROW_EXPORT
Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
struct ArrowDeviceArray* array, struct ArrowSchema* schema,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// @}

Expand Down
63 changes: 63 additions & 0 deletions cpp/src/arrow/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "arrow/device.h"

#include <cstring>
#include <mutex>
#include <unordered_map>
#include <utility>

#include "arrow/array.h"
Expand Down Expand Up @@ -268,4 +270,65 @@ std::shared_ptr<MemoryManager> CPUDevice::default_memory_manager() {
return default_cpu_memory_manager();
}

namespace {

class DeviceMapperRegistryImpl {
public:
DeviceMapperRegistryImpl() {}

Status RegisterDevice(DeviceAllocationType device_type, DeviceMapper memory_mapper) {
std::lock_guard<std::mutex> lock(lock_);
auto [_, inserted] = registry_.try_emplace(device_type, std::move(memory_mapper));
if (!inserted) {
return Status::KeyError("Device type ", static_cast<int>(device_type),
" is already registered");
}
return Status::OK();
}

Result<DeviceMapper> GetMapper(DeviceAllocationType device_type) {
std::lock_guard<std::mutex> lock(lock_);
auto it = registry_.find(device_type);
if (it == registry_.end()) {
return Status::KeyError("Device type ", static_cast<int>(device_type),
"is not registered");
}
return it->second;
}

private:
std::mutex lock_;
std::unordered_map<DeviceAllocationType, DeviceMapper> registry_;
};

Result<std::shared_ptr<MemoryManager>> DefaultCPUDeviceMapper(int64_t device_id) {
return default_cpu_memory_manager();
}

static std::unique_ptr<DeviceMapperRegistryImpl> CreateDeviceRegistry() {
auto registry = std::make_unique<DeviceMapperRegistryImpl>();

// Always register the CPU device
DCHECK_OK(registry->RegisterDevice(DeviceAllocationType::kCPU, DefaultCPUDeviceMapper));

return registry;
}

DeviceMapperRegistryImpl* GetDeviceRegistry() {
static auto g_registry = CreateDeviceRegistry();
return g_registry.get();
}

} // namespace

Status RegisterDeviceMapper(DeviceAllocationType device_type, DeviceMapper mapper) {
auto registry = GetDeviceRegistry();
return registry->RegisterDevice(device_type, std::move(mapper));
}

Result<DeviceMapper> GetDeviceMapper(DeviceAllocationType device_type) {
auto registry = GetDeviceRegistry();
return registry->GetMapper(device_type);
}

} // namespace arrow
28 changes: 28 additions & 0 deletions cpp/src/arrow/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,32 @@ class ARROW_EXPORT CPUMemoryManager : public MemoryManager {
ARROW_EXPORT
std::shared_ptr<MemoryManager> default_cpu_memory_manager();

using DeviceMapper =
std::function<Result<std::shared_ptr<MemoryManager>>(int64_t device_id)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but a couple more suggestions to unify naming:

  1. rename MemoryMapper to DeviceMemoryMapper?
  2. rename RegisterDeviceMemoryManager to RegisterDeviceMemoryMapper
  3. rename GetDeviceMemoryManager to GetDeviceMemoryMapper

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points, that naming is definitely more consistent.

There is however one problem that we already define a DeviceMemoryMapper for the keyword type in the actual bridge.h Import methods:

using DeviceMemoryMapper =
std::function<Result<std::shared_ptr<MemoryManager>>(ArrowDeviceType, int64_t)>;

and we should probably find a distinct name, given that both are slight different (the one takes device_type+device_id and returns a MemoryManager, while the other is a function already tied to a specific device_type and thus only takes a device_id, returning again a MemoryManager)

It's of course a subtle difference that might be difficult to embody in a name. But at least using distinct names seems best.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps DeviceIdMapper then? Not terribly pretty I admit...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that sounds good for the function type alias, but then I would personally leave the register/get functions as is? I would find RegisterDeviceIdMapper a bit strange with the focus on the id, because you are also registering a device type, it's just that the value you store for the registered type is the DeviceIdMapper ..

Anyway, in the end it doesn't matter that much, happy to go with whatever we come up with.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or DeviceMapper / RegisterDeviceMapper / GetDeviceMapper ? (that's a bit more generic, but keeps the three consistent with each other)


/// \brief Register a function to retrieve a MemoryManager for a Device type
///
/// This registers the device type globally. A specific device type can only
/// be registered once. This method is thread-safe.
///
/// Currently, this registry is only used for importing data through the C Device
/// Data Interface (for the default Device to MemoryManager mapper in
/// arrow::ImportDeviceArray/ImportDeviceRecordBatch).
///
/// \param[in] device_type the device type for which to register a MemoryManager
/// \param[in] mapper function that takes a device id and returns the appropriate
/// MemoryManager for the registered device type and given device id
/// \return Status
ARROW_EXPORT
Status RegisterDeviceMapper(DeviceAllocationType device_type, DeviceMapper mapper);

/// \brief Get the registered function to retrieve a MemoryManager for the
/// given Device type
///
/// \param[in] device_type the device type
/// \return function that takes a device id and returns the appropriate
/// MemoryManager for the registered device type and given device id
ARROW_EXPORT
Result<DeviceMapper> GetDeviceMapper(DeviceAllocationType device_type);

} // namespace arrow
19 changes: 19 additions & 0 deletions cpp/src/arrow/gpu/cuda_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cuda.h>

#include "arrow/buffer.h"
#include "arrow/device.h"
#include "arrow/io/memory.h"
#include "arrow/memory_pool.h"
#include "arrow/status.h"
Expand Down Expand Up @@ -501,5 +502,23 @@ Result<std::shared_ptr<MemoryManager>> DefaultMemoryMapper(ArrowDeviceType devic
}
}

namespace {

Result<std::shared_ptr<MemoryManager>> DefaultCUDADeviceMapper(int64_t device_id) {
ARROW_ASSIGN_OR_RAISE(auto device, arrow::cuda::CudaDevice::Make(device_id));
return device->default_memory_manager();
}

bool RegisterCUDADeviceInternal() {
DCHECK_OK(RegisterDeviceMapper(DeviceAllocationType::kCUDA, DefaultCUDADeviceMapper));
// TODO add the CUDA_HOST and CUDA_MANAGED allocation types when they are supported in
// the CudaDevice
return true;
}

static auto cuda_registered = RegisterCUDADeviceInternal();

} // namespace

} // namespace cuda
} // namespace arrow
4 changes: 3 additions & 1 deletion cpp/src/arrow/gpu/cuda_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ Result<uintptr_t> GetDeviceAddress(const uint8_t* cpu_data,
ARROW_EXPORT
Result<uint8_t*> GetHostAddress(uintptr_t device_ptr);

ARROW_EXPORT
ARROW_DEPRECATED(
"Deprecated in 16.0.0. The CUDA device is registered by default, and you can use "
"arrow::DefaultDeviceMapper instead.")
Result<std::shared_ptr<MemoryManager>> DefaultMemoryMapper(ArrowDeviceType device_type,
int64_t device_id);

Expand Down
15 changes: 2 additions & 13 deletions cpp/src/arrow/gpu/cuda_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -716,17 +716,6 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
public:
using ArrayFactory = std::function<Result<std::shared_ptr<Array>>()>;

static Result<std::shared_ptr<MemoryManager>> DeviceMapper(ArrowDeviceType type,
int64_t id) {
if (type != ARROW_DEVICE_CUDA) {
return Status::NotImplemented("should only be CUDA device");
}

ARROW_ASSIGN_OR_RAISE(auto manager, cuda::CudaDeviceManager::Instance());
ARROW_ASSIGN_OR_RAISE(auto device, manager->GetDevice(id));
return device->default_memory_manager();
}

static ArrayFactory JSONArrayFactory(std::shared_ptr<DataType> type, const char* json) {
return [=]() { return ArrayFromJSON(type, json); };
}
Expand Down Expand Up @@ -759,7 +748,7 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {

std::shared_ptr<Array> device_array_roundtripped;
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
ImportDeviceArray(&c_array, &c_schema));
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));

Expand All @@ -779,7 +768,7 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
ASSERT_OK(ExportDeviceArray(*device_array, sync, &c_array, &c_schema));
device_array_roundtripped.reset();
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
ImportDeviceArray(&c_array, &c_schema));
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));

Expand Down