Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrl/collectors/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
if x.device.type in ("cpu",):
x.share_memory_()
if x.device.type in ("mps",):
RuntimeError(MPS_ERROR)
raise RuntimeError(MPS_ERROR)

collected_tensordict.apply(cast_tensor, filter_empty=True)
data = (collected_tensordict, idx)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:

if self.update_after_each_batch or self.max_weight_update_interval > -1:
torchrl_logger.debug(f"Updating weights on worker {collector_index}")
self.update_policy_weights_(worker_ids=collector_index + 1)
self.update_policy_weights_(worker_ids=collector_index)

# Schedule a new collection task
future = collector.next.remote()
Expand Down
4 changes: 1 addition & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,7 @@ def __torch_function__(
if func not in cls.SPEC_HANDLED_FUNCTIONS or not all(
issubclass(t, (TensorSpec,)) for t in types
):
return NotImplementedError(
f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}"
)
return NotImplemented
return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs)

def unbind(self, dim: int = 0):
Expand Down
Loading