diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index c3fd13008e47..bdc1ba512160 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -13,7 +13,7 @@ from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay from ray.rllib.execution.train_ops import UpdateTargetNetwork from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.execution.replay_buffer import ReplayActor +from ray.rllib.execution.replay_buffer import ReplayActor, VanillaReplayActor from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated @@ -88,15 +88,21 @@ def __call__(self, item: ("ActorHandle", SampleBatchType)): def apex_execution_plan(workers: WorkerSet, config: dict): # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] - replay_actors = create_colocated(ReplayActor, [ + replay_actor_cls = ReplayActor if config[ + "prioritized_replay"] else VanillaReplayActor + replay_actors = create_colocated( + replay_actor_cls, + [ + num_replay_buffer_shards, + config["learning_starts"], + config["buffer_size"], + config["train_batch_size"], + config["prioritized_replay_alpha"], + config["prioritized_replay_beta"], + config["prioritized_replay_eps"], + ], num_replay_buffer_shards, - config["learning_starts"], - config["buffer_size"], - config["train_batch_size"], - config["prioritized_replay_alpha"], - config["prioritized_replay_beta"], - config["prioritized_replay_eps"], - ], num_replay_buffer_shards) + ) # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) @@ -105,7 +111,8 @@ def apex_execution_plan(workers: WorkerSet, config: dict): # Update experience priorities post learning. def update_prio_and_stats(item: ("ActorHandle", dict, int)): actor, prio_dict, count = item - actor.update_priorities.remote(prio_dict) + if config["prioritized_replay"]: + actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. diff --git a/rllib/agents/sac/apex.py b/rllib/agents/sac/apex.py index 6465f8de7b39..f8ea337da690 100644 --- a/rllib/agents/sac/apex.py +++ b/rllib/agents/sac/apex.py @@ -41,6 +41,9 @@ # __sphinx_doc_end__ # yapf: enable + ApexSACTrainer = SACTrainer.with_updates( - name="APEX_SAC", default_config=APEX_SAC_DEFAULT_CONFIG, execution_plan=apex_execution_plan + name="APEX_SAC", + default_config=APEX_SAC_DEFAULT_CONFIG, + execution_plan=apex_execution_plan, ) diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index a1ecab230fa3..25d6e57d6b44 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -416,4 +416,109 @@ def stats(self, debug=False): return stat +# Visible for testing. +_local_vanilla_replay_buffer = None + + +class LocalVanillaReplayBuffer(LocalReplayBuffer): + """A replay buffer shard. + + Ray actors are single-threaded, so for scalability multiple replay actors + may be created to increase parallelism.""" + + def __init__( + self, + num_shards, + learning_starts, + buffer_size, + replay_batch_size, + prioritized_replay_alpha=0.6, + prioritized_replay_beta=0.4, + prioritized_replay_eps=1e-6, + multiagent_sync_replay=False, + ): + self.replay_starts = learning_starts // num_shards + self.buffer_size = buffer_size // num_shards + self.replay_batch_size = replay_batch_size + self.prioritized_replay_beta = prioritized_replay_beta + self.prioritized_replay_eps = prioritized_replay_eps + self.multiagent_sync_replay = multiagent_sync_replay + + def gen_replay(): + while True: + yield self.replay() + + ParallelIteratorWorker.__init__(self, gen_replay, False) + + def new_buffer(): + return ReplayBuffer(self.buffer_size) + + self.replay_buffers = collections.defaultdict(new_buffer) + + # Metrics + self.add_batch_timer = TimerStat() + self.replay_timer = TimerStat() + self.update_priorities_timer = TimerStat() + self.num_added = 0 + + # Make externally accessible for testing. + global _local_vanilla_replay_buffer + _local_vanilla_replay_buffer = self + # If set, return this instead of the usual data for testing. + self._fake_batch = None + + @staticmethod + def get_instance_for_testing(): + global _local_vanilla_replay_buffer + return _local_vanilla_replay_buffer + + def replay(self): + if self._fake_batch: + fake_batch = SampleBatch(self._fake_batch) + return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch}, fake_batch.count) + + if self.num_added < self.replay_starts: + return None + + with self.replay_timer: + samples = {} + idxes = None + for policy_id, replay_buffer in self.replay_buffers.items(): + if self.multiagent_sync_replay: + if idxes is None: + idxes = replay_buffer.sample_idxes(self.replay_batch_size) + else: + idxes = replay_buffer.sample_idxes(self.replay_batch_size) + ( + obses_t, + actions, + rewards, + obses_tp1, + dones, + ) = replay_buffer.sample_with_idxes(idxes) + samples[policy_id] = SampleBatch( + { + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + } + ) + return MultiAgentBatch(samples, self.replay_batch_size) + + def stats(self, debug=False): + stat = { + "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + } + for policy_id, replay_buffer in self.replay_buffers.items(): + stat.update( + {"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)} + ) + return stat + + +VanillaReplayActor = ray.remote(num_cpus=0)(LocalVanillaReplayBuffer) + ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer) diff --git a/rllib/tests/agents/parameters.py b/rllib/tests/agents/parameters.py index 4a24dc14724a..6eb5d9fb236d 100644 --- a/rllib/tests/agents/parameters.py +++ b/rllib/tests/agents/parameters.py @@ -283,9 +283,20 @@ def astuple(self): config_updates={ "num_workers": 8, "exploration_config": {"type": "StochasticSampling"}, - "no_done_at_end": True, + "prioritized_replay": False, + "no_done_at_end": False, + }, + n_iter=200, + threshold=-350., + ), + TestAgentParams.for_pendulum( + algorithm=ContinuousActionSpaceAlgorithm.APEX_SAC, + config_updates={ + "num_workers": 8, + "exploration_config": {"type": "StochasticSampling"}, + "prioritized_replay": True, + "no_done_at_end": True }, - # TODO: Delete next line before landing PR n_iter=200, threshold=-350., ), diff --git a/rllib/tests/agents/test_learning.py b/rllib/tests/agents/test_learning.py index 3e55f79bf583..beeb1e9ff881 100644 --- a/rllib/tests/agents/test_learning.py +++ b/rllib/tests/agents/test_learning.py @@ -77,7 +77,7 @@ def test_monotonically_improving_algorithms_can_converge_with_different_framewor """ learnt = False episode_reward_mean = -float("inf") - for i in range(n_iter): + for _ in range(n_iter): results = trainer.train() episode_reward_mean = results["episode_reward_mean"] if episode_reward_mean >= threshold: