diff --git a/test/test_rb.py b/test/test_rb.py index f56baf6a1e9..571582f19e9 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -830,6 +830,8 @@ def test_smoke_replay_buffer_transform(transform): @pytest.mark.parametrize("transform", transforms) def test_smoke_replay_buffer_transform_no_inkeys(transform): + if PinMemoryTransform is PinMemoryTransform and not torch.cuda.is_available(): + raise pytest.skip("No CUDA device detected, skipping PinMemory") rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=transform()) td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f56c6ae171a..a79de3a1e81 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2632,19 +2632,18 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) for in_key, out_key in zip(self.in_keys, self.out_keys): if out_key in tensordict.keys(): - z = torch.zeros_like(tensordict[out_key]) - _reset = _reset.view_as(z) - tensordict[out_key][_reset] = z[_reset] + value = tensordict[out_key] + dtype = value.dtype + tensordict[out_key] = value * (~_reset).to(dtype) elif in_key == "reward": # Since the episode reward is not in the tensordict, we need to allocate it # with zeros entirely (regardless of the _reset mask) - z = self.parent.reward_spec.zero(self.parent.batch_size) - tensordict[out_key] = z + tensordict[out_key] = self.parent.reward_spec.zero() else: try: - tensordict[out_key] = self.parent.observation_spec[in_key].zero( - self.parent.batch_size - ) + tensordict[out_key] = self.parent.observation_spec[ + in_key + ].zero() except KeyError as err: raise KeyError( f"The key {in_key} was not found in the parent "