From 46b02a3c69cac52b43185a98ab1a6983dbd91ef7 Mon Sep 17 00:00:00 2001 From: Achintya P Date: Tue, 28 Apr 2026 13:24:58 -0700 Subject: [PATCH 1/3] [Feature] Add Safety-Gymnasium environment wrapper Adds SafetyGymnasiumWrapper and SafetyGymnasiumEnv for constrained-RL benchmarks. The underlying step API returns a 6-tuple with a parallel cost signal; the wrapper folds cost into the gymnasium info dict so the existing GymWrapper machinery is reused, then registers an info-dict reader that exposes cost as a top-level tensordict key alongside reward. --- test/libs/test_safety_gymnasium.py | 53 ++++++++ torchrl/envs/libs/__init__.py | 3 + torchrl/envs/libs/safety_gymnasium.py | 188 ++++++++++++++++++++++++++ 3 files changed, 244 insertions(+) create mode 100644 test/libs/test_safety_gymnasium.py create mode 100644 torchrl/envs/libs/safety_gymnasium.py diff --git a/test/libs/test_safety_gymnasium.py b/test/libs/test_safety_gymnasium.py new file mode 100644 index 00000000000..9f7ad8fa75b --- /dev/null +++ b/test/libs/test_safety_gymnasium.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import pytest +import torch + +from torchrl.envs.libs.safety_gymnasium import ( + _has_safety_gymnasium, + SafetyGymnasiumEnv, + SafetyGymnasiumWrapper, +) +from torchrl.envs.utils import check_env_specs + + +@pytest.mark.skipif( + not _has_safety_gymnasium, reason="safety-gymnasium not installed" +) +class TestSafetyGymnasium: + def test_wrapper_specs(self): + import safety_gymnasium + + base = safety_gymnasium.make("SafetyPointGoal1-v0") + env = SafetyGymnasiumWrapper(base) + check_env_specs(env) + assert "cost" in env.observation_spec.keys() + + def test_env_from_name_specs(self): + env = SafetyGymnasiumEnv(env_name="SafetyPointGoal1-v0") + check_env_specs(env) + assert "cost" in env.observation_spec.keys() + + def test_rollout_exposes_cost(self): + env = SafetyGymnasiumEnv(env_name="SafetyPointGoal1-v0") + env.set_seed(0) + td = env.rollout(5) + assert ("next", "cost") in td.keys(True) + assert td["next", "cost"].dtype == torch.float64 + assert td["next", "cost"].shape == td["next", "reward"].shape[:-1] + + def test_cost_fires_on_hazard_contact(self): + # SafetyCarPush2-v0 has dense hazards; under random actions we expect + # at least one positive cost in a long rollout. Without this signal + # being plumbed through, every cost would be zero. + env = SafetyGymnasiumEnv(env_name="SafetyCarPush2-v0") + env.set_seed(0) + td = env.rollout(2000, break_when_any_done=False) + assert (td["next", "cost"] > 0).any(), ( + "Expected at least one nonzero cost over 2000 random steps; " + "cost signal may not be plumbed correctly." + ) diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index c7fb8a3a046..742953ba1d9 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -25,6 +25,7 @@ from .pettingzoo import PettingZooEnv, PettingZooWrapper from .procgen import ProcgenEnv, ProcgenWrapper from .robohive import RoboHiveEnv +from .safety_gymnasium import SafetyGymnasiumEnv, SafetyGymnasiumWrapper from .smacv2 import SMACv2Env, SMACv2Wrapper from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper from .vmas import VmasEnv, VmasWrapper @@ -56,6 +57,8 @@ "ProcgenEnv", "ProcgenWrapper", "RoboHiveEnv", + "SafetyGymnasiumEnv", + "SafetyGymnasiumWrapper", "SMACv2Env", "SMACv2Wrapper", "UnityMLAgentsEnv", diff --git a/torchrl/envs/libs/safety_gymnasium.py b/torchrl/envs/libs/safety_gymnasium.py new file mode 100644 index 00000000000..b8b16fe71b6 --- /dev/null +++ b/torchrl/envs/libs/safety_gymnasium.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util +from types import ModuleType + +import numpy as np +import torch + +from torchrl.data.tensor_specs import Composite, Unbounded +from torchrl.envs.gym_like import default_info_dict_reader +from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend +from torchrl.envs.utils import _classproperty + +_has_safety_gymnasium = importlib.util.find_spec("safety_gymnasium") is not None + + +def _make_cost_reader() -> default_info_dict_reader: + cost_spec = Composite( + cost=Unbounded(shape=(), dtype=torch.float64), shape=[] + ) + return default_info_dict_reader(["cost"], spec=cost_spec) + + +class SafetyGymnasiumWrapper(GymWrapper): + """Safety-Gymnasium environment wrapper. + + Safety-Gymnasium (https://github.com/PKU-Alignment/safety-gymnasium) is the + actively-maintained successor to OpenAI's Safety-Gym. It provides + constrained-RL benchmarks where each step emits a parallel ``cost`` signal + alongside the standard reward, allowing agents to optimize reward subject + to a safety budget. + + The underlying ``step`` API returns a 6-tuple + ``(obs, reward, cost, terminated, truncated, info)``. This wrapper folds + ``cost`` into the info dict so that the standard + :class:`~torchrl.envs.libs.gym.GymWrapper` machinery can be reused, and + registers an info-dict reader that exposes ``cost`` as a top-level key in + the returned tensordict. + + Args: + env (safety_gymnasium.Env): the environment to wrap. + + Examples: + >>> import safety_gymnasium # doctest: +SKIP + >>> from torchrl.envs.libs.safety_gymnasium import SafetyGymnasiumWrapper + >>> base = safety_gymnasium.make("SafetyPointGoal1-v0") # doctest: +SKIP + >>> env = SafetyGymnasiumWrapper(base) # doctest: +SKIP + >>> td = env.rollout(3) # doctest: +SKIP + >>> assert ("next", "cost") in td.keys(True) # doctest: +SKIP + + """ + + git_url = "https://github.com/PKU-Alignment/safety-gymnasium" + libname = "safety-gymnasium" + + _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs) + + def __init__(self, env=None, **kwargs): + super().__init__(env=env, **kwargs) + self.set_info_dict_reader(_make_cost_reader()) + + def _output_transform(self, step_outputs_tuple): + observations, reward, cost, terminated, truncated, info = step_outputs_tuple + info = dict(info) if info is not None else {} + # The default info_dict_reader expects values with a `.dtype` + # attribute. safety-gymnasium emits cost as a Python float, so we + # promote it to a numpy scalar of fixed dtype. + info["cost"] = np.asarray(cost, dtype=np.float64) + return ( + observations, + reward, + terminated, + truncated, + terminated | truncated, + info, + ) + + @_classproperty + def available_envs(cls): + if not _has_safety_gymnasium: + return [] + # Curated list of canonical safety-gymnasium task ids. The library + # registers more (different difficulty levels, vision variants, etc.); + # this list mirrors the ones documented as primary benchmarks. + return [ + # Point robot + "SafetyPointGoal0-v0", + "SafetyPointGoal1-v0", + "SafetyPointGoal2-v0", + "SafetyPointButton0-v0", + "SafetyPointButton1-v0", + "SafetyPointButton2-v0", + "SafetyPointPush0-v0", + "SafetyPointPush1-v0", + "SafetyPointPush2-v0", + "SafetyPointCircle0-v0", + "SafetyPointCircle1-v0", + "SafetyPointRace0-v0", + "SafetyPointRace1-v0", + "SafetyPointRace2-v0", + # Car robot + "SafetyCarGoal0-v0", + "SafetyCarGoal1-v0", + "SafetyCarGoal2-v0", + "SafetyCarButton0-v0", + "SafetyCarButton1-v0", + "SafetyCarButton2-v0", + "SafetyCarPush0-v0", + "SafetyCarPush1-v0", + "SafetyCarPush2-v0", + "SafetyCarCircle0-v0", + "SafetyCarCircle1-v0", + "SafetyCarRace0-v0", + "SafetyCarRace1-v0", + "SafetyCarRace2-v0", + # Mujoco velocity tasks + "SafetyAntVelocity-v1", + "SafetyHalfCheetahVelocity-v1", + "SafetyHopperVelocity-v1", + "SafetyHumanoidVelocity-v1", + "SafetySwimmerVelocity-v1", + "SafetyWalker2dVelocity-v1", + ] + + +class SafetyGymnasiumEnv(GymEnv): + """Safety-Gymnasium environment built from an env id. + + See :class:`SafetyGymnasiumWrapper` for behavior details. The constructor + builds the environment via ``safety_gymnasium.make(env_name)`` and applies + the same cost-extraction pipeline. + + Args: + env_name (str): the safety-gymnasium task id, e.g. + ``"SafetyPointGoal1-v0"``. + + Examples: + >>> from torchrl.envs.libs.safety_gymnasium import SafetyGymnasiumEnv + >>> env = SafetyGymnasiumEnv(env_name="SafetyPointGoal1-v0") # doctest: +SKIP + >>> td = env.rollout(3) # doctest: +SKIP + >>> assert ("next", "cost") in td.keys(True) # doctest: +SKIP + + """ + + git_url = "https://github.com/PKU-Alignment/safety-gymnasium" + libname = "safety-gymnasium" + + available_envs = SafetyGymnasiumWrapper.available_envs + + @property + def lib(self) -> ModuleType: + if _has_safety_gymnasium: + import safety_gymnasium + + return safety_gymnasium + try: + import safety_gymnasium # noqa: F401 + except ImportError as err: + raise ImportError( + "safety-gymnasium not found, install with " + "`pip install safety-gymnasium`" + ) from err + + _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs) + + def __init__(self, env_name=None, **kwargs): + super().__init__(env_name=env_name, **kwargs) + self.set_info_dict_reader(_make_cost_reader()) + + def _output_transform(self, step_outputs_tuple): + observations, reward, cost, terminated, truncated, info = step_outputs_tuple + info = dict(info) if info is not None else {} + # The default info_dict_reader expects values with a `.dtype` + # attribute. safety-gymnasium emits cost as a Python float, so we + # promote it to a numpy scalar of fixed dtype. + info["cost"] = np.asarray(cost, dtype=np.float64) + return ( + observations, + reward, + terminated, + truncated, + terminated | truncated, + info, + ) From a5dfcf7da322bbafeb1da6095663232c3b79ca6e Mon Sep 17 00:00:00 2001 From: Achintya P Date: Tue, 28 Apr 2026 15:09:58 -0700 Subject: [PATCH 2/3] added documentation, torch tensor, and autodiscover. implementated the feedback --- .../scripts_safety_gymnasium/install.sh | 70 +++++++ .../scripts_safety_gymnasium/post_process.sh | 6 + .../scripts_safety_gymnasium/requirements.txt | 18 ++ .../scripts_safety_gymnasium/run_test.sh | 28 +++ .../scripts_safety_gymnasium/setup_env.sh | 65 ++++++ .github/workflows/test-linux-libs.yml | 40 ++++ docs/source/reference/envs_libraries.rst | 2 + torchrl/envs/libs/safety_gymnasium.py | 185 ++++++++---------- 8 files changed, 310 insertions(+), 104 deletions(-) create mode 100755 .github/unittest/linux_libs/scripts_safety_gymnasium/install.sh create mode 100755 .github/unittest/linux_libs/scripts_safety_gymnasium/post_process.sh create mode 100644 .github/unittest/linux_libs/scripts_safety_gymnasium/requirements.txt create mode 100755 .github/unittest/linux_libs/scripts_safety_gymnasium/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_safety_gymnasium/setup_env.sh diff --git a/.github/unittest/linux_libs/scripts_safety_gymnasium/install.sh b/.github/unittest/linux_libs/scripts_safety_gymnasium/install.sh new file mode 100755 index 00000000000..d078a95e8e7 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_safety_gymnasium/install.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +# Ensure uv is in PATH +export PATH="$HOME/.local/bin:$PATH" + +# Activate the virtual environment +source ./env/bin/activate + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu128" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + uv pip install torch --index-url https://download.pytorch.org/whl/cpu -U + else + uv pip install torch --index-url https://download.pytorch.org/whl/cu128 -U + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# Ensure tensordict and torchrl dependencies are installed +# (since we use --no-deps for tensordict and torchrl) +uv pip install numpy pyvers packaging cloudpickle + +# Install build dependencies for torchrl (needed with --no-build-isolation) +uv pip install setuptools wheel setuptools_scm ninja "pybind11[global]" + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + uv pip install --no-deps git+https://github.com/pytorch/tensordict.git +else + uv pip install --no-deps tensordict +fi + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python -m pip install -e . --no-build-isolation --no-deps + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_safety_gymnasium/post_process.sh b/.github/unittest/linux_libs/scripts_safety_gymnasium/post_process.sh new file mode 100755 index 00000000000..b143bf9fe69 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_safety_gymnasium/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +# Activate the virtual environment +source ./env/bin/activate diff --git a/.github/unittest/linux_libs/scripts_safety_gymnasium/requirements.txt b/.github/unittest/linux_libs/scripts_safety_gymnasium/requirements.txt new file mode 100644 index 00000000000..e3e8e971633 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_safety_gymnasium/requirements.txt @@ -0,0 +1,18 @@ +# Core dependencies for tensordict/torchrl (installed with --no-deps) +numpy +pyvers +packaging +cloudpickle + +# Test dependencies +pytest +pytest-xdist +pytest-instafail +pytest-error-for-skips +coverage + +# Safety-Gymnasium and friends +safety-gymnasium +gymnasium +mujoco +imageio diff --git a/.github/unittest/linux_libs/scripts_safety_gymnasium/run_test.sh b/.github/unittest/linux_libs/scripts_safety_gymnasium/run_test.sh new file mode 100755 index 00000000000..e8aec8ee903 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_safety_gymnasium/run_test.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -e + +# Activate the virtual environment +source ./env/bin/activate + +apt-get update && apt-get install -y git wget cmake + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +deactivate 2>/dev/null || true && source ./env/bin/activate + +# this workflow only tests the libs +python -c "import safety_gymnasium" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/libs --instafail -v --durations 200 --capture no -k TestSafetyGymnasium --error-for-skips --runslow + +coverage combine -q +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_safety_gymnasium/setup_env.sh b/.github/unittest/linux_libs/scripts_safety_gymnasium/setup_env.sh new file mode 100755 index 00000000000..0f1210f3a54 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_safety_gymnasium/setup_env.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +apt-get update && apt-get upgrade -y && apt-get install -y git cmake +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +apt-get install -y wget \ + gcc \ + g++ \ + unzip \ + curl \ + patchelf \ + libosmesa6-dev \ + libgl1-mesa-glx \ + libglfw3 \ + swig3.0 \ + libglew-dev \ + libglvnd0 \ + libgl1 \ + libglx0 \ + libegl1 \ + libgles2 \ + libglib2.0-0 + +# Upgrade specific package +apt-get upgrade -y libstdc++6 + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +# Install uv if not already installed +if ! command -v uv &> /dev/null; then + printf "* Installing uv\n" + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi + +# Create virtual environment using uv +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment with uv\n" + uv venv "${env_dir}" --python "${PYTHON_VERSION}" +fi + +# Activate the virtual environment +source "${env_dir}/bin/activate" + +# Upgrade pip +uv pip install --upgrade pip + +# Install dependencies from requirements.txt +printf "* Installing dependencies (except PyTorch)\n" +if [ -f "${this_dir}/requirements.txt" ]; then + uv pip install -r "${this_dir}/requirements.txt" +fi diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 9eb25699f02..1a246200820 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -625,6 +625,46 @@ jobs: bash .github/unittest/linux_libs/scripts_procgen/run_test.sh bash .github/unittest/linux_libs/scripts_procgen/post_process.sh + unittests-safety-gymnasium: + strategy: + matrix: + python_version: ["3.10"] + cuda_arch_version: ["12.8"] + if: ${{ github.event_name == 'push' || github.event_name == 'workflow_call' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'Environments') || contains(github.event.pull_request.labels.*.name, 'Environments/safety_gymnasium') }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "12.8" + docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.10" + export CU_VERSION="12.8" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_safety_gymnasium/setup_env.sh + bash .github/unittest/linux_libs/scripts_safety_gymnasium/install.sh + PYTHON=./env/bin/python bash .github/unittest/helpers/assert_torch_version.sh "$TORCH_VERSION" + bash .github/unittest/linux_libs/scripts_safety_gymnasium/run_test.sh + bash .github/unittest/linux_libs/scripts_safety_gymnasium/post_process.sh + unittests-robohive: strategy: matrix: diff --git a/docs/source/reference/envs_libraries.rst b/docs/source/reference/envs_libraries.rst index d36a1b9db4b..b839275b2cc 100644 --- a/docs/source/reference/envs_libraries.rst +++ b/docs/source/reference/envs_libraries.rst @@ -105,6 +105,8 @@ Available wrappers PettingZooWrapper ProcgenWrapper RoboHiveEnv + SafetyGymnasiumEnv + SafetyGymnasiumWrapper SMACv2Env SMACv2Wrapper UnityMLAgentsEnv diff --git a/torchrl/envs/libs/safety_gymnasium.py b/torchrl/envs/libs/safety_gymnasium.py index b8b16fe71b6..e4c6dfe8962 100644 --- a/torchrl/envs/libs/safety_gymnasium.py +++ b/torchrl/envs/libs/safety_gymnasium.py @@ -7,39 +7,88 @@ import importlib.util from types import ModuleType -import numpy as np import torch -from torchrl.data.tensor_specs import Composite, Unbounded -from torchrl.envs.gym_like import default_info_dict_reader +from torchrl.data.tensor_specs import Unbounded from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend from torchrl.envs.utils import _classproperty _has_safety_gymnasium = importlib.util.find_spec("safety_gymnasium") is not None -def _make_cost_reader() -> default_info_dict_reader: - cost_spec = Composite( - cost=Unbounded(shape=(), dtype=torch.float64), shape=[] +def _list_safety_gymnasium_envs() -> list[str]: + """Discover task ids exposed by safety-gymnasium. + + safety-gymnasium registers many id variants (``*Gymnasium``, + ``*Vision*``, ``*Debug``, ``*FadingEasy*``, ...). We surface the + canonical 6-tuple-step ids and skip the ``Gymnasium`` variants because + those return the standard 5-tuple and would not match this wrapper's + ``_output_transform``. + """ + if not _has_safety_gymnasium: + return [] + import gymnasium + import safety_gymnasium # noqa: F401 -- import side-effect: register envs + + return sorted( + env_id + for env_id in gymnasium.envs.registry + if env_id.startswith("Safety") and "Gymnasium" not in env_id ) - return default_info_dict_reader(["cost"], spec=cost_spec) -class SafetyGymnasiumWrapper(GymWrapper): +class _SafetyGymCostMixin: + """Expose safety-gymnasium's per-step ``cost`` signal as a top-level + observation key. + + safety-gymnasium's ``step`` returns a 6-tuple + ``(obs, reward, cost, terminated, truncated, info)``. We collapse the + extra ``cost`` element into a stashed attribute, then write it onto + the step/reset tensordict so it travels with the observation rather + than through the info-dict-reader machinery. + """ + + def _post_init_cost(self) -> None: + self.observation_spec["cost"] = Unbounded( + shape=(), dtype=torch.float64 + ) + self._last_cost = torch.zeros((), dtype=torch.float64) + + def _output_transform(self, step_outputs_tuple): + observations, reward, cost, terminated, truncated, info = step_outputs_tuple + self._last_cost = torch.as_tensor(cost, dtype=torch.float64) + return ( + observations, + reward, + terminated, + truncated, + terminated | truncated, + info, + ) + + def _step(self, tensordict): + out = super()._step(tensordict) + out.set("cost", self._last_cost) + return out + + def _reset(self, tensordict=None, **kwargs): + out = super()._reset(tensordict, **kwargs) + out.set("cost", torch.zeros_like(self._last_cost)) + return out + + +class SafetyGymnasiumWrapper(_SafetyGymCostMixin, GymWrapper): """Safety-Gymnasium environment wrapper. - Safety-Gymnasium (https://github.com/PKU-Alignment/safety-gymnasium) is the - actively-maintained successor to OpenAI's Safety-Gym. It provides - constrained-RL benchmarks where each step emits a parallel ``cost`` signal - alongside the standard reward, allowing agents to optimize reward subject - to a safety budget. + Safety-Gymnasium (https://github.com/PKU-Alignment/safety-gymnasium) is + the actively-maintained successor to OpenAI's Safety-Gym. It provides + constrained-RL benchmarks where each step emits a parallel ``cost`` + signal alongside the standard reward, allowing agents to optimize + reward subject to a safety budget. - The underlying ``step`` API returns a 6-tuple - ``(obs, reward, cost, terminated, truncated, info)``. This wrapper folds - ``cost`` into the info dict so that the standard - :class:`~torchrl.envs.libs.gym.GymWrapper` machinery can be reused, and - registers an info-dict reader that exposes ``cost`` as a top-level key in - the returned tensordict. + The underlying ``step`` API returns a 6-tuple. This wrapper folds + ``cost`` into the output tensordict as a top-level key alongside + ``reward``. Args: env (safety_gymnasium.Env): the environment to wrap. @@ -61,78 +110,20 @@ class SafetyGymnasiumWrapper(GymWrapper): def __init__(self, env=None, **kwargs): super().__init__(env=env, **kwargs) - self.set_info_dict_reader(_make_cost_reader()) - - def _output_transform(self, step_outputs_tuple): - observations, reward, cost, terminated, truncated, info = step_outputs_tuple - info = dict(info) if info is not None else {} - # The default info_dict_reader expects values with a `.dtype` - # attribute. safety-gymnasium emits cost as a Python float, so we - # promote it to a numpy scalar of fixed dtype. - info["cost"] = np.asarray(cost, dtype=np.float64) - return ( - observations, - reward, - terminated, - truncated, - terminated | truncated, - info, - ) + self._post_init_cost() @_classproperty def available_envs(cls): - if not _has_safety_gymnasium: - return [] - # Curated list of canonical safety-gymnasium task ids. The library - # registers more (different difficulty levels, vision variants, etc.); - # this list mirrors the ones documented as primary benchmarks. - return [ - # Point robot - "SafetyPointGoal0-v0", - "SafetyPointGoal1-v0", - "SafetyPointGoal2-v0", - "SafetyPointButton0-v0", - "SafetyPointButton1-v0", - "SafetyPointButton2-v0", - "SafetyPointPush0-v0", - "SafetyPointPush1-v0", - "SafetyPointPush2-v0", - "SafetyPointCircle0-v0", - "SafetyPointCircle1-v0", - "SafetyPointRace0-v0", - "SafetyPointRace1-v0", - "SafetyPointRace2-v0", - # Car robot - "SafetyCarGoal0-v0", - "SafetyCarGoal1-v0", - "SafetyCarGoal2-v0", - "SafetyCarButton0-v0", - "SafetyCarButton1-v0", - "SafetyCarButton2-v0", - "SafetyCarPush0-v0", - "SafetyCarPush1-v0", - "SafetyCarPush2-v0", - "SafetyCarCircle0-v0", - "SafetyCarCircle1-v0", - "SafetyCarRace0-v0", - "SafetyCarRace1-v0", - "SafetyCarRace2-v0", - # Mujoco velocity tasks - "SafetyAntVelocity-v1", - "SafetyHalfCheetahVelocity-v1", - "SafetyHopperVelocity-v1", - "SafetyHumanoidVelocity-v1", - "SafetySwimmerVelocity-v1", - "SafetyWalker2dVelocity-v1", - ] - - -class SafetyGymnasiumEnv(GymEnv): + return _list_safety_gymnasium_envs() + + +class SafetyGymnasiumEnv(_SafetyGymCostMixin, GymEnv): """Safety-Gymnasium environment built from an env id. - See :class:`SafetyGymnasiumWrapper` for behavior details. The constructor - builds the environment via ``safety_gymnasium.make(env_name)`` and applies - the same cost-extraction pipeline. + See :class:`SafetyGymnasiumWrapper` for behavior details. The + constructor builds the environment via + ``safety_gymnasium.make(env_name)`` and applies the same + cost-extraction pipeline. Args: env_name (str): the safety-gymnasium task id, e.g. @@ -149,7 +140,9 @@ class SafetyGymnasiumEnv(GymEnv): git_url = "https://github.com/PKU-Alignment/safety-gymnasium" libname = "safety-gymnasium" - available_envs = SafetyGymnasiumWrapper.available_envs + @_classproperty + def available_envs(cls): + return _list_safety_gymnasium_envs() @property def lib(self) -> ModuleType: @@ -169,20 +162,4 @@ def lib(self) -> ModuleType: def __init__(self, env_name=None, **kwargs): super().__init__(env_name=env_name, **kwargs) - self.set_info_dict_reader(_make_cost_reader()) - - def _output_transform(self, step_outputs_tuple): - observations, reward, cost, terminated, truncated, info = step_outputs_tuple - info = dict(info) if info is not None else {} - # The default info_dict_reader expects values with a `.dtype` - # attribute. safety-gymnasium emits cost as a Python float, so we - # promote it to a numpy scalar of fixed dtype. - info["cost"] = np.asarray(cost, dtype=np.float64) - return ( - observations, - reward, - terminated, - truncated, - terminated | truncated, - info, - ) + self._post_init_cost() From 55cea2715de70090a7cf00e56bb7eb2e3adb777d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 29 Apr 2026 10:04:00 +0100 Subject: [PATCH 3/3] linter --- test/libs/test_safety_gymnasium.py | 4 +--- torchrl/envs/libs/safety_gymnasium.py | 7 ++----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/test/libs/test_safety_gymnasium.py b/test/libs/test_safety_gymnasium.py index 9f7ad8fa75b..bc79b9301a5 100644 --- a/test/libs/test_safety_gymnasium.py +++ b/test/libs/test_safety_gymnasium.py @@ -15,9 +15,7 @@ from torchrl.envs.utils import check_env_specs -@pytest.mark.skipif( - not _has_safety_gymnasium, reason="safety-gymnasium not installed" -) +@pytest.mark.skipif(not _has_safety_gymnasium, reason="safety-gymnasium not installed") class TestSafetyGymnasium: def test_wrapper_specs(self): import safety_gymnasium diff --git a/torchrl/envs/libs/safety_gymnasium.py b/torchrl/envs/libs/safety_gymnasium.py index e4c6dfe8962..bca9906bc5d 100644 --- a/torchrl/envs/libs/safety_gymnasium.py +++ b/torchrl/envs/libs/safety_gymnasium.py @@ -38,8 +38,7 @@ def _list_safety_gymnasium_envs() -> list[str]: class _SafetyGymCostMixin: - """Expose safety-gymnasium's per-step ``cost`` signal as a top-level - observation key. + """Expose safety-gymnasium's per-step ``cost`` signal as a top-level observation key. safety-gymnasium's ``step`` returns a 6-tuple ``(obs, reward, cost, terminated, truncated, info)``. We collapse the @@ -49,9 +48,7 @@ class _SafetyGymCostMixin: """ def _post_init_cost(self) -> None: - self.observation_spec["cost"] = Unbounded( - shape=(), dtype=torch.float64 - ) + self.observation_spec["cost"] = Unbounded(shape=(), dtype=torch.float64) self._last_cost = torch.zeros((), dtype=torch.float64) def _output_transform(self, step_outputs_tuple):