From faa0a24816024b7b153ff1ef245bd13ad90e4f48 Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Tue, 24 Feb 2026 16:01:15 -0800 Subject: [PATCH] [BugFix] Fix missing raise, incorrect return, and off-by-one in worker_ids 1. collectors/_runner.py: Add missing `raise` before `RuntimeError(MPS_ERROR)`. The exception was constructed but never raised, silently allowing execution to continue when tensors are on MPS devices. 2. data/tensor_specs.py: Fix `__torch_function__` to return `NotImplemented` instead of `return NotImplementedError(...)`. The previous code returned an exception object as a value rather than following the standard `__torch_function__` protocol. 3. collectors/distributed/ray.py: Fix off-by-one in `_async_iterator` where `collector_index + 1` was passed to `update_policy_weights_()`. `collector_index` is already 0-indexed from `pending_tasks.pop(future)`, so the +1 causes weight updates to target the wrong worker. Fixes #3000 --- torchrl/collectors/_runner.py | 2 +- torchrl/collectors/distributed/ray.py | 2 +- torchrl/data/tensor_specs.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index f2db7b32e8e..c43b88060cf 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -487,7 +487,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): if x.device.type in ("cpu",): x.share_memory_() if x.device.type in ("mps",): - RuntimeError(MPS_ERROR) + raise RuntimeError(MPS_ERROR) collected_tensordict.apply(cast_tensor, filter_empty=True) data = (collected_tensordict, idx) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 2f8930207f6..dc2df7eb3e1 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -1010,7 +1010,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: if self.update_after_each_batch or self.max_weight_update_interval > -1: torchrl_logger.debug(f"Updating weights on worker {collector_index}") - self.update_policy_weights_(worker_ids=collector_index + 1) + self.update_policy_weights_(worker_ids=collector_index) # Schedule a new collection task future = collector.next.remote() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 8b67103c416..310b1ecd44c 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1229,9 +1229,7 @@ def __torch_function__( if func not in cls.SPEC_HANDLED_FUNCTIONS or not all( issubclass(t, (TensorSpec,)) for t in types ): - return NotImplementedError( - f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}" - ) + return NotImplemented return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) def unbind(self, dim: int = 0):