diff --git a/cpprb/PyReplayBuffer.pyx b/cpprb/PyReplayBuffer.pyx index f441d890..eac55ea7 100644 --- a/cpprb/PyReplayBuffer.pyx +++ b/cpprb/PyReplayBuffer.pyx @@ -1,25 +1,26 @@ # distutils: language = c++ # cython: linetrace=True - +import base64 import ctypes +import multiprocessing +import multiprocessing as mp +import time +import warnings +from functools import partial from logging import getLogger, StreamHandler, Formatter, INFO -from multiprocessing import Event, Lock, Process +from multiprocessing import Event, Lock +from multiprocessing import shared_memory +from multiprocessing.managers import SyncManager from multiprocessing.sharedctypes import Value, RawValue, RawArray -import time from typing import Any, Dict, Callable, Optional -import warnings -cimport numpy as np -import numpy as np import cython -from cython.operator cimport dereference - +import numpy as np from cpprb.ReplayBuffer cimport * from .VectorWrapper cimport * -from .VectorWrapper import (VectorWrapper, - VectorInt,VectorSize_t, - VectorDouble,PointerDouble,VectorFloat) +from .VectorWrapper import (VectorSize_t, + PointerDouble, VectorFloat) def default_logger(level=INFO): """ @@ -464,10 +465,108 @@ cdef class SharedBuffer: return (SharedBuffer,(self.view.shape,self.dtype,self.data)) +@cython.embedsignature(True) +cdef class SharedMemoryBuffer: + cdef dtype + cdef data + cdef data_ndarray + cdef view + cdef str shm_name + cdef i_am_the_creator + cdef shape + + def __init__(self,shape,dtype,data=None,shm_name=None): + self.shm_name = shm_name + self.dtype = np.dtype(dtype) + self.i_am_the_creator = False + + + if isinstance(shape,np.ndarray): + self.shape = shape.tolist() + + self.shape = tuple(shape) + + if data is None: + + n_elems = int(np.array(shape,copy=False,dtype="int").prod()) + size = self.dtype.itemsize * n_elems + self.data = shared_memory.SharedMemory(name=shm_name, create=True, size=size) + self.i_am_the_creator = True + + elif data is not None: + self.data = data + + self.data_ndarray = np.ndarray(shape=self.shape, + dtype=self.dtype, + buffer=self.data.buf) + + + # Reinterpretation + if self.dtype != self.data_ndarray.dtype: + self.view = self.data_ndarray.view(self.dtype) + else: + self.view = self.data_ndarray + + + + + def __getitem__(self,key): + return self.view[key] + + def __setitem__(self,key,value): + self.view[key] = value + + # def __reduce__(self): + # return SharedMemoryBuffer,self.view.shape,self.dtype,None,self.shm_name,True + + def __getstate__(self): + # Copy the object's state from self.__dict__ which contains + # all our instance attributes. Always use the dict.copy() + # method to avoid modifying the original state. + state = {"shm_name": self.shm_name, + "dtype": self.dtype, + "i_am_the_creator": False, + "shape": self.shape} + + # Remove the unpicklable entries. + return state + + def __setstate__(self, state): + # Restore instance attributes. + + self.shm_name = state["shm_name"] + self.dtype = state["dtype"] + self.i_am_the_creator = False + self.shape = state["shape"] + + # Restore shared memory + self.data = shared_memory.SharedMemory(name=self.shm_name, create=False) + + self.data_ndarray = np.ndarray(shape=self.shape, + dtype=self.dtype, + buffer=self.data.buf) + + # Reinterpretation + if self.dtype != self.data_ndarray.dtype: + self.view = self.data_ndarray.view(self.dtype) + else: + self.view = self.data_ndarray + + + def __del__(self): + if self.data is not None: + self.data.close() + if self.i_am_the_creator: + self.data.unlink() + + + + def dict2buffer(buffer_size: int,env_dict: Dict,*, stack_compress = None, default_dtype = None, mmap_prefix: Optional[str] = None, - shared: bool = False): + shared: bool = False, + shm_name = None): """Create buffer from env_dict Parameters @@ -483,6 +582,8 @@ def dict2buffer(buffer_size: int,env_dict: Dict,*, mmap_prefix : str, optional File name prefix to save buffer data using mmap. If `None` (default), save only on memory. + shm_name : str, optional + multiprocessing.SharedMemory string name Returns ------- @@ -494,9 +595,12 @@ def dict2buffer(buffer_size: int,env_dict: Dict,*, default_dtype = default_dtype or np.single def zeros(name,shape,dtype): - if shared: + if shared and shm_name is None: return SharedBuffer(shape,dtype) + if shm_name: + return SharedMemoryBuffer(shape=shape,dtype=dtype,shm_name=shm_name+"."+name) + if mmap_prefix: if not isinstance(shape,tuple): shape = tuple(shape) @@ -527,6 +631,7 @@ def dict2buffer(buffer_size: int,env_dict: Dict,*, shape[0] = -1 defs["add_shape"] = shape + return buffer def find_array(dict,key): @@ -835,10 +940,16 @@ cdef class RingBufferIndex: cdef buffer_size cdef is_full - def __init__(self,buffer_size): - self.index = RawValue(ctypes.c_size_t,0) - self.buffer_size = RawValue(ctypes.c_size_t,buffer_size) - self.is_full = RawValue(ctypes.c_int,0) + def __init__(self,buffer_size,m=None): + + if m is not None: + self.index = m.Value(ctypes.c_size_t,0) + self.buffer_size = m.Value(ctypes.c_size_t,buffer_size) + self.is_full = m.Value(ctypes.c_int,0) + else: + self.index = RawValue(ctypes.c_size_t, 0) + self.buffer_size = RawValue(ctypes.c_size_t, buffer_size) + self.is_full = RawValue(ctypes.c_int, 0) cdef size_t get_next_index(self): return self.index.value @@ -884,9 +995,12 @@ cdef class ProcessSafeRingBufferIndex(RingBufferIndex): """ cdef lock - def __init__(self,buffer_size): - super().__init__(buffer_size) - self.lock = Lock() + def __init__(self,buffer_size,m=None): + super().__init__(buffer_size,m) + if m is not None: + self.lock = m.Lock() + else: + self.lock = Lock() cdef size_t get_next_index(self): with self.lock: @@ -1843,9 +1957,13 @@ cdef class MPReplayBuffer: cdef StepChecker size_check cdef explorer_ready cdef explorer_count + cdef explorer_count_lock cdef learner_ready + cdef shm_name + cdef sync_manager_owner + cdef memory_manager - def __init__(self,size,env_dict=None,*,default_dtype=None,logger=None,**kwargs): + def __init__(self,size,env_dict=None,*,default_dtype=None,logger=None,shm_name=None,**kwargs): r"""Initialize ReplayBuffer Parameters @@ -1859,36 +1977,84 @@ cdef class MPReplayBuffer: default_dtype : numpy.dtype, optional fallback dtype for not specified in `env_dict`. default is numpy.single """ + + self.shm_name = shm_name + self.memory_manager = SyncManager() + self.memory_manager.start() + self.sync_manager_owner = True + self.env_dict = env_dict.copy() if env_dict else {} + cdef special_keys = [] self.buffer_size = size - self.index = ProcessSafeRingBufferIndex(self.buffer_size) + self.index = ProcessSafeRingBufferIndex(self.buffer_size, self.memory_manager) self.default_dtype = default_dtype or np.single # side effect: Add "add_shape" key into self.env_dict self.buffer = dict2buffer(self.buffer_size,self.env_dict, default_dtype = self.default_dtype, - shared = True) + shared = True, + shm_name=shm_name) self.size_check = StepChecker(self.env_dict,special_keys) - self.learner_ready = Event() + self.learner_ready = self.memory_manager.Event() self.learner_ready.clear() - self.explorer_ready = Event() + self.explorer_ready = self.memory_manager.Event() self.explorer_ready.set() + self.explorer_count = self.memory_manager.Value(ctypes.c_size_t, 0) + self.explorer_count_lock = self.memory_manager.Lock() + + + def __getstate__(self): + + state = dict() + + # Save instance persistent attributes. + state["shm_name"]= self.shm_name + state["buffer"] = self.buffer + state["buffer_size"] = self.buffer_size + state["env_dict"] = self.env_dict + + state["index"] = self.index + state["default_dtype"] = self.default_dtype + state["size_check"] = self.size_check + + state["explorer_ready"] = self.explorer_ready + state["explorer_count"] = self.explorer_count + state["explorer_count_lock"] = self.explorer_count_lock + state["learner_ready"] = self.learner_ready + + return state + + def __setstate__(self, state): + + # Restore instance attributes. + self.shm_name = state["shm_name"] + self.buffer = state["buffer"] + self.buffer_size = state["buffer_size"] + self.env_dict = state["env_dict"] + + self.index = state["index"] + self.default_dtype = state["default_dtype"] + self.size_check = state["size_check"] + self.explorer_ready = state["explorer_ready"] + self.explorer_count = state["explorer_count"] + self.explorer_count_lock = state["explorer_count_lock"] + self.learner_ready = state["learner_ready"] + self.sync_manager_owner = False - self.explorer_count = Value(ctypes.c_size_t,0) cdef void _lock_explorer(self) except *: self.explorer_ready.wait() # Wait permission self.learner_ready.clear() # Block learner - with self.explorer_count.get_lock(): + with self.explorer_count_lock: self.explorer_count.value += 1 cdef void _unlock_explorer(self) except *: - with self.explorer_count.get_lock(): + with self.explorer_count_lock: self.explorer_count.value -= 1 if self.explorer_count.value == 0: self.learner_ready.set() @@ -2081,6 +2247,13 @@ cdef class MPReplayBuffer: """ return False + def __del__(self): + if self.sync_manager_owner: + self.memory_manager.shutdown() + + + + cdef class ThreadSafePrioritizedSampler: cdef size_t size @@ -2162,6 +2335,7 @@ cdef class MPPrioritizedReplayBuffer(MPReplayBuffer): cdef helper cdef terminate cdef explorer_per_count + cdef explorer_per_count_lock cdef learner_per_ready cdef explorer_per_ready cdef vector[size_t] idx_vec @@ -2203,32 +2377,75 @@ cdef class MPPrioritizedReplayBuffer(MPReplayBuffer): self.weights = VectorFloat() self.indexes = VectorSize_t() - shm = RawArray(np.ctypeslib.as_ctypes_type(np.bool_), - int(np.array(size,copy=False,dtype='int').prod())) + shm = self.memory_manager.Array('b', + np.ones(shape=size, dtype='int').flatten()) + self.unchange_since_sample = np.ctypeslib.as_array(shm) self.unchange_since_sample[:] = True self.helper = None - self.terminate = Value(ctypes.c_bool) + self.terminate = self.memory_manager.Value(ctypes.c_bool,0) self.terminate.value = False - self.learner_per_ready = Event() + self.learner_per_ready = self.memory_manager.Event() self.learner_per_ready.clear() - self.explorer_per_ready = Event() + self.explorer_per_ready = self.memory_manager.Event() self.explorer_per_ready.set() - self.explorer_per_count = Value(ctypes.c_size_t,0) + self.explorer_per_count = self.memory_manager.Value(ctypes.c_size_t, 0) + self.explorer_per_count_lock = self.memory_manager.Lock() + self.idx_vec = vector[size_t]() self.ps_vec = vector[float]() + + def __getstate__(self): + + state = super().__getstate__() + + # Save instance persistent attributes. + state["weights"] = self.weights + state["indexes"] = self.indexes + state["per"] = self.per + state["unchange_since_sample"] = self.unchange_since_sample + state["helper"] = self.helper + state["terminate"] = self.terminate + state["explorer_per_count"] = self.explorer_per_count + state["explorer_per_count_lock"] = self.explorer_per_count_lock + state["learner_per_ready"] = self.learner_per_ready + state["explorer_per_ready"] = self.explorer_per_ready + state["idx_vec"] = self.idx_vec + state["ps_vec"] = self.ps_vec + + return state + + def __setstate__(self, state): + + super().__setstate__(state) + + # Restore instance attributes. + self.weights = state["weights"] + self.indexes = state["indexes"] + self.per = state["per"] + self.unchange_since_sample = state["unchange_since_sample"] + self.helper = state["helper"] + self.terminate = state["terminate"] + self.explorer_per_count = state["explorer_per_count"] + self.explorer_per_count_lock = state["explorer_per_count_lock"] + self.learner_per_ready = state["learner_per_ready"] + self.explorer_per_ready = state["explorer_per_ready"] + self.idx_vec = state["idx_vec"] + self.ps_vec = state["ps_vec"] + + cdef void _lock_explorer_per(self) except *: self.explorer_per_ready.wait() # Wait permission self.learner_per_ready.clear() # Block learner - with self.explorer_per_count.get_lock(): + with self.explorer_per_count_lock: self.explorer_per_count.value += 1 cdef void _unlock_explorer_per(self) except *: - with self.explorer_per_count.get_lock(): + with self.explorer_per_count_lock: self.explorer_per_count.value -= 1 if self.explorer_per_count.value == 0: self.learner_per_ready.set()