Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import ReplayActor
from ray.rllib.execution.replay_buffer import ReplayActor, VanillaReplayActor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated

Expand Down Expand Up @@ -88,15 +88,21 @@ def __call__(self, item: ("ActorHandle", SampleBatchType)):
def apex_execution_plan(workers: WorkerSet, config: dict):
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [
replay_actor_cls = ReplayActor if config[
"prioritized_replay"] else VanillaReplayActor
replay_actors = create_colocated(
replay_actor_cls,
[
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
],
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
], num_replay_buffer_shards)
)

# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())
Expand All @@ -105,7 +111,8 @@ def apex_execution_plan(workers: WorkerSet, config: dict):
# Update experience priorities post learning.
def update_prio_and_stats(item: ("ActorHandle", dict, int)):
actor, prio_dict, count = item
actor.update_priorities.remote(prio_dict)
if config["prioritized_replay"]:
actor.update_priorities.remote(prio_dict)
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner thread
# is executing outside the pipeline.
Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/sac/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
# __sphinx_doc_end__
# yapf: enable


ApexSACTrainer = SACTrainer.with_updates(
name="APEX_SAC", default_config=APEX_SAC_DEFAULT_CONFIG, execution_plan=apex_execution_plan
name="APEX_SAC",
default_config=APEX_SAC_DEFAULT_CONFIG,
execution_plan=apex_execution_plan,
)
105 changes: 105 additions & 0 deletions rllib/execution/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,109 @@ def stats(self, debug=False):
return stat


# Visible for testing.
_local_vanilla_replay_buffer = None


class LocalVanillaReplayBuffer(LocalReplayBuffer):
"""A replay buffer shard.

Ray actors are single-threaded, so for scalability multiple replay actors
may be created to increase parallelism."""

def __init__(
self,
num_shards,
learning_starts,
buffer_size,
replay_batch_size,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
multiagent_sync_replay=False,
):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.replay_batch_size = replay_batch_size
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.multiagent_sync_replay = multiagent_sync_replay

def gen_replay():
while True:
yield self.replay()

ParallelIteratorWorker.__init__(self, gen_replay, False)

def new_buffer():
return ReplayBuffer(self.buffer_size)

self.replay_buffers = collections.defaultdict(new_buffer)

# Metrics
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
self.num_added = 0

# Make externally accessible for testing.
global _local_vanilla_replay_buffer
_local_vanilla_replay_buffer = self
# If set, return this instead of the usual data for testing.
self._fake_batch = None

@staticmethod
def get_instance_for_testing():
global _local_vanilla_replay_buffer
return _local_vanilla_replay_buffer

def replay(self):
if self._fake_batch:
fake_batch = SampleBatch(self._fake_batch)
return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch}, fake_batch.count)

if self.num_added < self.replay_starts:
return None

with self.replay_timer:
samples = {}
idxes = None
for policy_id, replay_buffer in self.replay_buffers.items():
if self.multiagent_sync_replay:
if idxes is None:
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
else:
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
(
obses_t,
actions,
rewards,
obses_tp1,
dones,
) = replay_buffer.sample_with_idxes(idxes)
samples[policy_id] = SampleBatch(
{
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
}
)
return MultiAgentBatch(samples, self.replay_batch_size)

def stats(self, debug=False):
stat = {
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
}
for policy_id, replay_buffer in self.replay_buffers.items():
stat.update(
{"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)}
)
return stat


VanillaReplayActor = ray.remote(num_cpus=0)(LocalVanillaReplayBuffer)

ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)
15 changes: 13 additions & 2 deletions rllib/tests/agents/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,20 @@ def astuple(self):
config_updates={
"num_workers": 8,
"exploration_config": {"type": "StochasticSampling"},
"no_done_at_end": True,
"prioritized_replay": False,
"no_done_at_end": False,
},
n_iter=200,
threshold=-350.,
),
TestAgentParams.for_pendulum(
algorithm=ContinuousActionSpaceAlgorithm.APEX_SAC,
config_updates={
"num_workers": 8,
"exploration_config": {"type": "StochasticSampling"},
"prioritized_replay": True,
"no_done_at_end": True
},
# TODO: Delete next line before landing PR
n_iter=200,
threshold=-350.,
),
Expand Down
2 changes: 1 addition & 1 deletion rllib/tests/agents/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_monotonically_improving_algorithms_can_converge_with_different_framewor
"""
learnt = False
episode_reward_mean = -float("inf")
for i in range(n_iter):
for _ in range(n_iter):
results = trainer.train()
episode_reward_mean = results["episode_reward_mean"]
if episode_reward_mean >= threshold:
Expand Down