diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 2788a49c8ebf..3b9eb53b7cd2 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -108,7 +108,8 @@ cdef class CoreWorker: owner_address=*) cdef store_task_outputs( self, worker, outputs, const c_vector[CObjectID] return_ids, - c_vector[shared_ptr[CRayObject]] *returns) + c_vector[shared_ptr[CRayObject]] *returns, + c_vector[CObjectID]* contained_ids) cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 128cd027556a..5c09fa29f84b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -394,7 +394,8 @@ cdef execute_task( const c_vector[CObjectID] &c_arg_reference_ids, const c_vector[CObjectID] &c_return_ids, const c_string debugger_breakpoint, - c_vector[shared_ptr[CRayObject]] *returns): + c_vector[shared_ptr[CRayObject]] *returns, + c_vector[CObjectID] *contained_id): worker = ray.worker.global_worker manager = worker.function_actor_manager @@ -563,7 +564,7 @@ cdef execute_task( # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): core_worker.store_task_outputs( - worker, outputs, c_return_ids, returns) + worker, outputs, c_return_ids, returns, contained_id) except Exception as error: # If the debugger is enabled, drop into the remote pdb here. if "RAY_PDB" in os.environ: @@ -582,7 +583,7 @@ cdef execute_task( for _ in range(c_return_ids.size()): errors.append(failure_object) core_worker.store_task_outputs( - worker, errors, c_return_ids, returns) + worker, errors, c_return_ids, returns, contained_id) ray._private.utils.push_error_to_driver( worker, ray_constants.TASK_PUSH_ERROR, @@ -619,6 +620,7 @@ cdef CRayStatus task_execution_handler( const c_vector[CObjectID] &c_return_ids, const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns, + c_vector[CObjectID] *contained_id, shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil: with gil: try: @@ -628,7 +630,7 @@ cdef CRayStatus task_execution_handler( # it does, that indicates that there was an internal error. execute_task(task_type, task_name, ray_function, c_resources, c_args, c_arg_reference_ids, c_return_ids, - debugger_breakpoint, returns) + debugger_breakpoint, returns, contained_id) except Exception as e: sys_exit = SystemExit() if isinstance(e, RayActorError) and \ @@ -1633,17 +1635,22 @@ cdef class CoreWorker: cdef store_task_outputs( self, worker, outputs, const c_vector[CObjectID] return_ids, - c_vector[shared_ptr[CRayObject]] *returns): + c_vector[shared_ptr[CRayObject]] *returns, + c_vector[CObjectID]* contained_id_out): cdef: CObjectID return_id size_t data_size shared_ptr[CBuffer] metadata + c_vector[CObjectID]* contained_id_ptr c_vector[CObjectID] contained_id c_vector[CObjectID] return_ids_vector if return_ids.size() == 0: return - + if contained_id_out == NULL: + contained_id_ptr = &contained_id + else: + contained_id_ptr = contained_id_out n_returns = len(outputs) returns.resize(n_returns) for i in range(n_returns): @@ -1661,13 +1668,13 @@ cdef class CoreWorker: # Reset debugging context of this worker. ray.worker.global_worker.debugger_get_breakpoint = b"" metadata = string_to_buffer(metadata_str) - contained_id = ObjectRefsToVector( + contained_id_ptr[0] = ObjectRefsToVector( serialized_object.contained_object_refs) with nogil: check_status( CCoreWorkerProcess.GetCoreWorker().AllocateReturnObject( - return_id, data_size, metadata, contained_id, + return_id, data_size, metadata, contained_id_ptr[0], &returns[0][i])) if returns[0][i].get() != NULL: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 23fd4b9944f1..a3b8fc2ce97b 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -266,6 +266,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_vector[CObjectID] &return_ids, const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns, + c_vector[CObjectID] *return_contained_ids, shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil ) task_execution_callback diff --git a/python/ray/tests/test_basic_3.py b/python/ray/tests/test_basic_3.py index 9ddf336710b1..5fd4dc2a289b 100644 --- a/python/ray/tests/test_basic_3.py +++ b/python/ray/tests/test_basic_3.py @@ -18,6 +18,22 @@ logger = logging.getLogger(__name__) +def test_object_transfer(shutdown_only): + ray.init() + + @ray.remote + class Test: + def gen(self): + r = ray.put(b"a" * 10 * 1024 * 1024) + return [r] + + actor = Test.remote() + v = actor.gen.remote() + ray.wait([v]) + ray.kill(actor) + assert ray.get(ray.get(v)[0]) == b"a" * 10 * 1024 * 1024 + + def test_auto_global_gc(shutdown_only): # 100MB ray.init(num_cpus=1, object_store_memory=100 * 1024 * 1024) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index e25b337b48fe..bb6d1d8c585a 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -415,4 +415,8 @@ RAY_CONFIG(bool, gcs_task_scheduling_enabled, getenv("RAY_GCS_TASK_SCHEDULING_ENABLED") != nullptr && getenv("RAY_GCS_TASK_SCHEDULING_ENABLED") == std::string("true")) +RAY_CONFIG(bool, ownership_transfer_enabled, + getenv("RAY_TRANSFER_OWNERSHIP") != nullptr && + getenv("RAY_TRANSFER_OWNERSHIP") == std::string("1")) + RAY_CONFIG(uint32_t, max_error_msg_size_bytes, 512 * 1024) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index ef4ef574b6bc..ea74c17f3f7f 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -374,12 +374,15 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Initialize task receivers. if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) { RAY_CHECK(options_.task_execution_callback != nullptr); - auto execute_task = - std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, + std::placeholders::_4, std::placeholders::_5); + auto object_transfer = + std::bind(&CoreWorker::ShareOwnershipInternal, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); direct_task_receiver_ = std::make_unique( worker_context_, task_execution_service_, execute_task, - [this] { return local_raylet_client_->TaskDone(); }); + [this] { return local_raylet_client_->TaskDone(); }, object_transfer); } // Initialize raylet client. @@ -554,7 +557,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ "CoreWorker.ReconstructObject"); }; task_manager_.reset(new TaskManager( - memory_store_, reference_counter_, + memory_store_, reference_counter_, rpc_address_, /* retry_task_callback= */ [this](TaskSpecification &spec, bool delay) { if (delay) { @@ -2069,6 +2072,7 @@ Status CoreWorker::AllocateReturnObject(const ObjectID &object_id, Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, const std::shared_ptr &resource_ids, std::vector> *return_objects, + std::vector *contained_ids, ReferenceCounter::ReferenceTableProto *borrowed_refs) { RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString(); task_queue_length_ -= 1; @@ -2126,7 +2130,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, status = options_.task_execution_callback( task_type, task_spec.GetName(), func, task_spec.GetRequiredResources().GetResourceMap(), args, arg_reference_ids, - return_ids, task_spec.GetDebuggerBreakpoint(), return_objects, + return_ids, task_spec.GetDebuggerBreakpoint(), return_objects, contained_ids, creation_task_exception_pb_bytes); // Get the reference counts for any IDs that we borrowed during this task and @@ -2217,7 +2221,8 @@ void CoreWorker::ExecuteTaskLocalMode(const TaskSpecification &task_spec, } auto old_id = GetActorId(); SetActorId(actor_id); - RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs)); + RAY_UNUSED( + ExecuteTask(task_spec, resource_ids, &return_objects, nullptr, &borrowed_refs)); SetActorId(old_id); } @@ -2824,6 +2829,79 @@ void CoreWorker::HandleRunOnUtilWorker(const rpc::RunOnUtilWorkerRequest &reques } } +void CoreWorker::ShareOwnershipInternal( + const rpc::Address &to_addr, const std::vector &ids, + std::function)> cb) { + std::vector> node_id_mapping; + for (auto id : ids) { + if (!reference_counter_->OwnedByUs(id)) { + continue; + } + auto node_id = reference_counter_->GetObjectPinnedLocation(id); + if (node_id) { + node_id_mapping.emplace_back(*node_id, id); + } else { + // TODO (yic) Should wait until object is ready. + RAY_LOG(DEBUG) << "We only take care of put objects right now"; + continue; + } + } + + if (node_id_mapping.empty()) { + cb({}); + } else { + auto in_flight = std::make_shared(node_id_mapping.size()); + auto successed_ids = std::make_shared>(); + for (auto &v : node_id_mapping) { + auto node_info = gcs_client_->Nodes().Get(v.first); + auto grpc_client = rpc::NodeManagerWorkerClient::make( + node_info->node_manager_address(), node_info->node_manager_port(), + *client_call_manager_); + auto raylet_client = std::make_shared(std::move(grpc_client)); + raylet_client->PinObjectIDs( + to_addr, {v.second}, + [this, to_addr, in_flight, successed_ids, id = v.second, node_id = v.first, cb]( + auto &status, auto &pin_reply) mutable { + if (status.ok()) { + successed_ids->insert(id); + } + // TODO (yic): better with a barrier + if (--*in_flight == 0) { + absl::flat_hash_map> results; + bool exception = false; + plasma_store_provider_->Get(*successed_ids, -1, worker_context_, &results, + &exception); + RAY_CHECK(!exception) << "Failed to get object from store"; + google::protobuf::RepeatedPtrField transferred_objs; + for (auto &result : results) { + RAY_CHECK(result.second->IsInPlasmaError()) + << "Inline objects are shared by passing value"; + auto obj = transferred_objs.Add(); + obj->set_object_id(result.first.Binary()); + obj->set_object_size(result.second->GetSize()); + obj->set_pinned_at_node(node_id.Binary()); + } + cb(std::move(transferred_objs)); + } + }); + } + } +} + +void CoreWorker::HandleShareOwnership(const rpc::ShareOwnershipRequest &request, + rpc::ShareOwnershipReply *reply, + rpc::SendReplyCallback send_reply_callback) { + std::vector ids; + for (const auto &id : request.object_ids()) { + ids.push_back(ObjectID::FromBinary(id)); + } + const auto &addr = request.new_owner_address(); + ShareOwnershipInternal(addr, ids, [send_reply_callback, reply](auto ids) { + reply->mutable_shared_objs()->Swap(&ids); + send_reply_callback(Status::OK(), nullptr, nullptr); + }); +} + void CoreWorker::HandleSpillObjects(const rpc::SpillObjectsRequest &request, rpc::SpillObjectsReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 29884dc7642a..5e352f83f79c 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -68,6 +68,7 @@ struct CoreWorkerOptions { const std::vector &arg_reference_ids, const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results, + std::vector *return_contained_ids, std::shared_ptr &creation_task_exception_pb_bytes)>; CoreWorkerOptions() @@ -992,6 +993,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void HandleExit(const rpc::ExitRequest &request, rpc::ExitReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleShareOwnership(const rpc::ShareOwnershipRequest &request, + rpc::ShareOwnershipReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// /// Public methods related to async actor call. This should only be used when /// the actor is (1) direct actor and (2) using asyncio mode. @@ -1023,6 +1028,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { bool IsExiting() const; private: + void ShareOwnershipInternal( + const rpc::Address &to_addr, const std::vector &ids, + std::function)> cb); + void SetCurrentTaskId(const TaskID &task_id); void SetActorId(const ActorID &actor_id); @@ -1089,6 +1098,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { Status ExecuteTask(const TaskSpecification &task_spec, const std::shared_ptr &resource_ids, std::vector> *return_objects, + std::vector *contained_ids, ReferenceCounter::ReferenceTableProto *borrowed_refs); /// Execute a local mode task (runs normal ExecuteTask) diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index e1d0a3b49544..3fa57972d01a 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -102,6 +102,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( const std::vector &arg_reference_ids, const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results, + std::vector *return_contained_ids, std::shared_ptr &creation_task_exception_pb) { JNIEnv *env = GetJNIEnv(); RAY_CHECK(java_task_executor); @@ -169,6 +170,11 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( return_objects[i]->HasData() ? return_objects[i]->GetData()->Size() : 0; auto &metadata = return_objects[i]->GetMetadata(); auto &contained_object_id = return_objects[i]->GetNestedIds(); + if (return_contained_ids != nullptr) { + return_contained_ids->insert(return_contained_ids->end(), + contained_object_id.begin(), + contained_object_id.end()); + } auto result_ptr = &(*results)[0]; RAY_CHECK_OK(ray::CoreWorkerProcess::GetCoreWorker().AllocateReturnObject( diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 6c5133a893b2..aafba128a2b0 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -769,6 +769,16 @@ void ReferenceCounter::CleanupBorrowersOnRefRemoved( DeleteReferenceInternal(it, nullptr); } +void ReferenceCounter::RemoveBorrower(const ObjectID &object_id, + const rpc::Address &address) { + absl::MutexLock lock(&mutex_); + auto it = object_id_refs_.find(object_id); + RAY_CHECK(it != object_id_refs_.end()) << object_id; + if (it->second.borrowers.erase(address)) { + DeleteReferenceInternal(it, nullptr); + } +} + void ReferenceCounter::WaitForRefRemoved(const ReferenceTable::iterator &ref_it, const rpc::WorkerAddress &addr, const ObjectID &contained_in_id) { @@ -986,7 +996,7 @@ bool ReferenceCounter::RemoveObjectLocation(const ObjectID &object_id, } absl::optional> ReferenceCounter::GetObjectLocations( - const ObjectID &object_id) { + const ObjectID &object_id) const { absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); if (it == object_id_refs_.end()) { @@ -997,6 +1007,18 @@ absl::optional> ReferenceCounter::GetObjectLocations return it->second.locations; } +absl::optional ReferenceCounter::GetObjectPinnedLocation( + const ObjectID &object_id) const { + absl::MutexLock lock(&mutex_); + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + RAY_LOG(WARNING) << "Tried to get the object locations for an object " << object_id + << " that doesn't exist in the reference table"; + return absl::nullopt; + } + return it->second.pinned_at_raylet_id; +} + size_t ReferenceCounter::GetObjectSize(const ObjectID &object_id) const { absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 33d84ad9ccec..de62e797771f 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -286,6 +286,9 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[in] object_id The object that we were borrowing. void HandleRefRemoved(const ObjectID &object_id) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + void RemoveBorrower(const ObjectID &object_id, const rpc::Address &address) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + /// Returns the total number of ObjectIDs currently in scope. size_t NumObjectIDsInScope() const LOCKS_EXCLUDED(mutex_); @@ -404,7 +407,10 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \return The nodes that have the object if the reference exists, empty optional /// otherwise. absl::optional> GetObjectLocations( - const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); + const ObjectID &object_id) LOCKS_EXCLUDED(mutex_) const; + + absl::optional GetObjectPinnedLocation(const ObjectID &object_id) + LOCKS_EXCLUDED(mutex_) const; /// Subscribe to object location changes that are more recent than the given version. /// The provided callback will be invoked when new locations become available. diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 72136fb15ccf..9a0520ed0f22 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -230,6 +230,13 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, } } + for (auto &obj : reply.shared_obj_info()) { + ObjectID object_id = ObjectID::FromBinary(obj.object_id()); + reference_counter_->AddOwnedObject(object_id, {}, rpc_address_, "", + obj.object_size(), false, + NodeID::FromBinary(obj.pinned_at_node())); + } + TaskSpecification spec; bool release_lineage = true; { diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index bfea53fb3b4a..11682f6bf3ab 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -63,11 +63,12 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa public: TaskManager(std::shared_ptr in_memory_store, std::shared_ptr reference_counter, - RetryTaskCallback retry_task_callback, + const rpc::Address &address, RetryTaskCallback retry_task_callback, const std::function &check_node_alive, ReconstructObjectCallback reconstruct_object_callback) : in_memory_store_(in_memory_store), reference_counter_(reference_counter), + rpc_address_(address), retry_task_callback_(retry_task_callback), check_node_alive_(check_node_alive), reconstruct_object_callback_(reconstruct_object_callback) { @@ -249,6 +250,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// submitted tasks (dependencies and return objects). std::shared_ptr reference_counter_; + rpc::Address rpc_address_; + /// Called when a task should be retried. const RetryTaskCallback retry_task_callback_; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 4ca73e77dae8..6121e930ab05 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -473,16 +473,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask( RAY_CHECK(num_returns >= 0); std::vector> return_objects; + std::vector return_contained_ids; auto status = task_handler_(task_spec, resource_ids, &return_objects, - reply->mutable_borrowed_refs()); - + &return_contained_ids, reply->mutable_borrowed_refs()); bool objects_valid = return_objects.size() == num_returns; if (objects_valid) { for (size_t i = 0; i < return_objects.size(); i++) { auto return_object = reply->add_return_objects(); ObjectID id = ObjectID::FromIndex(task_spec.TaskId(), /*index=*/i + 1); return_object->set_object_id(id.Binary()); - // The object is nullptr if it already existed in the object store. const auto &result = return_objects[i]; return_object->set_size(result->GetSize()); @@ -510,19 +509,46 @@ void CoreWorkerDirectTaskReceiver::HandleTask( RAY_CHECK_OK(task_done_()); } } - if (status.ShouldExitWorker()) { - // Don't allow the worker to be reused, even though the reply status is OK. - // The worker will be shutting down shortly. - reply->set_worker_exiting(true); - if (objects_valid) { - // This happens when max_calls is hit. We still need to return the objects. - send_reply_callback(Status::OK(), nullptr, nullptr); + + if (RayConfig::instance().ownership_transfer_enabled()) { + // Pin the object in raylet + auto &caller_addr = task_spec.CallerAddress(); + transfer_handler_( + caller_addr, return_contained_ids, + [status, send_reply_callback, objects_valid, reply, num_returns]( + google::protobuf::RepeatedPtrField shared_objs) { + reply->mutable_shared_obj_info()->Swap(&shared_objs); + if (status.ShouldExitWorker()) { + // Don't allow the worker to be reused, even though the reply + // status is OK. The worker will be shutting down shortly. + reply->set_worker_exiting(true); + if (objects_valid) { + // This happens when max_calls is hit. We still need to + // return the objects. + send_reply_callback(Status::OK(), nullptr, nullptr); + } else { + send_reply_callback(status, nullptr, nullptr); + } + } else { + RAY_CHECK(objects_valid); + send_reply_callback(status, nullptr, nullptr); + } + }); + } else { + if (status.ShouldExitWorker()) { + // Don't allow the worker to be reused, even though the reply status is OK. + // The worker will be shutting down shortly. + reply->set_worker_exiting(true); + if (objects_valid) { + // This happens when max_calls is hit. We still need to return the objects. + send_reply_callback(Status::OK(), nullptr, nullptr); + } else { + send_reply_callback(status, nullptr, nullptr); + } } else { + RAY_CHECK(objects_valid) << return_objects.size() << " " << num_returns; send_reply_callback(status, nullptr, nullptr); } - } else { - RAY_CHECK(objects_valid) << return_objects.size() << " " << num_returns; - send_reply_callback(status, nullptr, nullptr); } }; diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 71761cee6ff3..be81ff9eb5b1 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -611,18 +611,23 @@ class CoreWorkerDirectTaskReceiver { std::function resource_ids, std::vector> *return_objects, + std::vector *return_contained_ids, ReferenceCounter::ReferenceTableProto *borrower_refs)>; using OnTaskDone = std::function; - + using ObjectTransfer = std::function &, + std::function)>)>; CoreWorkerDirectTaskReceiver(WorkerContext &worker_context, instrumented_io_context &main_io_service, const TaskHandler &task_handler, - const OnTaskDone &task_done) + const OnTaskDone &task_done, + const ObjectTransfer &transfer_handler) : worker_context_(worker_context), task_handler_(task_handler), task_main_io_service_(main_io_service), - task_done_(task_done) {} + task_done_(task_done), + transfer_handler_(transfer_handler) {} /// Initialize this receiver. This must be called prior to use. void Init(std::shared_ptr, rpc::Address rpc_address, @@ -671,7 +676,7 @@ class CoreWorkerDirectTaskReceiver { std::shared_ptr pool_; /// Whether this actor use asyncio for concurrency. bool is_asyncio_ = false; - + ObjectTransfer transfer_handler_; /// Set the max concurrency of an actor. /// This should be called once for the actor creation task. void SetMaxActorConcurrency(bool is_asyncio, int max_concurrency); diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 03bcd225e7d1..3c0bcf3e325a 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -112,6 +112,7 @@ message PushTaskReply { // may now be borrowing. The reference counts also include any new borrowers // that the worker created by passing a borrowed ID into a nested task. repeated ObjectReferenceCount borrowed_refs = 3; + repeated SharedObjectInfo shared_obj_info = 4; } message DirectActorCallArgWaitCompleteRequest { @@ -362,6 +363,21 @@ message RunOnUtilWorkerRequest { message RunOnUtilWorkerReply { } +message ShareOwnershipRequest { + repeated bytes object_ids = 1; + Address new_owner_address = 2; +} + +message SharedObjectInfo { + bytes object_id = 1; + bytes pinned_at_node = 2; + uint64 object_size = 3; +} + +message ShareOwnershipReply { + repeated SharedObjectInfo shared_objs = 1; +} + service CoreWorkerService { // Push a task directly to this worker from another. rpc PushTask(PushTaskRequest) returns (PushTaskReply); @@ -417,4 +433,5 @@ service CoreWorkerService { rpc RunOnUtilWorker(RunOnUtilWorkerRequest) returns (RunOnUtilWorkerReply); // Request for a worker to exit. rpc Exit(ExitRequest) returns (ExitReply); + rpc ShareOwnership(ShareOwnershipRequest) returns (ShareOwnershipReply); } diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index ec167e3af7cc..0e92e5dde11a 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -187,6 +187,9 @@ class CoreWorkerClientInterface { const DeleteSpilledObjectsRequest &request, const ClientCallback &callback) {} + virtual void ShareOwnership(const ShareOwnershipRequest &request, + const ClientCallback &callback) {} + virtual void AddSpilledUrl(const AddSpilledUrlRequest &request, const ClientCallback &callback) {} @@ -257,6 +260,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, VOID_RPC_CLIENT_METHOD(CoreWorkerService, DeleteSpilledObjects, grpc_client_, override) + VOID_RPC_CLIENT_METHOD(CoreWorkerService, ShareOwnership, grpc_client_, override) + VOID_RPC_CLIENT_METHOD(CoreWorkerService, AddSpilledUrl, grpc_client_, override) VOID_RPC_CLIENT_METHOD(CoreWorkerService, RunOnUtilWorker, grpc_client_, override) diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index 04b9a720c061..918e39e7f98f 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -45,6 +45,7 @@ namespace rpc { RPC_SERVICE_HANDLER(CoreWorkerService, SpillObjects) \ RPC_SERVICE_HANDLER(CoreWorkerService, RestoreSpilledObjects) \ RPC_SERVICE_HANDLER(CoreWorkerService, DeleteSpilledObjects) \ + RPC_SERVICE_HANDLER(CoreWorkerService, ShareOwnership) \ RPC_SERVICE_HANDLER(CoreWorkerService, AddSpilledUrl) \ RPC_SERVICE_HANDLER(CoreWorkerService, PlasmaObjectReady) \ RPC_SERVICE_HANDLER(CoreWorkerService, RunOnUtilWorker) \ @@ -67,6 +68,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(LocalGC) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(SpillObjects) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RestoreSpilledObjects) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(ShareOwnership) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DeleteSpilledObjects) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(AddSpilledUrl) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PlasmaObjectReady) \