🐛 Bug Description
In torchrl.collectors.distributed.ray.RayCollector._async_iterator, the update_policy_weights_() method is called with worker_ids=collector_index + 1. However, collector_index is an index into the self.remote_collectors list, and adding 1 to it may result in an IndexError downstream when evaluating RayWeightUpdater._skip_update().
🔍 Relevant Code (TorchRL v0.8.1)
# Inside RayCollector._async_iterator()
if self.update_after_each_batch or self.max_weight_update_interval > -1:
torchrl_logger.info(f"Updating weights on worker {collector_index}")
self.update_policy_weights_(worker_ids=collector_index + 1)
Why This is a Bug
collector_index comes directly from enumerate(self.remote_collectors) and is in the range [0, len(self.remote_collectors) - 1]. Adding +1 causes an out-of-range access when passed to self.update_policy_weights_() and eventually passed to RayWeightUpdater._skip_update which uses it to index into an array of self._batches_since_weight_update which is of length len(self.remote_collectors).
This likely came from a transition of the argument from worker_rank to worker_id . See this earlier merged PR. It used to be the case that update_policy_weights accepted a worker_rank argument which had value >= 1. See L883-907 in torchrl/collectors/distributed/generic.py.
Suggested Fix
Replace
self.update_policy_weights_(worker_ids=collector_index + 1)
with
self.update_policy_weights_(worker_ids=collector_index)
Reproduction Instructions
Create a clean virtualenv, install torchrl==0.8.1.0, ray==2.47.0 , and gymnasium==1.1.1.
Run the following script:
import ray
from torchrl.envs import GymEnv
from torch import nn
from torchrl.collectors.distributed.ray import RayCollector
from tensordict.nn import TensorDictModule
import random
import time
def create_env():
return GymEnv("CartPole-v1")
# A simple policy compatible with TorchRL
def create_policy():
module = nn.Sequential(
nn.Linear(4, 32),
nn.ReLU(),
nn.Linear(32, 2),
)
time.sleep(random.randint(0, 3))
return TensorDictModule(
module=module,
in_keys=["observation"],
out_keys=["action"],
)
# Ray setup
ray.init(ignore_reinit_error=True, include_dashboard=False, log_to_driver=False)
# === Trigger the bug ===
collector = RayCollector(
create_env_fn=create_env,
policy_factory=create_policy,
frames_per_batch=8,
total_frames=64,
update_after_each_batch=True, # This will trigger the problematic weight sync
num_collectors=2,
sync=False, # Async mode so one collector finishes before the other
)
try:
for _ in collector:
print("Collected batch")
except IndexError as e:
print("\n🔥🔥🔥 Caught IndexError as expected due to off-by-one bug! 🔥🔥🔥")
print(e)
finally:
collector.shutdown()
ray.shutdown()
The random sleep is to make it so that sometimes the second worker is the first to finish. (a bit of pain to get it to consistently trigger). If worker zero finishes first, the issue does not appear.
🐛 Bug Description
In
torchrl.collectors.distributed.ray.RayCollector._async_iterator, theupdate_policy_weights_()method is called withworker_ids=collector_index + 1. However,collector_indexis an index into theself.remote_collectorslist, and adding 1 to it may result in an IndexError downstream when evaluatingRayWeightUpdater._skip_update().🔍 Relevant Code (TorchRL v0.8.1)
Why This is a Bug
collector_indexcomes directly from enumerate(self.remote_collectors) and is in the range [0, len(self.remote_collectors) - 1]. Adding +1 causes an out-of-range access when passed to self.update_policy_weights_() and eventually passed toRayWeightUpdater._skip_updatewhich uses it to index into an array ofself._batches_since_weight_updatewhich is of lengthlen(self.remote_collectors).This likely came from a transition of the argument from
worker_ranktoworker_id. See this earlier merged PR. It used to be the case thatupdate_policy_weightsaccepted aworker_rankargument which had value >= 1. See L883-907 in torchrl/collectors/distributed/generic.py.Suggested Fix
Replace
with
Reproduction Instructions
Create a clean virtualenv, install
torchrl==0.8.1.0,ray==2.47.0, andgymnasium==1.1.1.Run the following script:
The random sleep is to make it so that sometimes the second worker is the first to finish. (a bit of pain to get it to consistently trigger). If worker zero finishes first, the issue does not appear.