diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index f2db7b32e8e..c43b88060cf 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -487,7 +487,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): if x.device.type in ("cpu",): x.share_memory_() if x.device.type in ("mps",): - RuntimeError(MPS_ERROR) + raise RuntimeError(MPS_ERROR) collected_tensordict.apply(cast_tensor, filter_empty=True) data = (collected_tensordict, idx) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 2f8930207f6..dc2df7eb3e1 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -1010,7 +1010,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: if self.update_after_each_batch or self.max_weight_update_interval > -1: torchrl_logger.debug(f"Updating weights on worker {collector_index}") - self.update_policy_weights_(worker_ids=collector_index + 1) + self.update_policy_weights_(worker_ids=collector_index) # Schedule a new collection task future = collector.next.remote() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 8b67103c416..310b1ecd44c 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1229,9 +1229,7 @@ def __torch_function__( if func not in cls.SPEC_HANDLED_FUNCTIONS or not all( issubclass(t, (TensorSpec,)) for t in types ): - return NotImplementedError( - f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}" - ) + return NotImplemented return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) def unbind(self, dim: int = 0):