From 7a0da9fd6954abfea5508b16d2326ce9fa6dd509 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 11:32:21 +0000 Subject: [PATCH 01/19] init --- torchrl/data/datasets/d4rl.py | 21 +++-- torchrl/data/datasets/minari.py | 151 ++++++++++++++++++++++++++++++++ torchrl/data/datasets/utils.py | 9 ++ 3 files changed, 170 insertions(+), 11 deletions(-) create mode 100644 torchrl/data/datasets/minari.py create mode 100644 torchrl/data/datasets/utils.py diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index b5fd63696a3..f6314f06e11 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -19,6 +19,8 @@ from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS + +from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -410,17 +412,14 @@ def _filepath_from_url(dataset_url): return dataset_filepath -def _set_dataset_path(path): - global DATASET_PATH - DATASET_PATH = path - os.makedirs(path, exist_ok=True) - - -_set_dataset_path( - os.environ.get( - "D4RL_DATASET_DIR", os.path.expanduser("~/.cache/torchrl/data/d4rl/datasets") - ) -) +# def _set_dataset_path(path): +# global DATASET_PATH +# DATASET_PATH = path +# os.makedirs(path, exist_ok=True) +# +# +# _set_dataset_path( +# os.environ.get(_get_root_dir("d4rl"))) if __name__ == "__main__": data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) diff --git a/torchrl/data/datasets/minari.py b/torchrl/data/datasets/minari.py new file mode 100644 index 00000000000..e7adbc84921 --- /dev/null +++ b/torchrl/data/datasets/minari.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import os.path +import tempfile +from pathlib import Path + +import torch +from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict +from torchrl._utils import KeyDependentDefaultDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import TensorStorage + +_NAME_MATCH = KeyDependentDefaultDict(lambda key: key) +_NAME_MATCH["observations"] = "observation" +_NAME_MATCH["rewards"] = "reward" +_NAME_MATCH["truncations"] = "truncated" +_NAME_MATCH["terminations"] = "terminated" +_NAME_MATCH["actions"] = "action" +_NAME_MATCH["infos"] = "info" + + +class MinariExperienceReplay(TensorDictReplayBuffer): + def __init__( + self, + dataset_id, + batch_size: int, + *, + root: str | Path | None = None, + download: bool = True, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, + pin_memory: bool = False, + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + split_trajs: bool = False, + **env_kwargs, + ): + self.dataset_id = dataset_id + if root is None: + root = _get_root_dir("minari") + os.makedirs(root, exist_ok=True) + self.root = root + self.split_trajs = split_trajs + self.download = download + if self.download and not self._is_downloaded(): + storage = self._download_and_preproc() + elif self.split_trajs and not os.path.exists(self.data_path): + storage = self._make_split() + else: + storage = self._load() + storage = TensorStorage(storage) + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + ) + + def _is_downloaded(self): + return os.path.exists(self.data_path) + + @property + def data_path(self): + if self.split_trajs: + return Path(self.root) / (self.dataset_id + "_split") + return self.data_path_root + + @property + def data_path_root(self): + return Path(self.root) / self.dataset_id + + def _download_and_preproc(self): + import minari + + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["MINARI_DATASETS_PATH"] = tmpdir + minari.download_dataset(dataset_id=self.dataset_id) + dataset = minari.load_dataset(self.dataset_id) + h5_data = PersistentTensorDict.from_h5( + Path(tmpdir) / self.dataset_id / "data/main_data.hdf5" + ) + + # Get the total number of steps for the dataset + total_steps = sum( + h5_data[episode, "actions"].shape[0] for episode in h5_data.keys() + ) + # populate the tensordict + td_data = TensorDict({}, []) + for key, episode in h5_data.items(): + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ("observations", "state"): + td_data.set(("next", match), torch.zeros_like(val)[0]) + td_data.set(match, torch.zeros_like(val)[0]) + elif key not in ("terminations", "truncations", "rewards"): + td_data.set(match, torch.zeros_like(val)[0]) + else: + td_data.set( + ("next", match), torch.zeros_like(val)[0].unsqueeze(-1) + ) + break + # give it the proper size + td_data = td_data.expand(total_steps) + # save to designated location + td_data.memmap_(self.data_path_root) + # iterate over episodes and populate the tensordict + index = 0 + for key, episode in h5_data.items(): + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ( + "observations", + "state", + ): + steps = val.shape[0] - 1 + td_data["next", match][index : (index + steps)] = val[1:] + td_data[match][index : (index + steps)] = val[:-1] + elif key not in ("terminations", "truncations", "rewards"): + steps = val.shape[0] + td_data[match][index : (index + val.shape[0])] = val + else: + steps = val.shape[0] + td_data[("next", match)][ + index : (index + val.shape[0]) + ] = val.unsqueeze(-1) + index += steps + # Add a "done" entry + with td_data.unlock_(): + td_data["next", "done"] = MemoryMappedTensor.from_tensor( + (td_data["next", "terminated"] | td_data["next", "truncated"]) + ) + if self.split_trajs: + td_data = split_trajectories(td_data).memmap_(self.data_path) + return td_data + + def _make_split(self): + td_data = TensorDict.load_memmap(self.data_path_root) + td_data = split_trajectories(td_data).memmap_(self.data_path) + return td_data + + def _load(self): + return TensorDict.load_memmap(self.data_path) diff --git a/torchrl/data/datasets/utils.py b/torchrl/data/datasets/utils.py new file mode 100644 index 00000000000..b88e3aee14e --- /dev/null +++ b/torchrl/data/datasets/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os + + +def _get_root_dir(dataset: str): + return os.path.join(os.path.expanduser("~"), ".cache", "torchrl", dataset) From ae10b94099218c584acbb5cb0bd26bb29ad2399e Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 12:28:55 +0000 Subject: [PATCH 02/19] fix d4rl --- torchrl/data/datasets/d4rl.py | 161 +++++++++++++++++++------------- torchrl/data/datasets/minari.py | 11 ++- 2 files changed, 106 insertions(+), 66 deletions(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index f6314f06e11..1b596e29a49 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -6,15 +6,18 @@ import importlib import os +import tempfile import urllib import warnings + +from pathlib import Path from typing import Callable import numpy as np import torch -from tensordict import PersistentTensorDict +from tensordict import PersistentTensorDict, TensorDict from tensordict.tensordict import make_tensordict from torchrl.collectors.utils import split_trajectories @@ -23,7 +26,7 @@ from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, TensorStorage from torchrl.data.replay_buffers.writers import Writer @@ -105,7 +108,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): >>> from torchrl.envs import ObservationNorm >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128) >>> # we can append transforms to the dataset - >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0)) + >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0, in_keys=["observation"])) >>> data.sample(128) """ @@ -138,13 +141,16 @@ def __init__( use_truncated_as_done: bool = True, direct_download: bool = None, terminate_on_end: bool = None, + download: bool = True, + root: str | Path | None = None, **env_kwargs, ): self.use_truncated_as_done = use_truncated_as_done - - if not from_env and direct_download is None: - self._import_d4rl() - direct_download = not self._has_d4rl + if root is None: + root = _get_root_dir("d4rl") + self.root = root + self.name = name + dataset = None if not direct_download: if from_env is None: @@ -158,43 +164,59 @@ def __init__( ) from_env = True self.from_env = from_env - if terminate_on_end is None: - # we use the default of d4rl - terminate_on_end = False - self._import_d4rl() - - if not self._has_d4rl: - raise ImportError("Could not import d4rl") from self.D4RL_ERR - - if from_env: - dataset = self._get_dataset_from_env(name, env_kwargs) - else: - if self.use_truncated_as_done: - warnings.warn( - "Using use_truncated_as_done=True + terminate_on_end=True " - "with from_env=False may not have the intended effect " - "as the timeouts (truncation) " - "can be absent from the static dataset." - ) - env_kwargs.update({"terminate_on_end": terminate_on_end}) - dataset = self._get_dataset_direct(name, env_kwargs) else: if from_env is None: from_env = False self.from_env = from_env - if terminate_on_end is False: - raise ValueError( - "Using terminate_on_end=False is not compatible with direct_download=True." - ) - dataset = self._get_dataset_direct_download(name, env_kwargs) - # Fill unknown next states with 0 - dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 - if split_trajs: - dataset = split_trajectories(dataset) - dataset["next", "done"][:, -1] = True + if not from_env and direct_download is None: + self._import_d4rl() + direct_download = not self._has_d4rl + if download and not self._is_downloaded(): + if not direct_download: + if terminate_on_end is None: + # we use the default of d4rl + terminate_on_end = False + self._import_d4rl() + + if not self._has_d4rl: + raise ImportError("Could not import d4rl") from self.D4RL_ERR + + if from_env: + dataset = self._get_dataset_from_env(name, env_kwargs) + else: + if self.use_truncated_as_done: + warnings.warn( + "Using use_truncated_as_done=True + terminate_on_end=True " + "with from_env=False may not have the intended effect " + "as the timeouts (truncation) " + "can be absent from the static dataset." + ) + env_kwargs.update({"terminate_on_end": terminate_on_end}) + dataset = self._get_dataset_direct(name, env_kwargs) + else: + if terminate_on_end is False: + raise ValueError( + "Using terminate_on_end=False is not compatible with direct_download=True." + ) + dataset = self._get_dataset_direct_download(name, env_kwargs) + # Fill unknown next states with 0 + dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 + + if split_trajs: + dataset = split_trajectories(dataset) + dataset["next", "done"][:, -1] = True + + storage = LazyMemmapStorage( + dataset.shape[0], scratch_dir=Path(self.root) / name + ) + elif self._is_downloaded(): + storage = TensorStorage(TensorDict.load_memmap(Path(self.root) / name)) + else: + raise RuntimeError( + f"The dataset could not be found in {Path(self.root) / name}." + ) - storage = LazyMemmapStorage(dataset.shape[0]) super().__init__( batch_size=batch_size, storage=storage, @@ -205,7 +227,12 @@ def __init__( prefetch=prefetch, transform=transform, ) - self.extend(dataset) + if dataset is not None: + # if dataset has just been downloaded + self.extend(dataset) + + def _is_downloaded(self): + return os.path.exists(Path(self.root) / self.name) def _get_dataset_direct_download(self, name, env_kwargs): """Directly download and use a D4RL dataset.""" @@ -216,10 +243,12 @@ def _get_dataset_direct_download(self, name, env_kwargs): url = D4RL_DATASETS.get(name, None) if url is None: raise KeyError(f"Env {name} not found.") - h5path = _download_dataset_from_url(url) - # h5path_parent = Path(h5path).parent - dataset = PersistentTensorDict.from_h5(h5path) - dataset = dataset.to_tensordict() + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["D4RL_DATASET_DIR"] = tmpdir + h5path = _download_dataset_from_url(url, tmpdir) + # h5path_parent = Path(h5path).parent + dataset = PersistentTensorDict.from_h5(h5path) + dataset = dataset.to_tensordict() with dataset.unlock_(): dataset = self._process_data_from_env(dataset) return dataset @@ -235,15 +264,17 @@ def _get_dataset_direct(self, name, env_kwargs): import gym env = GymWrapper(gym.make(name)) - dataset = d4rl.qlearning_dataset(env._env, **env_kwargs) - - dataset = make_tensordict( - { - k: torch.from_numpy(item) - for k, item in dataset.items() - if isinstance(item, np.ndarray) - } - ) + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["D4RL_DATASET_DIR"] = tmpdir + dataset = d4rl.qlearning_dataset(env._env, **env_kwargs) + + dataset = make_tensordict( + { + k: torch.from_numpy(item) + for k, item in dataset.items() + if isinstance(item, np.ndarray) + } + ) dataset = dataset.unflatten_keys("/") if "metadata" in dataset.keys(): metadata = dataset.get("metadata") @@ -304,14 +335,16 @@ def _get_dataset_from_env(self, name, env_kwargs): # we do a local import to avoid circular import issues from torchrl.envs.libs.gym import GymWrapper - env = GymWrapper(gym.make(name)) - dataset = make_tensordict( - { - k: torch.from_numpy(item) - for k, item in env.get_dataset().items() - if isinstance(item, np.ndarray) - } - ) + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["D4RL_DATASET_DIR"] = tmpdir + env = GymWrapper(gym.make(name)) + dataset = make_tensordict( + { + k: torch.from_numpy(item) + for k, item in env.get_dataset().items() + if isinstance(item, np.ndarray) + } + ) dataset = dataset.unflatten_keys("/") dataset = self._process_data_from_env(dataset, env) return dataset @@ -396,8 +429,8 @@ def _shift_reward_done(self, dataset): dataset[key][0] = 0 -def _download_dataset_from_url(dataset_url): - dataset_filepath = _filepath_from_url(dataset_url) +def _download_dataset_from_url(dataset_url, dataset_path): + dataset_filepath = _filepath_from_url(dataset_url, dataset_path) if not os.path.exists(dataset_filepath): print("Downloading dataset:", dataset_url, "to", dataset_filepath) urllib.request.urlretrieve(dataset_url, dataset_filepath) @@ -406,9 +439,9 @@ def _download_dataset_from_url(dataset_url): return dataset_filepath -def _filepath_from_url(dataset_url): +def _filepath_from_url(dataset_url, dataset_path): _, dataset_name = os.path.split(dataset_url) - dataset_filepath = os.path.join(DATASET_PATH, dataset_name) + dataset_filepath = os.path.join(dataset_path, dataset_name) return dataset_filepath diff --git a/torchrl/data/datasets/minari.py b/torchrl/data/datasets/minari.py index e7adbc84921..2e97d739d80 100644 --- a/torchrl/data/datasets/minari.py +++ b/torchrl/data/datasets/minari.py @@ -7,13 +7,16 @@ import os.path import tempfile from pathlib import Path +from typing import Callable import torch from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer _NAME_MATCH = KeyDependentDefaultDict(lambda key: key) _NAME_MATCH["observations"] = "observation" @@ -84,7 +87,7 @@ def _download_and_preproc(self): with tempfile.TemporaryDirectory() as tmpdir: os.environ["MINARI_DATASETS_PATH"] = tmpdir minari.download_dataset(dataset_id=self.dataset_id) - dataset = minari.load_dataset(self.dataset_id) + minari.load_dataset(self.dataset_id) h5_data = PersistentTensorDict.from_h5( Path(tmpdir) / self.dataset_id / "data/main_data.hdf5" ) @@ -139,10 +142,14 @@ def _download_and_preproc(self): (td_data["next", "terminated"] | td_data["next", "truncated"]) ) if self.split_trajs: + from torchrl.objectives.utils import split_trajectories + td_data = split_trajectories(td_data).memmap_(self.data_path) return td_data def _make_split(self): + from torchrl.objectives.utils import split_trajectories + td_data = TensorDict.load_memmap(self.data_path_root) td_data = split_trajectories(td_data).memmap_(self.data_path) return td_data From 2de024779610099f6afdd92588fe3ff4c14902b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 13:52:20 +0000 Subject: [PATCH 03/19] amend d4rl --- torchrl/data/datasets/d4rl.py | 7 ++++--- torchrl/data/datasets/minari.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 1b596e29a49..2cd7645eff3 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -152,6 +152,10 @@ def __init__( self.name = name dataset = None + if not from_env and direct_download is None: + self._import_d4rl() + direct_download = not self._has_d4rl + if not direct_download: if from_env is None: warnings.warn( @@ -169,9 +173,6 @@ def __init__( from_env = False self.from_env = from_env - if not from_env and direct_download is None: - self._import_d4rl() - direct_download = not self._has_d4rl if download and not self._is_downloaded(): if not direct_download: if terminate_on_end is None: diff --git a/torchrl/data/datasets/minari.py b/torchrl/data/datasets/minari.py index 2e97d739d80..69c78ce556f 100644 --- a/torchrl/data/datasets/minari.py +++ b/torchrl/data/datasets/minari.py @@ -28,6 +28,16 @@ class MinariExperienceReplay(TensorDictReplayBuffer): + """Minari Experience replay dataset. + + Args: + dataset_id (str): + batch_size (int): + + Keyword Args: + root (Path or str, optional): + download (bool, optional): + """ def __init__( self, dataset_id, From d209c8b3adb884035658eb90e766b5827d42dff2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 17:06:17 +0000 Subject: [PATCH 04/19] amend --- .../linux_libs/scripts_minari/environment.yml | 20 + .../linux_libs/scripts_minari/install.sh | 51 +++ .../linux_libs/scripts_minari/post_process.sh | 6 + .../scripts_minari/run-clang-format.py | 356 ++++++++++++++++ .../linux_libs/scripts_minari/run_test.sh | 61 +++ .../linux_libs/scripts_minari/setup_env.sh | 50 +++ .github/workflows/test-linux-d4rl.yml | 1 + .github/workflows/test-linux-minari.yml | 42 ++ test/test_libs.py | 36 ++ torchrl/data/datasets/__init__.py | 1 + torchrl/data/datasets/d4rl.py | 6 + torchrl/data/datasets/minari.py | 168 -------- torchrl/data/datasets/minari_data.py | 403 ++++++++++++++++++ torchrl/data/replay_buffers/storages.py | 2 +- 14 files changed, 1034 insertions(+), 169 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_minari/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_minari/install.sh create mode 100755 .github/unittest/linux_libs/scripts_minari/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_minari/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_minari/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_minari/setup_env.sh create mode 100644 .github/workflows/test-linux-minari.yml delete mode 100644 torchrl/data/datasets/minari.py create mode 100644 torchrl/data/datasets/minari_data.py diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml new file mode 100644 index 00000000000..27963a42a24 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -0,0 +1,20 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - minari diff --git a/.github/unittest/linux_libs/scripts_minari/install.sh b/.github/unittest/linux_libs/scripts_minari/install.sh new file mode 100755 index 00000000000..2eb52b8f65e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_minari/post_process.sh b/.github/unittest/linux_libs/scripts_minari/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_minari/run-clang-format.py b/.github/unittest/linux_libs/scripts_minari/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh new file mode 100755 index 00000000000..3723399a859 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +# we install d4rl here bc env variables have been updated +git clone https://github.com/Farama-Foundation/d4rl.git +cd d4rl +#pip3 install -U 'mujoco-py<2.1,>=2.0' +pip3 install -U "gym[classic_control,atari,accept-rom-license]"==0.23 +pip3 install -U six +pip install -e . +cd .. + +#flow is a dependency disaster of biblical scale +#git clone https://github.com/flow-project/flow.git +#cd flow +#python setup.py develop +#cd .. + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +# this workflow only tests the libs +python -c "import gym, d4rl" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips +coverage combine +coverage xml -i + +## check what happens if we update gym +#pip install gym -U +#python -c """ +#from torchrl.data.datasets import D4RLExperienceReplay +#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=True) +#for batch in data: +# print(batch) +# break +# +#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=False) +#for batch in data: +# print(batch) +# break +# +#import d4rl +#import gym +#gym.make('halfcheetah-medium-v2') +#""" diff --git a/.github/unittest/linux_libs/scripts_minari/setup_env.sh b/.github/unittest/linux_libs/scripts_minari/setup_env.sh new file mode 100755 index 00000000000..5214617c2ac --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-d4rl.yml b/.github/workflows/test-linux-d4rl.yml index 3a0d534cd8e..ef986e34498 100644 --- a/.github/workflows/test-linux-d4rl.yml +++ b/.github/workflows/test-linux-d4rl.yml @@ -21,6 +21,7 @@ jobs: matrix: python_version: ["3.9"] cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl diff --git a/.github/workflows/test-linux-minari.yml b/.github/workflows/test-linux-minari.yml new file mode 100644 index 00000000000..8ef92de54e4 --- /dev/null +++ b/.github/workflows/test-linux-minari.yml @@ -0,0 +1,42 @@ +name: D4RL Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_minari/setup_env.sh + bash .github/unittest/linux_libs/scripts_minari/install.sh + bash .github/unittest/linux_libs/scripts_minari/run_test.sh + bash .github/unittest/linux_libs/scripts_minari/post_process.sh diff --git a/test/test_libs.py b/test/test_libs.py index c3379021510..7cdb786de6f 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -50,6 +50,7 @@ from torchrl._utils import implement_for from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( @@ -90,6 +91,8 @@ _has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None +_has_minari = importlib.util.find_spec("minari") is not None + if _has_gym: try: import gymnasium as gym @@ -1961,6 +1964,39 @@ def test_d4rl_iteration(self, task, split_trajs): print(f"terminated test after {time.time()-t0}s") +@pytest.mark.skipif(not _has_minari, reason="Minari not found") +class TestMinari: + @pytest.fixture(scope="class") + def selected_datasets(self): + torch.manual_seed(0) + import minari + + keys = list(minari.list_remote_datasets()) + indices = torch.randperm(len(keys))[:10] + keys = [keys[idx] for idx in indices] + keys = [ + key + for key in keys + if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] + ] + assert len(keys) > 5 + return keys + + def test_load(self, selected_datasets): + for dataset in selected_datasets: + print("dataset", dataset) + data = MinariExperienceReplay(dataset, batch_size=32) + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + assert data.metadata["action_space"].is_in(sample["action"]) + assert data.metadata["observation_space"].is_in(sample["observation"]) + t0 = time.time() + if i == 10: + break + + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 81a668648d0..85b8e064917 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1,2 +1,3 @@ from .d4rl import D4RLExperienceReplay +from .minari_data import MinariExperienceReplay from .openml import OpenMLExperienceReplay diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 2cd7645eff3..163efebb7a1 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -99,6 +99,12 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): terminate_on_end (bool, optional): Set ``done=True`` on the last timestep in a trajectory. Default is ``False``, and will discard the last timestep in each trajectory. + root (Path or str, optional): The D4RL dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/d4rl`. + download (bool, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. **env_kwargs (key-value pairs): additional kwargs for :func:`d4rl.qlearning_dataset`. diff --git a/torchrl/data/datasets/minari.py b/torchrl/data/datasets/minari.py deleted file mode 100644 index 69c78ce556f..00000000000 --- a/torchrl/data/datasets/minari.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import os.path -import tempfile -from pathlib import Path -from typing import Callable - -import torch -from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict -from torchrl._utils import KeyDependentDefaultDict -from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer - -_NAME_MATCH = KeyDependentDefaultDict(lambda key: key) -_NAME_MATCH["observations"] = "observation" -_NAME_MATCH["rewards"] = "reward" -_NAME_MATCH["truncations"] = "truncated" -_NAME_MATCH["terminations"] = "terminated" -_NAME_MATCH["actions"] = "action" -_NAME_MATCH["infos"] = "info" - - -class MinariExperienceReplay(TensorDictReplayBuffer): - """Minari Experience replay dataset. - - Args: - dataset_id (str): - batch_size (int): - - Keyword Args: - root (Path or str, optional): - download (bool, optional): - """ - def __init__( - self, - dataset_id, - batch_size: int, - *, - root: str | Path | None = None, - download: bool = True, - sampler: Sampler | None = None, - writer: Writer | None = None, - collate_fn: Callable | None = None, - pin_memory: bool = False, - prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 - split_trajs: bool = False, - **env_kwargs, - ): - self.dataset_id = dataset_id - if root is None: - root = _get_root_dir("minari") - os.makedirs(root, exist_ok=True) - self.root = root - self.split_trajs = split_trajs - self.download = download - if self.download and not self._is_downloaded(): - storage = self._download_and_preproc() - elif self.split_trajs and not os.path.exists(self.data_path): - storage = self._make_split() - else: - storage = self._load() - storage = TensorStorage(storage) - super().__init__( - storage=storage, - sampler=sampler, - writer=writer, - collate_fn=collate_fn, - pin_memory=pin_memory, - prefetch=prefetch, - batch_size=batch_size, - ) - - def _is_downloaded(self): - return os.path.exists(self.data_path) - - @property - def data_path(self): - if self.split_trajs: - return Path(self.root) / (self.dataset_id + "_split") - return self.data_path_root - - @property - def data_path_root(self): - return Path(self.root) / self.dataset_id - - def _download_and_preproc(self): - import minari - - with tempfile.TemporaryDirectory() as tmpdir: - os.environ["MINARI_DATASETS_PATH"] = tmpdir - minari.download_dataset(dataset_id=self.dataset_id) - minari.load_dataset(self.dataset_id) - h5_data = PersistentTensorDict.from_h5( - Path(tmpdir) / self.dataset_id / "data/main_data.hdf5" - ) - - # Get the total number of steps for the dataset - total_steps = sum( - h5_data[episode, "actions"].shape[0] for episode in h5_data.keys() - ) - # populate the tensordict - td_data = TensorDict({}, []) - for key, episode in h5_data.items(): - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ("observations", "state"): - td_data.set(("next", match), torch.zeros_like(val)[0]) - td_data.set(match, torch.zeros_like(val)[0]) - elif key not in ("terminations", "truncations", "rewards"): - td_data.set(match, torch.zeros_like(val)[0]) - else: - td_data.set( - ("next", match), torch.zeros_like(val)[0].unsqueeze(-1) - ) - break - # give it the proper size - td_data = td_data.expand(total_steps) - # save to designated location - td_data.memmap_(self.data_path_root) - # iterate over episodes and populate the tensordict - index = 0 - for key, episode in h5_data.items(): - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ( - "observations", - "state", - ): - steps = val.shape[0] - 1 - td_data["next", match][index : (index + steps)] = val[1:] - td_data[match][index : (index + steps)] = val[:-1] - elif key not in ("terminations", "truncations", "rewards"): - steps = val.shape[0] - td_data[match][index : (index + val.shape[0])] = val - else: - steps = val.shape[0] - td_data[("next", match)][ - index : (index + val.shape[0]) - ] = val.unsqueeze(-1) - index += steps - # Add a "done" entry - with td_data.unlock_(): - td_data["next", "done"] = MemoryMappedTensor.from_tensor( - (td_data["next", "terminated"] | td_data["next", "truncated"]) - ) - if self.split_trajs: - from torchrl.objectives.utils import split_trajectories - - td_data = split_trajectories(td_data).memmap_(self.data_path) - return td_data - - def _make_split(self): - from torchrl.objectives.utils import split_trajectories - - td_data = TensorDict.load_memmap(self.data_path_root) - td_data = split_trajectories(td_data).memmap_(self.data_path) - return td_data - - def _load(self): - return TensorDict.load_memmap(self.data_path) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py new file mode 100644 index 00000000000..2766feb956a --- /dev/null +++ b/torchrl/data/datasets/minari_data.py @@ -0,0 +1,403 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import json +import os.path +import shutil +import tempfile +from dataclasses import asdict +from pathlib import Path +from typing import Callable + +import torch +import tqdm + +from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict +from torchrl._utils import KeyDependentDefaultDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) + +_NAME_MATCH = KeyDependentDefaultDict(lambda key: key) +_NAME_MATCH["observations"] = "observation" +_NAME_MATCH["rewards"] = "reward" +_NAME_MATCH["truncations"] = "truncated" +_NAME_MATCH["terminations"] = "terminated" +_NAME_MATCH["actions"] = "action" +_NAME_MATCH["infos"] = "info" + + +class MinariExperienceReplay(TensorDictReplayBuffer): + """Minari Experience replay dataset. + + Args: + dataset_id (str): + batch_size (int): + + Keyword Args: + root (Path or str, optional): The Minari dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/minari`. + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. + split_trajs (bool, optional): if ``True``, the trajectories will be split + along the first dimension and padded to have a matching shape. + To split the trajectories, the ``"done"`` signal will be used, which + is recovered via ``done = truncated | terminated``. In other words, + it is assumed that any ``truncated`` or ``terminated`` signal is + equivalent to the end of a trajectory. For some datasets from + ``D4RL``, this may not be true. It is up to the user to make + accurate choices regarding this usage of ``split_trajs``. + Defaults to ``False``. + + Examples: + >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay + >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force") + >>> for sample in data: + ... print(sample) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([32, 28]), device=cpu, dtype=torch.float32, is_shared=False), + index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), + info: TensorDict( + fields={ + success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), + reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False), + state: TensorDict( + fields={ + door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False), + qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), + qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), + state: TensorDict( + fields={ + door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False), + qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), + qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + + """ + + def __init__( + self, + dataset_id, + batch_size: int, + *, + root: str | Path | None = None, + download: bool = True, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, + pin_memory: bool = False, + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + split_trajs: bool = False, + **env_kwargs, + ): + self.dataset_id = dataset_id + if root is None: + root = _get_root_dir("minari") + os.makedirs(root, exist_ok=True) + self.root = root + self.split_trajs = split_trajs + self.download = download + if self.download == "force" or (self.download and not self._is_downloaded()): + if self.download == "force": + try: + shutil.rmtree(self.data_path_root) + if self.data_path != self.data_path_root: + shutil.rmtree(self.data_path) + except FileNotFoundError: + pass + storage = self._download_and_preproc() + elif self.split_trajs and not os.path.exists(self.data_path): + storage = self._make_split() + else: + storage = self._load() + storage = TensorStorage(storage) + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + ) + + def available_datasets(self): + import minari + + return minari.list_remote_datasets().keys() + + def _is_downloaded(self): + return os.path.exists(self.data_path) + + @property + def data_path(self): + if self.split_trajs: + return Path(self.root) / (self.dataset_id + "_split") + return self.data_path_root + + @property + def data_path_root(self): + return Path(self.root) / self.dataset_id + + @property + def metadata_path(self): + return Path(self.root) / self.dataset_id / "env_metadata.json" + + def _download_and_preproc(self): + import minari + + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["MINARI_DATASETS_PATH"] = tmpdir + minari.download_dataset(dataset_id=self.dataset_id) + parent_dir = Path(tmpdir) / self.dataset_id / "data" + + h5files = [] + for filename in os.listdir(parent_dir): + if filename.endswith(".hdf5"): + file_path = parent_dir / filename + h5files.append(file_path) + + td_data = TensorDict({}, []) + total_steps = 0 + print("first read through data to create data structure...") + with tqdm.tqdm(h5files) as pbar: + for h5file in pbar: + pbar.set_description(f"reading h5 {h5file}") + h5_data = PersistentTensorDict.from_h5(h5file) + # Get the total number of steps for the dataset + total_steps += sum( + h5_data[episode, "actions"].shape[0] + for episode in h5_data.keys() + ) + # populate the tensordict + for key, episode in h5_data.items(): + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ("observations", "state", "infos"): + if not val.shape: + # Data is ambiguous, skipping + continue + # unique_shapes = defaultdict([]) + # for subkey, subval in val.items(): + # unique_shapes[subval.shape[0]].append(subkey) + # if not len(unique_shapes) == 2: + # raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") + # val_td = val.to_tensordict() + # min_shape = min(*unique_shapes) # can only be found at root + # max_shape = min_shape + 1 + # val_td = val_td.select(*unique_shapes[min_shape]) + # print("key - val", key, val) + # print("episode", episode) + td_data.set(("next", match), torch.zeros_like(val)[0]) + td_data.set(match, torch.zeros_like(val)[0]) + if key not in ("terminations", "truncations", "rewards"): + td_data.set(match, torch.zeros_like(val)[0]) + else: + td_data.set( + ("next", match), + torch.zeros_like(val)[0].unsqueeze(-1), + ) + break + h5_data.close() + + # give it the proper size + td_data = td_data.expand(total_steps) + # save to designated location + print(f"creating tensordict data in {self.data_path_root}: ", end="\t") + td_data = td_data.memmap_like(self.data_path_root) + print(td_data) + + print("Reading data") + index = 0 + with tqdm.tqdm(total=total_steps) as pbar: + for h5file in h5files: + h5_data = PersistentTensorDict.from_h5(h5file) + # TODO: sort episodes + # iterate over episodes and populate the tensordict + for key, episode in h5_data.items(): + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ( + "observations", + "state", + "infos", + ): + if not val.shape: + # Data is ambiguous, skipping + continue + steps = val.shape[0] - 1 + td_data["next", match][index : (index + steps)] = val[ + 1: + ] + td_data[match][index : (index + steps)] = val[:-1] + elif key not in ("terminations", "truncations", "rewards"): + steps = val.shape[0] + td_data[match][index : (index + val.shape[0])] = val + else: + steps = val.shape[0] + td_data[("next", match)][ + index : (index + val.shape[0]) + ] = val.unsqueeze(-1) + pbar.update(steps) + pbar.set_description(f"index={index} - h5 {h5file}") + index += steps + h5_data.close() + # Add a "done" entry + with td_data.unlock_(): + td_data["next", "done"] = MemoryMappedTensor.from_tensor( + (td_data["next", "terminated"] | td_data["next", "truncated"]) + ) + if self.split_trajs: + from torchrl.objectives.utils import split_trajectories + + td_data = split_trajectories(td_data).memmap_(self.data_path) + with open(self.metadata_path, "w") as metadata_file: + dataset = minari.load_dataset(self.dataset_id) + self.metadata = asdict(dataset.spec) + self.metadata["observation_space"] = _spec_to_dict( + self.metadata["observation_space"] + ) + self.metadata["action_space"] = _spec_to_dict( + self.metadata["action_space"] + ) + print("self.metadata", self.metadata) + json.dump(self.metadata, metadata_file) + self._load_and_proc_metadata() + return td_data + + def _make_split(self): + from torchrl.objectives.utils import split_trajectories + + self._load_and_proc_metadata() + td_data = TensorDict.load_memmap(self.data_path_root) + td_data = split_trajectories(td_data).memmap_(self.data_path) + return td_data + + def _load(self): + self._load_and_proc_metadata() + return TensorDict.load_memmap(self.data_path) + + def _load_and_proc_metadata(self): + with open(self.metadata_path, "r") as file: + self.metadata = json.load(file) + self.metadata["observation_space"] = _proc_spec( + self.metadata["observation_space"] + ) + self.metadata["action_space"] = _proc_spec(self.metadata["action_space"]) + print("Loaded metadata", self.metadata) + + +def _proc_spec(spec): + if spec is None: + return + if spec["type"] == "Dict": + return CompositeSpec( + {key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()} + ) + elif spec["type"] == "Box": + if all(item == -float("inf") for item in spec["low"]) and all( + item == float("inf") for item in spec["high"] + ): + return UnboundedContinuousTensorSpec( + spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] + ) + return BoundedTensorSpec( + shape=spec["shape"], + low=torch.tensor(spec["low"]), + high=torch.tensor(spec["high"]), + dtype=_DTYPE_DIR[spec["dtype"]], + ) + elif spec["type"] == "Discrete": + return DiscreteTensorSpec( + spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] + ) + else: + raise NotImplementedError(f"{type(spec)}") + + +def _spec_to_dict(spec): + from torchrl.envs.libs.gym import gym_backend + + if isinstance(spec, gym_backend("spaces").Dict): + return { + "type": "Dict", + "subspaces": {key: _spec_to_dict(val) for key, val in spec.items()}, + } + if isinstance(spec, gym_backend("spaces").Box): + return { + "type": "Box", + "low": spec.low.tolist(), + "high": spec.high.tolist(), + "dtype": str(spec.dtype), + "shape": tuple(spec.shape), + } + if isinstance(spec, gym_backend("spaces").Discrete): + return { + "type": "Discrete", + "dtype": str(spec.dtype), + "n": int(spec.n), + "shape": tuple(spec.shape), + } + if isinstance(spec, gym_backend("spaces").Text): + return + raise NotImplementedError(f"{type(spec)}, {str(spec)}") + + +_DTYPE_DIR = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, +} diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 9c8417b9c97..ec7a4a467ac 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -394,7 +394,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: out = self._storage[index] if is_tensor_collection(out): out = _reset_batch_size(out) - return out.unlock_() + return out # .unlock_() return out def __len__(self): From 1f9c0d6fb17529575d4b31b5cb7a8deb96739a22 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 18:04:00 +0000 Subject: [PATCH 05/19] empty From 49f5f5112acb8f7ca17ea778f901bf0732c0fb7e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 10:58:57 +0000 Subject: [PATCH 06/19] amend --- docs/source/reference/data.rst | 1 + torchrl/data/datasets/minari_data.py | 146 +++++++++++++-------------- 2 files changed, 70 insertions(+), 77 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 98d2d40cd5c..d14c2c6cf38 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -189,6 +189,7 @@ Here's an example: D4RLExperienceReplay + MinariExperienceReplay OpenMLExperienceReplay TensorSpec diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 2766feb956a..476de4e1304 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -201,54 +201,47 @@ def _download_and_preproc(self): minari.download_dataset(dataset_id=self.dataset_id) parent_dir = Path(tmpdir) / self.dataset_id / "data" - h5files = [] - for filename in os.listdir(parent_dir): - if filename.endswith(".hdf5"): - file_path = parent_dir / filename - h5files.append(file_path) - td_data = TensorDict({}, []) total_steps = 0 print("first read through data to create data structure...") - with tqdm.tqdm(h5files) as pbar: - for h5file in pbar: - pbar.set_description(f"reading h5 {h5file}") - h5_data = PersistentTensorDict.from_h5(h5file) - # Get the total number of steps for the dataset - total_steps += sum( - h5_data[episode, "actions"].shape[0] - for episode in h5_data.keys() - ) - # populate the tensordict - for key, episode in h5_data.items(): - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ("observations", "state", "infos"): - if not val.shape: - # Data is ambiguous, skipping - continue - # unique_shapes = defaultdict([]) - # for subkey, subval in val.items(): - # unique_shapes[subval.shape[0]].append(subkey) - # if not len(unique_shapes) == 2: - # raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") - # val_td = val.to_tensordict() - # min_shape = min(*unique_shapes) # can only be found at root - # max_shape = min_shape + 1 - # val_td = val_td.select(*unique_shapes[min_shape]) - # print("key - val", key, val) - # print("episode", episode) - td_data.set(("next", match), torch.zeros_like(val)[0]) - td_data.set(match, torch.zeros_like(val)[0]) - if key not in ("terminations", "truncations", "rewards"): - td_data.set(match, torch.zeros_like(val)[0]) - else: - td_data.set( - ("next", match), - torch.zeros_like(val)[0].unsqueeze(-1), - ) - break - h5_data.close() + h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + # Get the total number of steps for the dataset + total_steps += sum( + h5_data[episode, "actions"].shape[0] + for episode in h5_data.keys() + ) + # populate the tensordict + episode_dict = {} + for key, episode in h5_data.items(): + for key, val in episode.items(): + episode_num = int(key[len("episode_"):]) + episode_dict[episode_num] = key + match = _NAME_MATCH[key] + if key in ("observations", "state", "infos"): + if not val.shape: + # Data is ambiguous, skipping + continue + # unique_shapes = defaultdict([]) + # for subkey, subval in val.items(): + # unique_shapes[subval.shape[0]].append(subkey) + # if not len(unique_shapes) == 2: + # raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") + # val_td = val.to_tensordict() + # min_shape = min(*unique_shapes) # can only be found at root + # max_shape = min_shape + 1 + # val_td = val_td.select(*unique_shapes[min_shape]) + # print("key - val", key, val) + # print("episode", episode) + td_data.set(("next", match), torch.zeros_like(val)[0]) + td_data.set(match, torch.zeros_like(val)[0]) + if key not in ("terminations", "truncations", "rewards"): + td_data.set(match, torch.zeros_like(val)[0]) + else: + td_data.set( + ("next", match), + torch.zeros_like(val)[0].unsqueeze(-1), + ) + break # give it the proper size td_data = td_data.expand(total_steps) @@ -260,38 +253,37 @@ def _download_and_preproc(self): print("Reading data") index = 0 with tqdm.tqdm(total=total_steps) as pbar: - for h5file in h5files: - h5_data = PersistentTensorDict.from_h5(h5file) - # TODO: sort episodes - # iterate over episodes and populate the tensordict - for key, episode in h5_data.items(): - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ( - "observations", - "state", - "infos", - ): - if not val.shape: - # Data is ambiguous, skipping - continue - steps = val.shape[0] - 1 - td_data["next", match][index : (index + steps)] = val[ - 1: - ] - td_data[match][index : (index + steps)] = val[:-1] - elif key not in ("terminations", "truncations", "rewards"): - steps = val.shape[0] - td_data[match][index : (index + val.shape[0])] = val - else: - steps = val.shape[0] - td_data[("next", match)][ - index : (index + val.shape[0]) - ] = val.unsqueeze(-1) - pbar.update(steps) - pbar.set_description(f"index={index} - h5 {h5file}") - index += steps - h5_data.close() + # iterate over episodes and populate the tensordict + for episode_num in sorted(episode_dict): + key = episode_dict[episode_num] + episode = h5_data.get(key) + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ( + "observations", + "state", + "infos", + ): + if not val.shape: + # Data is ambiguous, skipping + continue + steps = val.shape[0] - 1 + td_data["next", match][index : (index + steps)] = val[ + 1: + ] + td_data[match][index : (index + steps)] = val[:-1] + elif key not in ("terminations", "truncations", "rewards"): + steps = val.shape[0] + td_data[match][index : (index + val.shape[0])] = val + else: + steps = val.shape[0] + td_data[("next", match)][ + index : (index + val.shape[0]) + ] = val.unsqueeze(-1) + pbar.update(steps) + pbar.set_description(f"index={index} - episode num {episode_num}") + index += steps + h5_data.close() # Add a "done" entry with td_data.unlock_(): td_data["next", "done"] = MemoryMappedTensor.from_tensor( From 88d9661e181eb536630f4dd065247c301e12289a Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 11:00:07 +0000 Subject: [PATCH 07/19] amend --- torchrl/data/datasets/minari_data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 476de4e1304..d91c2e7a5a8 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -212,10 +212,10 @@ def _download_and_preproc(self): ) # populate the tensordict episode_dict = {} - for key, episode in h5_data.items(): + for episode_key, episode in h5_data.items(): + episode_num = int(episode_key[len("episode_"):]) + episode_dict[episode_num] = episode_key for key, val in episode.items(): - episode_num = int(key[len("episode_"):]) - episode_dict[episode_num] = key match = _NAME_MATCH[key] if key in ("observations", "state", "infos"): if not val.shape: @@ -255,8 +255,8 @@ def _download_and_preproc(self): with tqdm.tqdm(total=total_steps) as pbar: # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): - key = episode_dict[episode_num] - episode = h5_data.get(key) + episode_key = episode_dict[episode_num] + episode = h5_data.get(episode_key) for key, val in episode.items(): match = _NAME_MATCH[key] if key in ( From b6a3dc0111ba6ecfaddd236d6feb8a177c2d6e17 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 11:15:11 +0000 Subject: [PATCH 08/19] amend --- test/test_libs.py | 5 +++-- torchrl/data/datasets/minari_data.py | 6 ++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 7cdb786de6f..8b54ca1c243 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1982,10 +1982,11 @@ def selected_datasets(self): assert len(keys) > 5 return keys - def test_load(self, selected_datasets): + @pytest.mark.parametrize("split", [False, True]) + def test_load(self, selected_datasets, split): for dataset in selected_datasets: print("dataset", dataset) - data = MinariExperienceReplay(dataset, batch_size=32) + data = MinariExperienceReplay(dataset, batch_size=32, split_trajs=split) t0 = time.time() for i, sample in enumerate(data): t1 = time.time() diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index d91c2e7a5a8..a70b3904c67 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -8,6 +8,7 @@ import os.path import shutil import tempfile +import time from dataclasses import asdict from pathlib import Path from typing import Callable @@ -241,7 +242,6 @@ def _download_and_preproc(self): ("next", match), torch.zeros_like(val)[0].unsqueeze(-1), ) - break # give it the proper size td_data = td_data.expand(total_steps) @@ -250,7 +250,7 @@ def _download_and_preproc(self): td_data = td_data.memmap_like(self.data_path_root) print(td_data) - print("Reading data") + print(f"Reading data from {max(*episode_dict)} episodes") index = 0 with tqdm.tqdm(total=total_steps) as pbar: # iterate over episodes and populate the tensordict @@ -302,7 +302,6 @@ def _download_and_preproc(self): self.metadata["action_space"] = _spec_to_dict( self.metadata["action_space"] ) - print("self.metadata", self.metadata) json.dump(self.metadata, metadata_file) self._load_and_proc_metadata() return td_data @@ -326,7 +325,6 @@ def _load_and_proc_metadata(self): self.metadata["observation_space"] ) self.metadata["action_space"] = _proc_spec(self.metadata["action_space"]) - print("Loaded metadata", self.metadata) def _proc_spec(spec): From 4a791cd2e81e7504b70f7a16aa9155f34d54fc8a Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 12:02:06 +0000 Subject: [PATCH 09/19] amend --- torchrl/data/datasets/minari_data.py | 61 ++++++++++++++++------------ 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index a70b3904c67..842e353a750 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -8,7 +8,8 @@ import os.path import shutil import tempfile -import time + +from collections import defaultdict from dataclasses import asdict from pathlib import Path from typing import Callable @@ -178,7 +179,7 @@ def available_datasets(self): return minari.list_remote_datasets().keys() def _is_downloaded(self): - return os.path.exists(self.data_path) + return os.path.exists(self.data_path_root) @property def data_path(self): @@ -208,39 +209,26 @@ def _download_and_preproc(self): h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") # Get the total number of steps for the dataset total_steps += sum( - h5_data[episode, "actions"].shape[0] - for episode in h5_data.keys() + h5_data[episode, "actions"].shape[0] for episode in h5_data.keys() ) # populate the tensordict episode_dict = {} for episode_key, episode in h5_data.items(): - episode_num = int(episode_key[len("episode_"):]) + episode_num = int(episode_key[len("episode_") :]) episode_dict[episode_num] = episode_key for key, val in episode.items(): match = _NAME_MATCH[key] if key in ("observations", "state", "infos"): if not val.shape: - # Data is ambiguous, skipping - continue - # unique_shapes = defaultdict([]) - # for subkey, subval in val.items(): - # unique_shapes[subval.shape[0]].append(subkey) - # if not len(unique_shapes) == 2: - # raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") - # val_td = val.to_tensordict() - # min_shape = min(*unique_shapes) # can only be found at root - # max_shape = min_shape + 1 - # val_td = val_td.select(*unique_shapes[min_shape]) - # print("key - val", key, val) - # print("episode", episode) - td_data.set(("next", match), torch.zeros_like(val)[0]) - td_data.set(match, torch.zeros_like(val)[0]) + val = _patch_info(val) + td_data.set(("next", match), torch.zeros_like(val[0])) + td_data.set(match, torch.zeros_like(val[0])) if key not in ("terminations", "truncations", "rewards"): - td_data.set(match, torch.zeros_like(val)[0]) + td_data.set(match, torch.zeros_like(val[0])) else: td_data.set( ("next", match), - torch.zeros_like(val)[0].unsqueeze(-1), + torch.zeros_like(val[0].unsqueeze(-1)), ) # give it the proper size @@ -265,12 +253,9 @@ def _download_and_preproc(self): "infos", ): if not val.shape: - # Data is ambiguous, skipping - continue + val = _patch_info(val) steps = val.shape[0] - 1 - td_data["next", match][index : (index + steps)] = val[ - 1: - ] + td_data["next", match][index : (index + steps)] = val[1:] td_data[match][index : (index + steps)] = val[:-1] elif key not in ("terminations", "truncations", "rewards"): steps = val.shape[0] @@ -391,3 +376,25 @@ def _spec_to_dict(spec): "int32": torch.int32, "uint8": torch.uint8, } + + +def _patch_info(info_td): + # Some info dicts have tensors with one less element than others + # We explicitely assume that the missing item is in the first position because + # it wasn't given at reset time. + # An alternative explanation could be that the last element is missing because + # deemed useless for training... + unique_shapes = defaultdict(list) + for subkey, subval in info_td.items(): + unique_shapes[subval.shape[0]].append(subkey) + if not len(unique_shapes) == 2: + raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") + val_td = info_td.to_tensordict() + min_shape = min(*unique_shapes) # can only be found at root + max_shape = min_shape + 1 + val_td_sel = val_td.select(*unique_shapes[min_shape]).apply( + lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0) + ) + val_td_sel.batch_size = [min_shape + 1] + val_td_sel.update(val_td.select(*unique_shapes[max_shape])) + return val_td_sel From 1985222794a67c3855418f97ab4c0488898ba3fa Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 13:41:50 +0000 Subject: [PATCH 10/19] amend --- torchrl/data/datasets/minari_data.py | 37 +++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 842e353a750..df70145c0e1 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -4,18 +4,19 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import importlib.util import json import os.path import shutil import tempfile from collections import defaultdict +from contextlib import nullcontext from dataclasses import asdict from pathlib import Path from typing import Callable import torch -import tqdm from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict @@ -31,6 +32,8 @@ UnboundedContinuousTensorSpec, ) +_has_tqdm = importlib.util.find_spec("tqdm", None) is not None + _NAME_MATCH = KeyDependentDefaultDict(lambda key: key) _NAME_MATCH["observations"] = "observation" _NAME_MATCH["rewards"] = "reward" @@ -40,6 +43,16 @@ _NAME_MATCH["infos"] = "info" +_DTYPE_DIR = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, +} + + class MinariExperienceReplay(TensorDictReplayBuffer): """Minari Experience replay dataset. @@ -198,6 +211,9 @@ def metadata_path(self): def _download_and_preproc(self): import minari + if _has_tqdm: + from tqdm import tqdm + with tempfile.TemporaryDirectory() as tmpdir: os.environ["MINARI_DATASETS_PATH"] = tmpdir minari.download_dataset(dataset_id=self.dataset_id) @@ -240,7 +256,7 @@ def _download_and_preproc(self): print(f"Reading data from {max(*episode_dict)} episodes") index = 0 - with tqdm.tqdm(total=total_steps) as pbar: + with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): episode_key = episode_dict[episode_num] @@ -265,8 +281,11 @@ def _download_and_preproc(self): td_data[("next", match)][ index : (index + val.shape[0]) ] = val.unsqueeze(-1) - pbar.update(steps) - pbar.set_description(f"index={index} - episode num {episode_num}") + if pbar is not None: + pbar.update(steps) + pbar.set_description( + f"index={index} - episode num {episode_num}" + ) index += steps h5_data.close() # Add a "done" entry @@ -368,16 +387,6 @@ def _spec_to_dict(spec): raise NotImplementedError(f"{type(spec)}, {str(spec)}") -_DTYPE_DIR = { - "float16": torch.float16, - "float32": torch.float32, - "float64": torch.float64, - "int64": torch.int64, - "int32": torch.int32, - "uint8": torch.uint8, -} - - def _patch_info(info_td): # Some info dicts have tensors with one less element than others # We explicitely assume that the missing item is in the first position because From 3a213cbaf22211efbf3a3a6fd2753c051ad231f6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 14:31:16 +0000 Subject: [PATCH 11/19] amend --- .github/workflows/test-linux-minari.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-minari.yml b/.github/workflows/test-linux-minari.yml index 8ef92de54e4..aa473d5aef2 100644 --- a/.github/workflows/test-linux-minari.yml +++ b/.github/workflows/test-linux-minari.yml @@ -1,4 +1,4 @@ -name: D4RL Tests on Linux +name: Minari Tests on Linux on: pull_request: From 1a31f763373c63ba6e8337d15e5e95b2be140a72 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 16:26:46 +0000 Subject: [PATCH 12/19] amend --- .../linux_libs/scripts_minari/run_test.sh | 38 +------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh index 3723399a859..bf2a49e3b43 100755 --- a/.github/unittest/linux_libs/scripts_minari/run_test.sh +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -8,21 +8,6 @@ conda activate ./env apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 ln -s /usr/bin/swig3.0 /usr/bin/swig -# we install d4rl here bc env variables have been updated -git clone https://github.com/Farama-Foundation/d4rl.git -cd d4rl -#pip3 install -U 'mujoco-py<2.1,>=2.0' -pip3 install -U "gym[classic_control,atari,accept-rom-license]"==0.23 -pip3 install -U six -pip install -e . -cd .. - -#flow is a dependency disaster of biblical scale -#git clone https://github.com/flow-project/flow.git -#cd flow -#python setup.py develop -#cd .. - export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" @@ -35,27 +20,8 @@ lib_dir="${env_dir}/lib" conda deactivate && conda activate ./env # this workflow only tests the libs -python -c "import gym, d4rl" +python -c "import gym, minari" -python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMinari --error-for-skips coverage combine coverage xml -i - -## check what happens if we update gym -#pip install gym -U -#python -c """ -#from torchrl.data.datasets import D4RLExperienceReplay -#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=True) -#for batch in data: -# print(batch) -# break -# -#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=False) -#for batch in data: -# print(batch) -# break -# -#import d4rl -#import gym -#gym.make('halfcheetah-medium-v2') -#""" From 18c7f10cc949d288d527ae4c73d2c2aa056c111b Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 17:05:38 +0000 Subject: [PATCH 13/19] amend --- .github/unittest/linux_libs/scripts_minari/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh index bf2a49e3b43..7741a491f5b 100755 --- a/.github/unittest/linux_libs/scripts_minari/run_test.sh +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -20,7 +20,7 @@ lib_dir="${env_dir}/lib" conda deactivate && conda activate ./env # this workflow only tests the libs -python -c "import gym, minari" +python -c "import minari" python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMinari --error-for-skips coverage combine From 2714b62cab6dc15891cd425e1129f25d0002d967 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 10:08:52 +0000 Subject: [PATCH 14/19] fixes --- test/test_libs.py | 71 ++++++++++++--------- torchrl/data/datasets/minari_data.py | 93 +++++++++++++++++++--------- 2 files changed, 105 insertions(+), 59 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8b54ca1c243..00869c8fa16 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1964,38 +1964,51 @@ def test_d4rl_iteration(self, task, split_trajs): print(f"terminated test after {time.time()-t0}s") +_MINARI_DATASETS = [] + + +def _minari_selected_datasets(): + if not _has_minari: + return + global _MINARI_DATASETS + import minari + + torch.manual_seed(0) + + keys = list(minari.list_remote_datasets()) + indices = torch.randperm(len(keys))[:10] + keys = [keys[idx] for idx in indices] + keys = [ + key + for key in keys + if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] + ] + assert len(keys) > 5 + _MINARI_DATASETS += keys + print("_MINARI_DATASETS", _MINARI_DATASETS) + + +_minari_selected_datasets() + + @pytest.mark.skipif(not _has_minari, reason="Minari not found") +@pytest.mark.parametrize("split", [False, True]) +@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS) class TestMinari: - @pytest.fixture(scope="class") - def selected_datasets(self): - torch.manual_seed(0) - import minari - - keys = list(minari.list_remote_datasets()) - indices = torch.randperm(len(keys))[:10] - keys = [keys[idx] for idx in indices] - keys = [ - key - for key in keys - if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] - ] - assert len(keys) > 5 - return keys - - @pytest.mark.parametrize("split", [False, True]) - def test_load(self, selected_datasets, split): - for dataset in selected_datasets: - print("dataset", dataset) - data = MinariExperienceReplay(dataset, batch_size=32, split_trajs=split) + def test_load(self, selected_dataset, split): + print("dataset", selected_dataset) + data = MinariExperienceReplay( + selected_dataset, batch_size=32, split_trajs=split + ) + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + assert data.metadata["action_space"].is_in(sample["action"]) + assert data.metadata["observation_space"].is_in(sample["observation"]) t0 = time.time() - for i, sample in enumerate(data): - t1 = time.time() - print(f"sampling time {1000 * (t1-t0): 4.4f}ms") - assert data.metadata["action_space"].is_in(sample["action"]) - assert data.metadata["observation_space"].is_in(sample["observation"]) - t0 = time.time() - if i == 10: - break + if i == 10: + break @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index df70145c0e1..945ad9d7320 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -18,7 +18,7 @@ import torch -from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict +from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer @@ -223,19 +223,22 @@ def _download_and_preproc(self): total_steps = 0 print("first read through data to create data structure...") h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") - # Get the total number of steps for the dataset - total_steps += sum( - h5_data[episode, "actions"].shape[0] for episode in h5_data.keys() - ) # populate the tensordict episode_dict = {} for episode_key, episode in h5_data.items(): episode_num = int(episode_key[len("episode_") :]) - episode_dict[episode_num] = episode_key + episode_len = episode["actions"].shape[0] + episode_dict[episode_num] = (episode_key, episode_len) + # Get the total number of steps for the dataset + total_steps += episode_len for key, val in episode.items(): match = _NAME_MATCH[key] if key in ("observations", "state", "infos"): - if not val.shape: + if ( + not val.shape + ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if val.is_empty(): + continue val = _patch_info(val) td_data.set(("next", match), torch.zeros_like(val[0])) td_data.set(match, torch.zeros_like(val[0])) @@ -248,19 +251,26 @@ def _download_and_preproc(self): ) # give it the proper size + td_data["next", "done"] = ( + td_data["next", "truncated"] | td_data["next", "terminated"] + ) + if "terminated" in td_data.keys(): + td_data["done"] = td_data["truncated"] | td_data["terminated"] td_data = td_data.expand(total_steps) # save to designated location print(f"creating tensordict data in {self.data_path_root}: ", end="\t") td_data = td_data.memmap_like(self.data_path_root) - print(td_data) + print("tensordict structure:", td_data) print(f"Reading data from {max(*episode_dict)} episodes") index = 0 with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): - episode_key = episode_dict[episode_num] + episode_key, steps = episode_dict[episode_num] episode = h5_data.get(episode_key) + idx = slice(index, (index + steps)) + data_view = td_data[idx] for key, val in episode.items(): match = _NAME_MATCH[key] if key in ( @@ -268,19 +278,41 @@ def _download_and_preproc(self): "state", "infos", ): - if not val.shape: + if not val.shape or steps != val.shape[0] - 1: + if val.is_empty(): + continue val = _patch_info(val) - steps = val.shape[0] - 1 - td_data["next", match][index : (index + steps)] = val[1:] - td_data[match][index : (index + steps)] = val[:-1] + if steps != val.shape[0] - 1: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." + ) + data_view["next", match].copy_(val[1:]) + data_view[match].copy_(val[:-1]) elif key not in ("terminations", "truncations", "rewards"): - steps = val.shape[0] - td_data[match][index : (index + val.shape[0])] = val + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[match].copy_(val) else: - steps = val.shape[0] - td_data[("next", match)][ - index : (index + val.shape[0]) - ] = val.unsqueeze(-1) + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[("next", match)].copy_(val.unsqueeze(-1)) + data_view["next", "done"].copy_( + data_view["next", "terminated"] | data_view["next", "truncated"] + ) + if "done" in data_view.keys(): + data_view["done"].copy_( + data_view["terminated"] | data_view["truncated"] + ) if pbar is not None: pbar.update(steps) pbar.set_description( @@ -289,11 +321,8 @@ def _download_and_preproc(self): index += steps h5_data.close() # Add a "done" entry - with td_data.unlock_(): - td_data["next", "done"] = MemoryMappedTensor.from_tensor( - (td_data["next", "terminated"] | td_data["next", "truncated"]) - ) - if self.split_trajs: + if self.split_trajs: + with td_data.unlock_(): from torchrl.objectives.utils import split_trajectories td_data = split_trajectories(td_data).memmap_(self.data_path) @@ -311,7 +340,7 @@ def _download_and_preproc(self): return td_data def _make_split(self): - from torchrl.objectives.utils import split_trajectories + from torchrl.collectors.utils import split_trajectories self._load_and_proc_metadata() td_data = TensorDict.load_memmap(self.data_path_root) @@ -396,14 +425,18 @@ def _patch_info(info_td): unique_shapes = defaultdict(list) for subkey, subval in info_td.items(): unique_shapes[subval.shape[0]].append(subkey) - if not len(unique_shapes) == 2: - raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") + if len(unique_shapes) == 1: + unique_shapes[subval.shape[0] + 1] = [] + if len(unique_shapes) != 2: + raise RuntimeError( + f"Unique shapes in a sub-tensordict can only be of length 2, got shapes {unique_shapes}." + ) val_td = info_td.to_tensordict() min_shape = min(*unique_shapes) # can only be found at root max_shape = min_shape + 1 - val_td_sel = val_td.select(*unique_shapes[min_shape]).apply( - lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0) + val_td_sel = val_td.select(*unique_shapes[min_shape]) + val_td_sel = val_td_sel.apply( + lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1] ) - val_td_sel.batch_size = [min_shape + 1] val_td_sel.update(val_td.select(*unique_shapes[max_shape])) return val_td_sel From a7b7425f93144e128187540212280451fcce7d6b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 10:09:41 +0000 Subject: [PATCH 15/19] docstrings --- torchrl/data/datasets/minari_data.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 945ad9d7320..6a9a7d78d9a 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -105,10 +105,17 @@ class MinariExperienceReplay(TensorDictReplayBuffer): fields={ success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), - device=None, + device=cpu, is_shared=False), next: TensorDict( fields={ + done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False), state: TensorDict( @@ -117,12 +124,12 @@ class MinariExperienceReplay(TensorDictReplayBuffer): qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([32]), - device=None, + device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), - device=None, + device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), state: TensorDict( @@ -131,10 +138,10 @@ class MinariExperienceReplay(TensorDictReplayBuffer): qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([32]), - device=None, + device=cpu, is_shared=False)}, batch_size=torch.Size([32]), - device=None, + device=cpu, is_shared=False) """ From 20a6289c504798026e96ee5b2880489348733273 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 11:08:12 +0000 Subject: [PATCH 16/19] amend --- test/test_libs.py | 5 +++++ torchrl/data/datasets/d4rl.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 00869c8fa16..1e81b7d8f12 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1832,6 +1832,7 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): terminate_on_end=True, batch_size=2, use_truncated_as_done=use_truncated_as_done, + download="force", ) _ = D4RLExperienceReplay( task, @@ -1840,6 +1841,7 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): terminate_on_end=False, batch_size=2, use_truncated_as_done=use_truncated_as_done, + download="force", ) data_from_env = D4RLExperienceReplay( task, @@ -1847,6 +1849,7 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): from_env=True, batch_size=2, use_truncated_as_done=use_truncated_as_done, + download="force", ) if not use_truncated_as_done: keys = set(data_from_env._storage._storage.keys(True, True)) @@ -1885,6 +1888,7 @@ def test_direct_download(self, task): batch_size=2, use_truncated_as_done=True, direct_download=True, + download="force", ) data_d4rl = D4RLExperienceReplay( task, @@ -1894,6 +1898,7 @@ def test_direct_download(self, task): use_truncated_as_done=True, direct_download=False, terminate_on_end=True, # keep the last time step + download="force", ) keys = set(data_direct._storage._storage.keys(True, True)) keys = keys.intersection(data_d4rl._storage._storage.keys(True, True)) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 163efebb7a1..38fce4a6b7c 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -98,7 +98,8 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): Otherwise, only the ``terminated`` key is used. Defaults to ``True``. terminate_on_end (bool, optional): Set ``done=True`` on the last timestep in a trajectory. Default is ``False``, and will discard the - last timestep in each trajectory. + last timestep in each trajectory. This is to be used only with + ``direct_download=False``. root (Path or str, optional): The D4RL dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to @@ -173,13 +174,22 @@ def __init__( category=DeprecationWarning, ) from_env = True + else: + warnings.warn( + "You are using the D4RL library for collecting data. " + "We advise against this use, as D4RL formatting can be " + "inconsistent. " + "To download the D4RL data without the D4RL library, use " + "direct_download=True in the dataset constructor. " + "Recurring to `direct_download=False` will soon be deprecated." + ) self.from_env = from_env else: if from_env is None: from_env = False self.from_env = from_env - if download and not self._is_downloaded(): + if (download == "force") or (download and not self._is_downloaded()): if not direct_download: if terminate_on_end is None: # we use the default of d4rl From b31eaa8105d661204bb78ee6932e012ecaebfb8e Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 11:32:46 +0000 Subject: [PATCH 17/19] amend --- test/test_libs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index 1e81b7d8f12..41b62d8b02f 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1820,7 +1820,10 @@ class TestD4RL: @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("use_truncated_as_done", [True, False]) @pytest.mark.parametrize("split_trajs", [True, False]) - def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): + def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs, tmpdir): + root1 = tmpdir / "1" + root2 = tmpdir / "1" + root3 = tmpdir / "1" with pytest.warns( UserWarning, match="Using use_truncated_as_done=True" @@ -1833,6 +1836,7 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): batch_size=2, use_truncated_as_done=use_truncated_as_done, download="force", + root=root1, ) _ = D4RLExperienceReplay( task, @@ -1842,6 +1846,7 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): batch_size=2, use_truncated_as_done=use_truncated_as_done, download="force", + root=root2, ) data_from_env = D4RLExperienceReplay( task, @@ -1850,6 +1855,7 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): batch_size=2, use_truncated_as_done=use_truncated_as_done, download="force", + root=root3, ) if not use_truncated_as_done: keys = set(data_from_env._storage._storage.keys(True, True)) From 75630a70632ffee40fa2da0a838bcd61186054b9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 11:55:29 +0000 Subject: [PATCH 18/19] amend --- test/test_libs.py | 4 ++-- torchrl/data/datasets/minari_data.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 41b62d8b02f..211e527ce70 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1822,8 +1822,8 @@ class TestD4RL: @pytest.mark.parametrize("split_trajs", [True, False]) def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs, tmpdir): root1 = tmpdir / "1" - root2 = tmpdir / "1" - root3 = tmpdir / "1" + root2 = tmpdir / "2" + root3 = tmpdir / "3" with pytest.warns( UserWarning, match="Using use_truncated_as_done=True" diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 6a9a7d78d9a..492ac0fff58 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -91,6 +91,12 @@ class MinariExperienceReplay(TensorDictReplayBuffer): accurate choices regarding this usage of ``split_trajs``. Defaults to ``False``. + .. note:: + Text data is currenrtly discarded from the wrapped dataset, as there is not + PyTorch native way of representing text data. + If this feature is required, please post an issue on TorchRL's GitHub + repository. + Examples: >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force") From 0090fe0f97535e9d276a33b6859d4c75ff6bb6b6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 12:32:57 +0000 Subject: [PATCH 19/19] amend --- test/test_libs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index 211e527ce70..ede7968f15c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1886,7 +1886,9 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs, tmpdir assert "truncated" not in leaf_names @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) - def test_direct_download(self, task): + def test_direct_download(self, task, tmpdir): + root1 = tmpdir / "1" + root2 = tmpdir / "2" data_direct = D4RLExperienceReplay( task, split_trajs=False, @@ -1895,6 +1897,7 @@ def test_direct_download(self, task): use_truncated_as_done=True, direct_download=True, download="force", + root=root1, ) data_d4rl = D4RLExperienceReplay( task, @@ -1905,6 +1908,7 @@ def test_direct_download(self, task): direct_download=False, terminate_on_end=True, # keep the last time step download="force", + root=root2, ) keys = set(data_direct._storage._storage.keys(True, True)) keys = keys.intersection(data_d4rl._storage._storage.keys(True, True))