diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index 1228cfbc12b2..c72980ae4462 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -377,7 +377,7 @@ - DOC_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-ray_air,-gpu,-py37,-post_wheel_build doc/... -- label: ":book: :ariplane: Ray AIR examples" +- label: ":book: :airplane: Ray AIR examples" conditions: ["RAY_CI_PYTHON_AFFECTED", "RAY_CI_TUNE_AFFECTED", "RAY_CI_DOC_AFFECTED", "RAY_CI_SERVE_AFFECTED"] commands: diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000000..cbfc76e415d9 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Formatted Python code with Black +7f1bacc7dc9caf6d0ec042e39499bbf1d9a7d065 diff --git a/.github/ISSUE_TEMPLATE/documentation-issue.yml b/.github/ISSUE_TEMPLATE/documentation-issue.yml new file mode 100644 index 000000000000..a35084dc5669 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation-issue.yml @@ -0,0 +1,26 @@ +name: Documentation +title: "[] " +description: Report an issue with the Ray documentation +labels: [docs] +body: + - type: markdown + attributes: + value: Thank you for helping us improve the Ray documentation! + + - type: textarea + attributes: + label: Description + description: | + Tell us about the change you'd like to see. For example, "I'd like to + see more examples of how to use `ray.remote`." + validations: + required: true + + - type: textarea + attributes: + label: Link + description: | + If the problem is related to an existing section, please add a link to + the section. For example, https://docs.ray.io/en/master/ray-core/package-ref.html#ray.remote. + validations: + required: false diff --git a/.gitignore b/.gitignore index a13d7d04ad3d..42b5c72c9722 100644 --- a/.gitignore +++ b/.gitignore @@ -213,3 +213,11 @@ workflow_data/ # Jupyter Notebooks **/.ipynb_checkpoints/ + +### Added by Hedron's Bazel Compile Commands Extractor: https://github.com/hedronvision/bazel-compile-commands-extractor +# The external link: Differs on Windows vs macOS/Linux, so we can't check it in. The pattern needs to not have a trailing / because it's a symlink on macOS/Linux. +/external +# Compiled output -> don't check in +/compile_commands.json +# Directory where clangd puts its indexing work +/.cache/ diff --git a/WORKSPACE b/WORKSPACE index c175d8e40d2e..1d6cb1ec30b3 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -20,3 +20,9 @@ load("@bazel_skylib//lib:versions.bzl", "versions") # When the bazel version is updated, make sure to update it # in setup.py as well. versions.check(minimum_bazel_version = "4.2.1") + +# Tools to generate `compile_commands.json` to enable awesome tooling of the C language family. +# Just run `bazel run @hedron_compile_commands//:refresh_all` +load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") + +hedron_compile_commands_setup() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 64d77e40de10..4793a551eb88 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -306,3 +306,16 @@ def ray_deps_setup(): ], sha256 = "379113459b0feaf6bfbb584a91874c065078aa673222846ac765f86661c27407", ) + + # Hedron's Compile Commands Extractor for Bazel + # https://github.com/hedronvision/bazel-compile-commands-extractor + http_archive( + name = "hedron_compile_commands", + + # Replace the commit hash in both places (below) with the latest, rather than using the stale one here. + # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/cfd16a16cb4c4f27337ef652aa8510dcf1dd01ce.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-cfd16a16cb4c4f27337ef652aa8510dcf1dd01ce", + # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..." + sha256 = "4c2753a8d446f561391b7968a6d0eed748e8bb0f40adeda51301c57e829c7696", + ) diff --git a/ci/env/install-dependencies.sh b/ci/env/install-dependencies.sh index 71d566c7ceec..8e6c6a640dbb 100755 --- a/ci/env/install-dependencies.sh +++ b/ci/env/install-dependencies.sh @@ -379,7 +379,7 @@ install_dependencies() { # dependencies with Modin. if [ "${INSTALL_LUDWIG-}" = 1 ]; then # TODO: eventually pin this to master. - pip install -U "ludwig[test]">=0.4 + pip install -U "ludwig[test]">=0.4 jsonschema>=4 fi # Data processing test dependencies. diff --git a/ci/lint/check_api_annotations.py b/ci/lint/check_api_annotations.py index 04675ba3bf59..f7d32e51d155 100755 --- a/ci/lint/check_api_annotations.py +++ b/ci/lint/check_api_annotations.py @@ -5,7 +5,16 @@ import ray from ray.util.annotations import _is_annotated -IGNORE_PATHS = {".impl.", ".backend.", ".experimental.", ".internal.", ".generated."} +IGNORE_PATHS = { + ".impl.", + ".backend.", + ".experimental.", + ".internal.", + ".generated.", + ".test_utils.", + ".annotations.", + ".deprecation", +} def _fullname(attr): @@ -76,11 +85,12 @@ def verify(symbol, scanned, ok, output, prefix=None): verify(ray.data, set(), ok, output) # Sanity check the lint logic. assert len(ok) >= 60, len(ok) + + verify(ray.rllib, set(), ok, output) # TODO(ekl) enable it for all modules. # verify(ray.ml, set(), ok, output) # verify(ray.train, set(), ok, output) # verify(ray.serve, set(), ok, output) - # verify(ray.rllib, set(), ok, output) # verify(ray.tune, set(), ok, output) # verify(ray, set(), ok, output) diff --git a/ci/pipeline/determine_tests_to_run.py b/ci/pipeline/determine_tests_to_run.py index 4be1491c747f..96a4176c84a6 100644 --- a/ci/pipeline/determine_tests_to_run.py +++ b/ci/pipeline/determine_tests_to_run.py @@ -201,6 +201,10 @@ def get_commit_range(): elif any(changed_file.startswith(prefix) for prefix in skip_prefix_list): # nothing is run but linting in these cases pass + elif changed_file.startswith("release/ray_release/"): + # Tests for release/ray_release always run, so it is unnecessary to + # tag affected tests. + pass elif changed_file.endswith("build-docker-images.py"): RAY_CI_DOCKER_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 diff --git a/dashboard/client/src/pages/dashboard/node-info/NodeInfo.tsx b/dashboard/client/src/pages/dashboard/node-info/NodeInfo.tsx index b49428530670..fd6eead1016f 100644 --- a/dashboard/client/src/pages/dashboard/node-info/NodeInfo.tsx +++ b/dashboard/client/src/pages/dashboard/node-info/NodeInfo.tsx @@ -217,7 +217,10 @@ const NodeInfo: React.FC<{}> = () => { // If a Ray node is running in a K8s pod, it marks available disk as 1 byte. // (See ReporterAgent._get_disk_usage() in reporter_agent.py) // Check if there are any nodes with realistic disk total: - const showDisk = nodes.filter((n) => n.disk["/"].total > 10).length !== 0; + const showDisk = + nodes.filter( + (n) => n !== undefined && n.disk !== undefined && n.disk["/"].total > 10, + ).length !== 0; const filterPredicate = ( feature: NodeInfoFeature | HeaderInfo, diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index a38a59e6e220..0925e7433375 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -415,10 +415,16 @@ def _get_supervisor_runtime_env( runtime_env = ( copy.deepcopy(user_runtime_env) if user_runtime_env is not None else {} ) + + # NOTE(edoakes): Can't use .get(, {}) here because we need to handle the case + # where env_vars is explicitly set to `None`. + env_vars = runtime_env.get("env_vars") + if env_vars is None: + env_vars = {} + # Don't set CUDA_VISIBLE_DEVICES for the supervisor actor so the # driver can use GPUs if it wants to. This will be removed from # the driver's runtime_env so it isn't inherited by tasks & actors. - env_vars = runtime_env.get("env_vars", {}) env_vars[ray_constants.NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR] = "1" runtime_env["env_vars"] = env_vars return runtime_env diff --git a/dashboard/modules/job/tests/test_job_manager.py b/dashboard/modules/job/tests/test_job_manager.py index c9c1e18d11d3..ffe62e954fee 100644 --- a/dashboard/modules/job/tests/test_job_manager.py +++ b/dashboard/modules/job/tests/test_job_manager.py @@ -3,6 +3,7 @@ import psutil import tempfile import sys +import urllib.request from uuid import uuid4 import signal @@ -270,6 +271,24 @@ async def test_submit_with_s3_runtime_env(self, job_manager): job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n" ) + async def test_submit_with_file_runtime_env(self, job_manager): + with tempfile.NamedTemporaryFile(suffix=".zip") as f: + filename, _ = urllib.request.urlretrieve( + "https://runtime-env-test.s3.amazonaws.com/script_runtime_env.zip", + filename=f.name, + ) + job_id = job_manager.submit_job( + entrypoint="python script.py", + runtime_env={"working_dir": "file://" + filename}, + ) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) + assert ( + job_manager.get_job_logs(job_id) + == "Executing main() from script.py !!\n" + ) + @pytest.mark.asyncio class TestRuntimeEnv: @@ -424,13 +443,21 @@ def dict_to_str(d): {JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id} ) in job_manager.get_job_logs(job_id) - async def test_cuda_visible_devices(self, job_manager): + @pytest.mark.parametrize( + "env_vars", + [None, {}, {"hello": "world"}], + ) + async def test_cuda_visible_devices(self, job_manager, env_vars): """Check CUDA_VISIBLE_DEVICES behavior. Should not be set in the driver, but should be set in tasks. + + We test a variety of `env_vars` parameters due to custom parsing logic + that caused https://github.com/ray-project/ray/issues/25086. """ run_cmd = f"python {_driver_script_path('check_cuda_devices.py')}" - job_id = job_manager.submit_job(entrypoint=run_cmd) + runtime_env = {"env_vars": env_vars} + job_id = job_manager.submit_job(entrypoint=run_cmd, runtime_env=runtime_env) await async_wait_for_condition( check_job_succeeded, job_manager=job_manager, job_id=job_id diff --git a/dashboard/modules/serve/serve_head.py b/dashboard/modules/serve/serve_head.py index 143d4161e728..f8918999a350 100644 --- a/dashboard/modules/serve/serve_head.py +++ b/dashboard/modules/serve/serve_head.py @@ -33,14 +33,14 @@ async def get_all_deployments(self, req: Request) -> Response: @routes.get("/api/serve/deployments/status") @optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True) async def get_all_deployment_statuses(self, req: Request) -> Response: - from ray.serve.api import get_deployment_statuses - from ray.serve.schema import serve_application_status_to_schema + from ray.serve.context import get_global_client + from ray.serve.schema import serve_status_to_schema - serve_application_status_schema = serve_application_status_to_schema( - get_deployment_statuses() - ) + client = get_global_client(_override_controller_namespace="serve") + + serve_status_schema = serve_status_to_schema(client.get_serve_status()) return Response( - text=serve_application_status_schema.json(), + text=serve_status_schema.json(), content_type="application/json", ) diff --git a/dashboard/modules/serve/tests/test_serve_head.py b/dashboard/modules/serve/tests/test_serve_head.py index 447f1ed8e6b6..908e0772253e 100644 --- a/dashboard/modules/serve/tests/test_serve_head.py +++ b/dashboard/modules/serve/tests/test_serve_head.py @@ -1,14 +1,15 @@ +import os +import sys import json +import time +import pytest +import requests import subprocess -import sys -import os from typing import List, Dict, Set -import pytest - -import requests import ray from ray import serve +from ray._private.test_utils import wait_for_condition GET_OR_PUT_URL = "http://localhost:8265/api/serve/deployments/" @@ -124,6 +125,14 @@ def test_put_get_success(ray_start_stop): GET_OR_PUT_URL, json={"deployments": deployments}, timeout=30 ) assert put_response.status_code == 200 + + # Use wait_for_condition() to ensure "deep" deployment deleted + wait_for_condition( + lambda: len(requests.get(GET_OR_PUT_URL, timeout=3).json()["deployments"]) + == 2, + timeout=10, + ) + assert ( requests.get("http://localhost:8000/shallow", timeout=30).text == "Hello shallow world!" @@ -176,9 +185,12 @@ def test_delete_success(ray_start_stop): delete_response = requests.delete(GET_OR_PUT_URL, timeout=30) assert delete_response.status_code == 200 - # Make sure no deployments exist - get_response = requests.get(GET_OR_PUT_URL, timeout=30) - assert len(get_response.json()["deployments"]) == 0 + # Make sure all deployments are deleted + wait_for_condition( + lambda: len(requests.get(GET_OR_PUT_URL, timeout=3).json()["deployments"]) + == 0, + timeout=10, + ) def test_get_status_info(ray_start_stop): @@ -216,18 +228,23 @@ def test_get_status_info(ray_start_stop): status_response = requests.get(STATUS_URL, timeout=30) assert status_response.status_code == 200 + serve_status = status_response.json() - statuses = status_response.json()["statuses"] - assert len(statuses) == len(deployments) + deployment_statuses = serve_status["deployment_statuses"] + assert len(deployment_statuses) == len(deployments) expected_deployment_names = {deployment["name"] for deployment in deployments} - for deployment_status in statuses: + for deployment_status in deployment_statuses: assert deployment_status["name"] in expected_deployment_names expected_deployment_names.remove(deployment_status["name"]) assert deployment_status["status"] in {"UPDATING", "HEALTHY"} assert deployment_status["message"] == "" assert len(expected_deployment_names) == 0 - print(statuses) + assert serve_status["app_status"]["status"] in {"DEPLOYING", "RUNNING"} + wait_for_condition( + lambda: time.time() > serve_status["app_status"]["deployment_timestamp"], + timeout=2, + ) def test_serve_namespace(ray_start_stop): diff --git a/dashboard/modules/state/state_head.py b/dashboard/modules/state/state_head.py index 66af666415dc..13037503904a 100644 --- a/dashboard/modules/state/state_head.py +++ b/dashboard/modules/state/state_head.py @@ -1,4 +1,5 @@ import logging + import aiohttp.web import dataclasses @@ -10,6 +11,7 @@ from ray.dashboard.optional_utils import rest_response from ray.dashboard.state_aggregator import StateAPIManager from ray.experimental.state.common import ListApiOptions +from ray.experimental.state.exception import DataSourceUnavailable from ray.experimental.state.state_manager import StateDataSourceClient logger = logging.getLogger(__name__) @@ -33,13 +35,19 @@ def __init__(self, dashboard_head): def _options_from_req(self, req) -> ListApiOptions: """Obtain `ListApiOptions` from the aiohttp request.""" limit = int(req.query.get("limit")) + # Only apply 80% of the timeout so that + # the API will reply before client times out if query to the source fails. timeout = int(req.query.get("timeout")) return ListApiOptions(limit=limit, timeout=timeout) - def _reply(self, success: bool, message: str, result: dict): + def _reply(self, success: bool, error_message: str, result: dict, **kwargs): """Reply to the client.""" return rest_response( - success=success, message=message, result=result, convert_google_style=False + success=success, + message=error_message, + result=result, + convert_google_style=False, + **kwargs, ) async def _update_raylet_stubs(self, change: Change): @@ -85,57 +93,61 @@ async def _update_agent_stubs(self, change: Change): int(ports[1]), ) + async def _handle_list_api(self, list_api_fn, req): + try: + result = await list_api_fn(option=self._options_from_req(req)) + return self._reply( + success=True, + error_message="", + result=result.result, + partial_failure_warning=result.partial_failure_warning, + ) + except DataSourceUnavailable as e: + return self._reply(success=False, error_message=str(e), result=None) + @routes.get("/api/v0/actors") async def list_actors(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_actors(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_actors, req) @routes.get("/api/v0/jobs") async def list_jobs(self, req) -> aiohttp.web.Response: - data = self._state_api.list_jobs(option=self._options_from_req(req)) - return self._reply( - success=True, - message="", - result={ - job_id: dataclasses.asdict(job_info) - for job_id, job_info in data.items() - }, - ) + try: + result = self._state_api.list_jobs(option=self._options_from_req(req)) + return self._reply( + success=True, + error_message="", + result={ + job_id: dataclasses.asdict(job_info) + for job_id, job_info in result.result.items() + }, + partial_failure_warning=result.partial_failure_warning, + ) + except DataSourceUnavailable as e: + return self._reply(success=False, error_message=str(e), result=None) @routes.get("/api/v0/nodes") async def list_nodes(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_nodes(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_nodes, req) @routes.get("/api/v0/placement_groups") async def list_placement_groups(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_placement_groups( - option=self._options_from_req(req) - ) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_placement_groups, req) @routes.get("/api/v0/workers") async def list_workers(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_workers(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_workers, req) @routes.get("/api/v0/tasks") async def list_tasks(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_tasks(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_tasks, req) @routes.get("/api/v0/objects") async def list_objects(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_objects(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_objects, req) @routes.get("/api/v0/runtime_envs") - @dashboard_optional_utils.aiohttp_cache async def list_runtime_envs(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_runtime_envs( - option=self._options_from_req(req) - ) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_runtime_envs, req) async def run(self, server): gcs_channel = self._dashboard_head.aiogrpc_gcs_channel diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index ec4d751eeb9b..65bc16130e2f 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -1,13 +1,12 @@ import asyncio import logging -from typing import List, Dict from itertools import islice +from typing import List from ray.core.generated.common_pb2 import TaskStatus import ray.dashboard.utils as dashboard_utils import ray.dashboard.memory_utils as memory_utils -from ray.dashboard.modules.job.common import JobInfo from ray.experimental.state.common import ( filter_fields, @@ -19,13 +18,34 @@ ObjectState, RuntimeEnvState, ListApiOptions, + ListApiResponse, +) +from ray.experimental.state.state_manager import ( + StateDataSourceClient, + DataSourceUnavailable, ) -from ray.experimental.state.state_manager import StateDataSourceClient from ray.runtime_env import RuntimeEnv from ray._private.utils import binary_to_hex logger = logging.getLogger(__name__) +GCS_QUERY_FAILURE_WARNING = ( + "Failed to query data from GCS. It is due to " + "(1) GCS is unexpectedly failed. " + "(2) GCS is overloaded. " + "(3) There's an unexpected network issue. " + "Please check the gcs_server.out log to find the root cause." +) +NODE_QUERY_FAILURE_WARNING = ( + "Failed to query data from {type}. " + "Queryed {total} {type} " + "and {network_failures} {type} failed to reply. It is due to " + "(1) {type} is unexpectedly failed. " + "(2) {type} is overloaded. " + "(3) There's an unexpected network issue. Please check the " + "{log_command} to find the root cause." +) + # TODO(sang): Move the class to state/state_manager.py. # TODO(sang): Remove *State and replaces with Pydantic or protobuf. @@ -42,14 +62,19 @@ def __init__(self, state_data_source_client: StateDataSourceClient): def data_source_client(self): return self._client - async def list_actors(self, *, option: ListApiOptions) -> dict: + async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: """List all actor information from the cluster. Returns: {actor_id -> actor_data_in_dict} actor_data_in_dict's schema is in ActorState + """ - reply = await self._client.get_all_actor_info(timeout=option.timeout) + try: + reply = await self._client.get_all_actor_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.actor_table_data: data = self._message_to_dict(message=message, fields_to_decode=["actor_id"]) @@ -58,16 +83,24 @@ async def list_actors(self, *, option: ListApiOptions) -> dict: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["actor_id"]) - return {d["actor_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["actor_id"]: d for d in islice(result, option.limit)} + ) - async def list_placement_groups(self, *, option: ListApiOptions) -> dict: + async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: """List all placement group information from the cluster. Returns: {pg_id -> pg_data_in_dict} pg_data_in_dict's schema is in PlacementGroupState """ - reply = await self._client.get_all_placement_group_info(timeout=option.timeout) + try: + reply = await self._client.get_all_placement_group_info( + timeout=option.timeout + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.placement_group_table_data: @@ -80,16 +113,22 @@ async def list_placement_groups(self, *, option: ListApiOptions) -> dict: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["placement_group_id"]) - return {d["placement_group_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["placement_group_id"]: d for d in islice(result, option.limit)} + ) - async def list_nodes(self, *, option: ListApiOptions) -> dict: + async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: """List all node information from the cluster. Returns: {node_id -> node_data_in_dict} node_data_in_dict's schema is in NodeState """ - reply = await self._client.get_all_node_info(timeout=option.timeout) + try: + reply = await self._client.get_all_node_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.node_info_list: data = self._message_to_dict(message=message, fields_to_decode=["node_id"]) @@ -98,16 +137,22 @@ async def list_nodes(self, *, option: ListApiOptions) -> dict: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["node_id"]) - return {d["node_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["node_id"]: d for d in islice(result, option.limit)} + ) - async def list_workers(self, *, option: ListApiOptions) -> dict: + async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: """List all worker information from the cluster. Returns: {worker_id -> worker_data_in_dict} worker_data_in_dict's schema is in WorkerState """ - reply = await self._client.get_all_worker_info(timeout=option.timeout) + try: + reply = await self._client.get_all_worker_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.worker_table_data: data = self._message_to_dict( @@ -119,34 +164,65 @@ async def list_workers(self, *, option: ListApiOptions) -> dict: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["worker_id"]) - return {d["worker_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["worker_id"]: d for d in islice(result, option.limit)} + ) - def list_jobs(self, *, option: ListApiOptions) -> Dict[str, JobInfo]: + def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: # TODO(sang): Support limit & timeout & async calls. - return self._client.get_job_info() + try: + result = self._client.get_job_info() + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + return ListApiResponse(result=result) - async def list_tasks(self, *, option: ListApiOptions) -> dict: + async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: """List all task information from the cluster. Returns: {task_id -> task_data_in_dict} task_data_in_dict's schema is in TaskState """ + raylet_ids = self._client.get_all_registered_raylet_ids() replies = await asyncio.gather( *[ self._client.get_task_info(node_id, timeout=option.timeout) - for node_id in self._client.get_all_registered_raylet_ids() - ] + for node_id in raylet_ids + ], + return_exceptions=True, ) + unresponsive_nodes = 0 running_task_id = set() + successful_replies = [] for reply in replies: + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + successful_replies.append(reply) for task_id in reply.running_task_ids: running_task_id.add(binary_to_hex(task_id)) + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + result = [] - for reply in replies: - logger.info(reply) + for reply in successful_replies: + assert not isinstance(reply, Exception) tasks = reply.owned_task_info_entries for task in tasks: data = self._message_to_dict( @@ -162,24 +238,36 @@ async def list_tasks(self, *, option: ListApiOptions) -> dict: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["task_id"]) - return {d["task_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["task_id"]: d for d in islice(result, option.limit)}, + partial_failure_warning=partial_failure_warning, + ) - async def list_objects(self, *, option: ListApiOptions) -> dict: + async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: """List all object information from the cluster. Returns: {object_id -> object_data_in_dict} object_data_in_dict's schema is in ObjectState """ + raylet_ids = self._client.get_all_registered_raylet_ids() replies = await asyncio.gather( *[ self._client.get_object_info(node_id, timeout=option.timeout) - for node_id in self._client.get_all_registered_raylet_ids() - ] + for node_id in raylet_ids + ], + return_exceptions=True, ) + unresponsive_nodes = 0 worker_stats = [] - for reply in replies: + for reply, node_id in zip(replies, raylet_ids): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + for core_worker_stat in reply.core_workers_stats: # NOTE: Set preserving_proto_field_name=False here because # `construct_memory_table` requires a dictionary that has @@ -193,6 +281,20 @@ async def list_objects(self, *, option: ListApiOptions) -> dict: ) ) + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + result = [] memory_table = memory_utils.construct_memory_table(worker_stats) for entry in memory_table.table: @@ -207,9 +309,12 @@ async def list_objects(self, *, option: ListApiOptions) -> dict: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["object_id"]) - return {d["object_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["object_id"]: d for d in islice(result, option.limit)}, + partial_failure_warning=partial_failure_warning, + ) - async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]: + async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: """List all runtime env information from the cluster. Returns: @@ -219,14 +324,24 @@ async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]: We don't have id -> data mapping like other API because runtime env doesn't have unique ids. """ + agent_ids = self._client.get_all_registered_agent_ids() replies = await asyncio.gather( *[ self._client.get_runtime_envs_info(node_id, timeout=option.timeout) - for node_id in self._client.get_all_registered_agent_ids() - ] + for node_id in agent_ids + ], + return_exceptions=True, ) + result = [] + unresponsive_nodes = 0 for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + states = reply.runtime_env_states for state in states: data = self._message_to_dict(message=state, fields_to_decode=[]) @@ -238,6 +353,20 @@ async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]: data = filter_fields(data, RuntimeEnvState) result.append(data) + partial_failure_warning = None + if len(agent_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="agent", + total=len(agent_ids), + network_failures=unresponsive_nodes, + log_command="dashboard_agent.log", + ) + if unresponsive_nodes == len(agent_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + # Sort to make the output deterministic. def sort_func(entry): # If creation time is not there yet (runtime env is failed @@ -251,7 +380,10 @@ def sort_func(entry): return float(entry["creation_time_ms"]) result.sort(key=sort_func, reverse=True) - return list(islice(result, option.limit)) + return ListApiResponse( + result=list(islice(result, option.limit)), + partial_failure_warning=partial_failure_warning, + ) def _message_to_dict( self, diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index 6735d4ae4a13..4064d6da49ea 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -697,6 +697,32 @@ def test_dashboard_port_conflict(ray_start_with_dashboard): raise Exception("Timed out while testing.") +@pytest.mark.skipif( + os.environ.get("RAY_MINIMAL") == "1", + reason="This test is not supposed to work for minimal installation.", +) +def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): + assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True + + all_processes = ray.worker._global_node.all_processes + dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] + dashboard_proc = psutil.Process(dashboard_info.process.pid) + gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0] + gcs_server_proc = psutil.Process(gcs_server_info.process.pid) + + assert dashboard_proc.status() in [ + psutil.STATUS_RUNNING, + psutil.STATUS_SLEEPING, + psutil.STATUS_DISK_SLEEP, + ] + + gcs_server_proc.kill() + gcs_server_proc.wait() + + # The dashboard exits by os._exit(-1) + assert dashboard_proc.wait(10) == 255 + + @pytest.mark.skipif( os.environ.get("RAY_DEFAULT") != "1", reason="This test only works for default installation.", diff --git a/doc/source/_toc.yml b/doc/source/_toc.yml index 887c8ad35aa4..2e86e62cd825 100644 --- a/doc/source/_toc.yml +++ b/doc/source/_toc.yml @@ -224,7 +224,9 @@ parts: - caption: Ray Clusters chapters: + - file: cluster/index - file: cluster/quickstart + - file: cluster/key-concepts - file: cluster/user-guide - file: cluster/cloud - file: cluster/deploy diff --git a/doc/source/cluster/api.rst b/doc/source/cluster/api.rst index 78dc6f4a17c6..131f5c50381e 100644 --- a/doc/source/cluster/api.rst +++ b/doc/source/cluster/api.rst @@ -1,3 +1,5 @@ +.. _ref-cluster-api: + Ray Cluster API =============== diff --git a/doc/source/cluster/guide.rst b/doc/source/cluster/guide.rst index 2ddabe15e2d2..9c6d12b08a5a 100644 --- a/doc/source/cluster/guide.rst +++ b/doc/source/cluster/guide.rst @@ -42,6 +42,7 @@ To simplify Operator configuration, Ray provides a :ref:`a Helm chart Installing the Helm chart will create an Operator Deployment. The Operator manages autoscaling Ray clusters; each Ray node runs in its own K8s Pod. +.. _deployment-guide-autoscaler: Autoscaling with Ray -------------------- diff --git a/doc/source/cluster/index.rst b/doc/source/cluster/index.rst index 9ab498565071..894ab6da80c9 100644 --- a/doc/source/cluster/index.rst +++ b/doc/source/cluster/index.rst @@ -4,42 +4,81 @@ .. _cluster-index: -Ray Cluster Overview -==================== +Ray Clusters Overview +===================== What is a Ray cluster? ------------------------- +---------------------- -One of Ray's strengths is the ability to leverage multiple machines in the same -program. Ray can, of course, be run on a single machine (and is done so often), -but the real power is using Ray on a cluster of machines. +One of Ray's strengths is the ability to leverage multiple machines for +distributed execution. Ray can, of course, be run on a single machine (and is +done so often), but the real power is using Ray on a cluster of machines. -A Ray cluster consists of a **head node** and a set of **worker nodes**. The -head node needs to be started first, and the worker nodes are given the address -of the head node to form the cluster: +Ray can automatically interact with the cloud provider to request or release +instances. You can specify :ref:`a configuration ` to launch +clusters on :ref:`AWS, GCP, Azure, Kubernetes, Aliyun, on-premise, or even on +your custom node provider `. Your cluster can have a fixed size +or :ref:`automatically scale up and down` depending on the +demands of your application. -.. image:: ray-cluster.jpg - :align: center - :width: 600px +Where to go from here? +---------------------- -You can use the Ray Cluster Launcher to provision machines and launch a -multi-node Ray cluster. You can use the cluster launcher :ref:`on AWS, GCP, -Azure, Kubernetes, Aliyun, on-premise, and Staroid or even on your custom node provider -`. Ray clusters can also make use of the Ray Autoscaler, which -allows Ray to interact with a cloud provider to request or release instances -following :ref:`a specification ` and according to application -workload. +.. panels:: + :container: text-center + :column: col-lg-6 px-2 py-2 + :card: + Quick Start + ^^^ -Next steps ----------- + In this quick start tutorial you will take a sample application designed to + run on a laptop and scale it up in the cloud. -To get started with Ray Clusters, we recommend that you check out the :ref:`Ray -Cluster Quick Start `. For more advanced examples of -use, you can also refer to the :ref:`full specification for Ray Cluster -configuration `. + +++ + .. link-button:: ref-cluster-quick-start + :type: ref + :text: Ray Clusters Quick Start + :classes: btn-outline-info btn-block + --- -To learn about best practices for deploying a Ray cluster, :ref:`check out the -deployment guide `. + Key Concepts + ^^^ + + Understand the key concepts behind Ray Clusters. Learn about the main + concepts and the different ways to interact with a cluster. + + +++ + .. link-button:: cluster-key-concepts + :type: ref + :text: Learn Key Concepts + :classes: btn-outline-info btn-block + --- + + Deployment Guide + ^^^ + + Learn how to set up a distributed Ray cluster and run your workloads on it. + + +++ + .. link-button:: ref-deployment-guide + :type: ref + :text: Deploy on a Ray Cluster + :classes: btn-outline-info btn-block + --- + + API + ^^^ + + Get more in-depth information about the various APIs to interact with Ray + Clusters, including the :ref:`Ray cluster config YAML and CLI`, + the :ref:`Ray Client API` and the + :ref:`Ray job submission API`. + + +++ + .. link-button:: ref-cluster-api + :type: ref + :text: Read the API Reference + :classes: btn-outline-info btn-block .. include:: /_includes/clusters/announcement_bottom.rst diff --git a/doc/source/cluster/key-concepts.rst b/doc/source/cluster/key-concepts.rst new file mode 100644 index 000000000000..e552eebde835 --- /dev/null +++ b/doc/source/cluster/key-concepts.rst @@ -0,0 +1,107 @@ +.. include:: we_are_hiring.rst + +.. _cluster-key-concepts: + +Key Concepts +============ + +Cluster +------- + +A Ray cluster is a set of one or more nodes that are running Ray and share the +same :ref:`head node`. + +.. _cluster-node-types: + +Node types +---------- + +A Ray cluster consists of a :ref:`head node` and a set of +:ref:`worker nodes`. + +.. image:: ray-cluster.jpg + :align: center + :width: 600px + +.. _cluster-head-node: + +Head node +~~~~~~~~~ + +The head node is the first node started by the +:ref:`Ray cluster launcher` when trying to launch a Ray +cluster. Among other things, the head node holds the :ref:`Global Control Store +(GCS)` and runs the :ref:`autoscaler`. Once the head +node is started, it will be responsible for launching any additional +:ref:`worker nodes`. The head node itself will also execute +tasks and actors to utilize its capacity. + +.. _cluster-worker-node: + +Worker node +~~~~~~~~~~~ + +A worker node is any node in the Ray cluster that is not functioning as head node. +Therefore, worker nodes are simply responsible for executing tasks and actors. +When a worker node is launched, it will be given the address of the head node to +form a cluster. + +.. _cluster-launcher: + +Cluster launcher +---------------- + +The cluster launcher is a process responsible for bootstrapping the Ray cluster +by launching the :ref:`head node`. For more information on how +to use the cluster launcher, refer to +:ref:`cluster launcher CLI commands documentation` and the +corresponding :ref:`documentation for the configuration file`. + +.. _cluster-autoscaler: + +Autoscaler +---------- + +The autoscaler is a process that runs on the :ref:`head node` +and is responsible for adding or removing :ref:`worker nodes` +to meet the needs of the Ray workload while matching the specification in the +:ref:`cluster config file`. In particular, if the resource +demands of the Ray workload exceed the current capacity of the cluster, the +autoscaler will try to add nodes. Conversely, if a node is idle for long enough, +the autoscaler will remove it from the cluster. To learn more about autoscaling, +refer to the :ref:`Ray cluster deployment guide`. + +Ray Client +---------- +The Ray Client is an API that connects a Python script to a remote Ray cluster. +To learn more about the Ray Client, you can refer to the :ref:`documentation`. + +Job submission +-------------- + +Ray Job submission is a mechanism to submit locally developed and tested applications +to a remote Ray cluster. It simplifies the experience of packaging, deploying, +and managing a Ray application. To learn more about Ray jobs, refer to the +:ref:`documentation`. + +Cloud clusters +-------------- + +If you’re using AWS, Azure, GCP or Aliyun, you can use the +:ref:`Ray cluster launcher` to launch cloud clusters, which +greatly simplifies the cluster setup process. + +Cluster managers +---------------- + +You can simplify the process of managing Ray clusters using a number of popular +cluster managers including :ref:`Kubernetes`, +:ref:`YARN`, :ref:`Slurm` and :ref:`LSF`. + +Kubernetes (K8s) operator +------------------------- + +Deployments of Ray on Kubernetes are managed by the Ray Kubernetes Operator. The +Ray Operator makes it easy to deploy clusters of Ray pods within a Kubernetes +cluster. To learn more about the K8s operator, refer to +the :ref:`documentation`. diff --git a/doc/source/cluster/quickstart.rst b/doc/source/cluster/quickstart.rst index 7b8e04f3e570..26b0f3dd5f62 100644 --- a/doc/source/cluster/quickstart.rst +++ b/doc/source/cluster/quickstart.rst @@ -4,8 +4,8 @@ .. _ref-cluster-quick-start: -Ray Cluster Quick Start -======================= +Ray Clusters Quick Start +======================== This quick start demonstrates the capabilities of the Ray cluster. Using the Ray cluster, we'll take a sample application designed to run on a laptop and scale it up in the cloud. Ray will launch clusters and scale Python with just a few commands. diff --git a/doc/source/cluster/user-guide.rst b/doc/source/cluster/user-guide.rst index 5ecca629cf92..5813c1474ded 100644 --- a/doc/source/cluster/user-guide.rst +++ b/doc/source/cluster/user-guide.rst @@ -16,7 +16,6 @@ To run an interactive Ray workload and see the output in real time in a client o .. toctree:: :maxdepth: 2 - index.rst guide.rst job-submission.rst ray-client.rst diff --git a/doc/source/data/dataset.rst b/doc/source/data/dataset.rst index bf3bb9caf315..74dc15bcefce 100644 --- a/doc/source/data/dataset.rst +++ b/doc/source/data/dataset.rst @@ -64,7 +64,7 @@ Advanced users can refer directly to the Ray Datasets :ref:`API reference ` @@ -78,7 +78,7 @@ Advanced users can refer directly to the Ray Datasets :ref:`API reference `, :ref:`save @@ -107,6 +107,19 @@ Advanced users can refer directly to the Ray Datasets :ref:`API reference `. For scaling diff --git a/doc/source/rllib/rllib-algorithms.rst b/doc/source/rllib/rllib-algorithms.rst index 89b18e4cf92b..69efcd391557 100644 --- a/doc/source/rllib/rllib-algorithms.rst +++ b/doc/source/rllib/rllib-algorithms.rst @@ -791,13 +791,13 @@ Tuned examples: `Two-step game `__ `[implementation] `__ MADDPG is a DDPG centralized/shared critic algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `justinkterry/maddpg-rllib `__ for examples and more information. Note that the implementation here is based on OpenAI's, and is intended for use with the discrete MPE environments. Please also note that people typically find this method difficult to get to work, even with all applicable optimizations for their environment applied. This method should be viewed as for research purposes, and for reproducing the results of the paper introducing it. +`[paper] `__ `[implementation] `__ MADDPG is a DDPG centralized/shared critic algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `justinkterry/maddpg-rllib `__ for examples and more information. Note that the implementation here is based on OpenAI's, and is intended for use with the discrete MPE environments. Please also note that people typically find this method difficult to get to work, even with all applicable optimizations for their environment applied. This method should be viewed as for research purposes, and for reproducing the results of the paper introducing it. **MADDPG-specific configs** (see also `common configs `__): Tuned examples: `Multi-Agent Particle Environment `__, `Two-step game `__ -.. literalinclude:: ../../../rllib/agents/maddpg/maddpg.py +.. literalinclude:: ../../../rllib/algorithms/maddpg/maddpg.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ diff --git a/doc/source/train/user_guide.rst b/doc/source/train/user_guide.rst index 4aff156bc85c..7dc1bd79ce3a 100644 --- a/doc/source/train/user_guide.rst +++ b/doc/source/train/user_guide.rst @@ -659,7 +659,7 @@ appropriately in distributed training. for epoch in range(config["num_epochs"]): model.fit(X, Y, batch_size=20) - train.save_checkpoint(epoch=epoch, model_weights=model.get_weights()) + train.save_checkpoint(epoch=epoch, model=model.get_weights()) trainer = Trainer(backend="tensorflow", num_workers=2) @@ -861,7 +861,7 @@ Checkpoints can be loaded into the training function in 2 steps: for epoch in range(start_epoch, config["num_epochs"]): model.fit(X, Y, batch_size=20) - train.save_checkpoint(epoch=epoch, model_weights=model.get_weights()) + train.save_checkpoint(epoch=epoch, model=model.get_weights()) trainer = Trainer(backend="tensorflow", num_workers=2) diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 5ac3eedd3e70..107321e13107 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -384,9 +384,3 @@ tune.with_parameters .. autofunction:: ray.tune.with_parameters - -StatusReporter --------------- - -.. autoclass:: ray.tune.function_runner.StatusReporter - :members: __call__, logdir diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index 2ba20dffe367..eeda9efc76c5 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -207,7 +207,7 @@ via ``ray.init()``, making your script on your laptop the "driver". # configure how checkpoints are sync'd to the scheduler/sampler # we recommend cloud storage checkpointing as it survives the cluster when # instances are terminated, and has better performance - sync_config = tune.syncConfig( + sync_config = tune.SyncConfig( upload_dir="s3://my-checkpoints-bucket/path/", # requires AWS credentials ) diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 775ec333c371..669e7568e67b 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -104,28 +104,16 @@ def _unsubscribe_request(self, channels): ) return req - def _handle_polling_failure(self, e: grpc.RpcError) -> bool: - if self._close.is_set(): - return False - - if e.code() in ( - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.UNKNOWN, - grpc.StatusCode.DEADLINE_EXCEEDED, - ): + @staticmethod + def _should_terminate_polling(e: grpc.RpcError) -> None: + # Caller only expects polling to be terminated after deadline exceeded. + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: return True - - raise e - - def _handle_subscribe_failure(self, e: grpc.RpcError): - if e.code() in ( - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.UNKNOWN, - grpc.StatusCode.DEADLINE_EXCEEDED, - ): - time.sleep(1) - else: - raise e + # Could be a temporary connection issue. Suppress error. + # TODO: reconnect GRPC channel? + if e.code() == grpc.StatusCode.UNAVAILABLE: + return True + return False @staticmethod def _pop_error_info(queue): @@ -223,7 +211,7 @@ def __init__( # Type of the channel. self._channel = pubsub_channel_type # Protects multi-threaded read and write of self._queue. - self._lock = threading.RLock() + self._lock = threading.Lock() # A queue of received PubMessage. self._queue = deque() # Indicates whether the subscriber has closed. @@ -236,24 +224,13 @@ def subscribe(self) -> None: saved for the subscriber. """ with self._lock: + if self._close.is_set(): + return req = self._subscribe_request(self._channel) - start = time.time() - from ray._raylet import Config - - while True: - try: - if self._close.is_set(): - return - return self._stub.GcsSubscriberCommandBatch(req, timeout=30) - except grpc.RpcError as e: - self._handle_subscribe_failure(e) - if ( - time.time() - start - > Config.gcs_rpc_server_reconnect_timeout_s() - ): - raise e + self._stub.GcsSubscriberCommandBatch(req, timeout=30) def _poll_locked(self, timeout=None) -> None: + assert self._lock.locked() # Poll until data becomes available. while len(self._queue) == 0: @@ -280,18 +257,9 @@ def _poll_locked(self, timeout=None) -> None: # GRPC has not replied, continue waiting. continue except grpc.RpcError as e: - if ( - e.code() == grpc.StatusCode.DEADLINE_EXCEEDED - and timeout is not None - ): - return - if self._handle_polling_failure(e) is True: - self.subscribe() - fut = self._stub.GcsSubscriberPoll.future( - self._poll_request(), timeout=timeout - ) - else: + if self._should_terminate_polling(e): return + raise if fut.done(): self._last_batch_size = len(fut.result().pub_messages) @@ -309,13 +277,11 @@ def close(self) -> None: return self._close.set() req = self._unsubscribe_request(channels=[self._channel]) - try: self._stub.GcsSubscriberCommandBatch(req, timeout=5) except Exception: pass - with self._lock: - self._stub = None + self._stub = None class GcsErrorSubscriber(_SyncSubscriber): @@ -530,39 +496,18 @@ async def subscribe(self) -> None: if self._close.is_set(): return req = self._subscribe_request(self._channel) - start = time.time() - from ray._raylet import Config - - while True: - try: - return await self._stub.GcsSubscriberCommandBatch(req, timeout=30) - except grpc.RpcError as e: - self._handle_subscribe_failure(e) - if time.time() - start > Config.gcs_rpc_server_reconnect_timeout_s(): - raise + await self._stub.GcsSubscriberCommandBatch(req, timeout=30) async def _poll_call(self, req, timeout=None): # Wrap GRPC _AioCall as a coroutine. - while True: - try: - return await self._stub.GcsSubscriberPoll(req, timeout=timeout) - except grpc.RpcError as e: - if ( - e.code() == grpc.StatusCode.DEADLINE_EXCEEDED - and timeout is not None - ): - return - if self._handle_polling_failure(e) is True: - await self.subscribe() - else: - return + return await self._stub.GcsSubscriberPoll(req, timeout=timeout) async def _poll(self, timeout=None) -> None: req = self._poll_request() while len(self._queue) == 0: # TODO: use asyncio.create_task() after Python 3.6 is no longer # supported. - poll = asyncio.ensure_future(self._poll_call(req)) + poll = asyncio.ensure_future(self._poll_call(req, timeout=timeout)) close = asyncio.ensure_future(self._close.wait()) done, _ = await asyncio.wait( [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED @@ -570,9 +515,14 @@ async def _poll(self, timeout=None) -> None: if poll not in done or close in done: # Request timed out or subscriber closed. break - self._last_batch_size = len(poll.result().pub_messages) - for msg in poll.result().pub_messages: - self._queue.append(msg) + try: + self._last_batch_size = len(poll.result().pub_messages) + for msg in poll.result().pub_messages: + self._queue.append(msg) + except grpc.RpcError as e: + if self._should_terminate_polling(e): + return + raise async def close(self) -> None: """Closes the subscriber and its active subscription.""" diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index f1121f38f3dd..21f474c10d66 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -55,12 +55,13 @@ def __new__(cls, value, doc=None): HTTPS = "https", "Remote https path, assumes everything packed in one zip file." S3 = "s3", "Remote s3 path, assumes everything packed in one zip file." GS = "gs", "Remote google storage path, assumes everything packed in one zip file." + FILE = "file", "File storage path, assumes everything packed in one zip file." @classmethod def remote_protocols(cls): - # Returns a lit of protocols that support remote storage + # Returns a list of protocols that support remote storage # These protocols should only be used with paths that end in ".zip" - return [cls.HTTPS, cls.S3, cls.GS] + return [cls.HTTPS, cls.S3, cls.GS, cls.FILE] def _xor_bytes(left: bytes, right: bytes) -> bytes: @@ -174,6 +175,15 @@ def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: ) -> ("gs", "gs_public-runtime-env-test_test_module.zip") + For FILE URIs, the path will have '/' replaced with '_'. The package name + will be the adjusted path with 'file_' prepended. + urlparse("file:///path/to/test_module.zip") + -> ParseResult( + scheme='file', + netloc='path', + path='/path/to/test_module.zip' + ) + -> ("file", "file__path_to_test_module.zip") """ uri = urlparse(pkg_uri) protocol = Protocol(uri.scheme) @@ -184,6 +194,11 @@ def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: protocol, f"https_{uri.netloc.replace('.', '_')}{uri.path.replace('/', '_')}", ) + elif protocol == Protocol.FILE: + return ( + protocol, + f"file_{uri.path.replace('/', '_')}", + ) else: return (protocol, uri.netloc) @@ -575,7 +590,7 @@ def download_and_unpack_package( if protocol == Protocol.S3: try: - from smart_open import open + from smart_open import open as open_file import boto3 except ImportError: raise ImportError( @@ -586,7 +601,7 @@ def download_and_unpack_package( tp = {"client": boto3.client("s3")} elif protocol == Protocol.GS: try: - from smart_open import open + from smart_open import open as open_file from google.cloud import storage # noqa: F401 except ImportError: raise ImportError( @@ -594,17 +609,23 @@ def download_and_unpack_package( "`pip install google-cloud-storage` " "to fetch URIs in Google Cloud Storage bucket." ) + elif protocol == Protocol.FILE: + pkg_uri = pkg_uri[len("file://") :] + + def open_file(uri, mode, *, transport_params=None): + return open(uri, mode) + else: try: - from smart_open import open + from smart_open import open as open_file except ImportError: raise ImportError( "You must `pip install smart_open` " f"to fetch {protocol.value.upper()} URIs." ) - with open(pkg_uri, "rb", transport_params=tp) as package_zip: - with open(pkg_file, "wb") as fin: + with open_file(pkg_uri, "rb", transport_params=tp) as package_zip: + with open_file(pkg_file, "wb") as fin: fin.write(package_zip.read()) unzip_package( diff --git a/python/ray/autoscaler/_private/command_runner.py b/python/ray/autoscaler/_private/command_runner.py index 0799430fd95a..0723880da149 100644 --- a/python/ray/autoscaler/_private/command_runner.py +++ b/python/ray/autoscaler/_private/command_runner.py @@ -443,8 +443,8 @@ def _run_helper( Full command to run. Should include SSH options and other processing that we do. with_output (bool): - If `with_output` is `True`, command stdout and stderr - will be captured and returned. + If `with_output` is `True`, command stdout will be captured and + returned. exit_on_fail (bool): If `exit_on_fail` is `True`, the process will exit if the command fails (exits with a code other than 0). @@ -465,10 +465,8 @@ def _run_helper( silent=silent, use_login_shells=is_using_login_shells(), ) - if with_output: - return self.process_runner.check_output(final_cmd) else: - return self.process_runner.check_call(final_cmd) + return self.process_runner.check_output(final_cmd) except subprocess.CalledProcessError as e: joined_cmd = " ".join(final_cmd) if not is_using_login_shells(): @@ -488,6 +486,11 @@ def _run_helper( if is_output_redirected(): fail_msg += " See above for the output from the failure." raise click.ClickException(fail_msg) from None + finally: + # Do our best to flush output to terminal. + # See https://github.com/ray-project/ray/pull/19473. + sys.stdout.flush() + sys.stderr.flush() def run( self, diff --git a/python/ray/autoscaler/_private/subprocess_output_util.py b/python/ray/autoscaler/_private/subprocess_output_util.py index febbd231268a..1b83f55ed7fd 100644 --- a/python/ray/autoscaler/_private/subprocess_output_util.py +++ b/python/ray/autoscaler/_private/subprocess_output_util.py @@ -325,7 +325,7 @@ def run_cmd_redirected( process_runner: Process runner used for executing commands. silent (bool): If true, the command output will be silenced completely (redirected to /dev/null), unless verbose logging - is enabled. Use this for runnign utility commands like + is enabled. Use this for running utility commands like rsync. """ if silent and cli_logger.verbosity < 1: diff --git a/python/ray/autoscaler/_private/updater.py b/python/ray/autoscaler/_private/updater.py index f3662319db91..a78dc01addf8 100644 --- a/python/ray/autoscaler/_private/updater.py +++ b/python/ray/autoscaler/_private/updater.py @@ -436,7 +436,7 @@ def do_update(self): _numbered=("[]", 4, NUM_SETUP_STEPS), ) with cli_logger.group( - "Initalizing command runner", + "Initializing command runner", # todo: fix command numbering _numbered=("[]", 5, NUM_SETUP_STEPS), ): diff --git a/python/ray/autoscaler/gcp/tpu.yaml b/python/ray/autoscaler/gcp/tpu.yaml index 1df29d828a0b..001dcd87c56a 100644 --- a/python/ray/autoscaler/gcp/tpu.yaml +++ b/python/ray/autoscaler/gcp/tpu.yaml @@ -39,7 +39,7 @@ available_node_types: provider: type: gcp region: us-central1 - availability_zone: us-central1-f + availability_zone: us-central1-b project_id: null # replace with your GCP project id setup_commands: [] diff --git a/python/ray/data/impl/push_based_shuffle.py b/python/ray/data/impl/push_based_shuffle.py index 44519af17df6..df93053020b0 100644 --- a/python/ray/data/impl/push_based_shuffle.py +++ b/python/ray/data/impl/push_based_shuffle.py @@ -138,6 +138,7 @@ def execute( # The placement strategy for reduce tasks is overwritten to colocate # them with their inputs from the merge stage, so remove any # pre-specified scheduling strategy here. + reduce_ray_remote_args = reduce_ray_remote_args.copy() reduce_ray_remote_args.pop("scheduling_strategy", None) map_fn = self._map_partition @@ -249,6 +250,7 @@ def submit_merge_task(arg): last_merge_metadata_results, merge_args, ) + shuffle_merge_metadata += prev_merge_metadata for merge_idx, merge_result in enumerate(merge_results): all_merge_results[merge_idx].append(merge_result) del merge_results diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index ffee72f24e58..ed67c7373b15 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3802,33 +3802,47 @@ def test_random_shuffle_check_random(shutdown_only): prev = x -def test_random_shuffle_spread(ray_start_cluster): - cluster = ray_start_cluster - cluster.add_node( - resources={"bar:1": 100}, - num_cpus=10, - _system_config={"max_direct_call_object_size": 0}, - ) - cluster.add_node(resources={"bar:2": 100}, num_cpus=10) - cluster.add_node(resources={"bar:3": 100}, num_cpus=0) +@pytest.mark.parametrize("use_push_based_shuffle", [False, True]) +def test_random_shuffle_spread(ray_start_cluster, use_push_based_shuffle): + ctx = ray.data.context.DatasetContext.get_current() + try: + original = ctx.use_push_based_shuffle + ctx.use_push_based_shuffle = use_push_based_shuffle - ray.init(cluster.address) + cluster = ray_start_cluster + cluster.add_node( + resources={"bar:1": 100}, + num_cpus=10, + _system_config={"max_direct_call_object_size": 0}, + ) + cluster.add_node(resources={"bar:2": 100}, num_cpus=10) + cluster.add_node(resources={"bar:3": 100}, num_cpus=0) - @ray.remote - def get_node_id(): - return ray.get_runtime_context().node_id.hex() + ray.init(cluster.address) - node1_id = ray.get(get_node_id.options(resources={"bar:1": 1}).remote()) - node2_id = ray.get(get_node_id.options(resources={"bar:2": 1}).remote()) + @ray.remote + def get_node_id(): + return ray.get_runtime_context().node_id.hex() - ds = ray.data.range(100, parallelism=2).random_shuffle() - blocks = ds.get_internal_block_refs() - ray.wait(blocks, num_returns=len(blocks), fetch_local=False) - location_data = ray.experimental.get_object_locations(blocks) - locations = [] - for block in blocks: - locations.extend(location_data[block]["node_ids"]) - assert set(locations) == {node1_id, node2_id} + node1_id = ray.get(get_node_id.options(resources={"bar:1": 1}).remote()) + node2_id = ray.get(get_node_id.options(resources={"bar:2": 1}).remote()) + + ds = ray.data.range(100, parallelism=2).random_shuffle() + blocks = ds.get_internal_block_refs() + ray.wait(blocks, num_returns=len(blocks), fetch_local=False) + location_data = ray.experimental.get_object_locations(blocks) + locations = [] + for block in blocks: + locations.extend(location_data[block]["node_ids"]) + assert "2 nodes used" in ds.stats() + + if not use_push_based_shuffle: + # We don't check this for push-based shuffle since it will try to + # colocate reduce tasks to improve locality. + assert set(locations) == {node1_id, node2_id} + + finally: + ctx.use_push_based_shuffle = original def test_parquet_read_spread(ray_start_cluster, tmp_path): diff --git a/python/ray/data/tests/test_sort.py b/python/ray/data/tests/test_sort.py index 77e8e51441bc..ed907007c676 100644 --- a/python/ray/data/tests/test_sort.py +++ b/python/ray/data/tests/test_sort.py @@ -212,6 +212,45 @@ def _test(num_input_blocks, merge_factor, num_cpus_per_node_map): _test(100, 10, {"node1": 10, "node2": 10, "node3": 10}) +def test_push_based_shuffle_stats(ray_start_cluster): + ctx = ray.data.context.DatasetContext.get_current() + try: + original = ctx.use_push_based_shuffle + ctx.use_push_based_shuffle = True + + cluster = ray_start_cluster + cluster.add_node( + resources={"bar:1": 100}, + num_cpus=10, + _system_config={"max_direct_call_object_size": 0}, + ) + cluster.add_node(resources={"bar:2": 100}, num_cpus=10) + cluster.add_node(resources={"bar:3": 100}, num_cpus=0) + + ray.init(cluster.address) + + parallelism = 100 + ds = ray.data.range(1000, parallelism=parallelism).random_shuffle() + assert "random_shuffle_merge" in ds.stats() + # Check all nodes used. + assert "2 nodes used" in ds.stats() + assert "1 nodes used" not in ds.stats() + + # Check all merge tasks are included in stats. + internal_stats = ds._plan.stats() + num_merge_tasks = len(internal_stats.stages["random_shuffle_merge"]) + # Merge factor is 2 for random_shuffle ops. + merge_factor = 2 + assert ( + parallelism // (merge_factor + 1) + <= num_merge_tasks + <= parallelism // merge_factor + ) + + finally: + ctx.use_push_based_shuffle = original + + if __name__ == "__main__": import sys diff --git a/python/ray/experimental/dag/py_obj_scanner.py b/python/ray/experimental/dag/py_obj_scanner.py index 7ea21eb4654b..3ce3bf92a0e9 100644 --- a/python/ray/experimental/dag/py_obj_scanner.py +++ b/python/ray/experimental/dag/py_obj_scanner.py @@ -49,6 +49,13 @@ def __init__(self): from ray.serve.pipeline.deployment_node import DeploymentNode from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode from ray.serve.pipeline.deployment_function_node import DeploymentFunctionNode + from ray.serve.deployment_executor_node import DeploymentExecutorNode + from ray.serve.deployment_method_executor_node import ( + DeploymentMethodExecutorNode, + ) + from ray.serve.deployment_function_executor_node import ( + DeploymentFunctionExecutorNode, + ) self.dispatch_table[FunctionNode] = self._reduce_dag_node self.dispatch_table[ClassNode] = self._reduce_dag_node @@ -58,6 +65,11 @@ def __init__(self): self.dispatch_table[DeploymentNode] = self._reduce_dag_node self.dispatch_table[DeploymentMethodNode] = self._reduce_dag_node self.dispatch_table[DeploymentFunctionNode] = self._reduce_dag_node + + self.dispatch_table[DeploymentExecutorNode] = self._reduce_dag_node + self.dispatch_table[DeploymentMethodExecutorNode] = self._reduce_dag_node + self.dispatch_table[DeploymentFunctionExecutorNode] = self._reduce_dag_node + super().__init__(self._buf) def find_nodes(self, obj: Any) -> List["DAGNode"]: diff --git a/python/ray/experimental/state/api.py b/python/ray/experimental/state/api.py index 4fe83f6125a0..eea31f4b5638 100644 --- a/python/ray/experimental/state/api.py +++ b/python/ray/experimental/state/api.py @@ -1,4 +1,5 @@ import requests +import warnings from dataclasses import fields @@ -8,18 +9,26 @@ DEFAULT_RPC_TIMEOUT, DEFAULT_LIMIT, ) +from ray.experimental.state.exception import RayStateApiException # TODO(sang): Replace it with auto-generated methods. -def _list(resource_name: str, options: ListApiOptions, api_server_url: str = None): +def _list( + resource_name: str, + options: ListApiOptions, + api_server_url: str = None, + _explain: bool = False, +): """Query the API server in address to list "resource_name" states. Args: resource_name: The name of the resource. E.g., actor, task. options: The options for the REST API that are translated to query strings. - address: The address of API server. If it is not give, it assumes the ray + api_server_url: The address of API server. If it is not give, it assumes the ray is already connected and obtains the API server address using Ray API. + explain: Print the API information such as API + latency or failed query information. """ if api_server_url is None: assert ray.is_initialized() @@ -40,10 +49,18 @@ def _list(resource_name: str, options: ListApiOptions, api_server_url: str = Non r.raise_for_status() response = r.json() - if not response["result"]: - raise ValueError( - "API server internal error. See dashboard.log file for more details." + if response["result"] is False: + raise RayStateApiException( + "API server internal error. See dashboard.log file for more details. " + f"Error: {response['msg']}" ) + + if _explain: + # Print warnings if anything was given. + warning_msg = response["data"].get("partial_failure_warning", None) + if warning_msg is not None: + warnings.warn(warning_msg, RuntimeWarning) + return r.json()["data"]["result"] @@ -51,11 +68,13 @@ def list_actors( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "actors", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -63,11 +82,13 @@ def list_placement_groups( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "placement_groups", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -75,11 +96,13 @@ def list_nodes( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "nodes", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -87,11 +110,13 @@ def list_jobs( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "jobs", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -99,11 +124,13 @@ def list_workers( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "workers", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -111,11 +138,13 @@ def list_tasks( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "tasks", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -123,17 +152,25 @@ def list_objects( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "objects", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) -def list_runtime_envs(api_server_url: str = None, limit: int = 1000, timeout: int = 30): +def list_runtime_envs( + api_server_url: str = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +): return _list( "runtime_envs", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index 789faae72fbc..7b119083824e 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass, fields +from typing import List, Dict, Union from ray.dashboard.modules.job.common import JobInfo @@ -23,11 +24,20 @@ def filter_fields(data: dict, state_dataclass) -> dict: class ListApiOptions: limit: int timeout: int + # When the request is processed on the server side, + # we should apply multiplier so that server side can finish + # processing a request within timeout. Otherwise, + # timeout will always lead Http timeout. + _server_timeout_multiplier: float = 0.8 # TODO(sang): Use Pydantic instead. def __post_init__(self): assert isinstance(self.limit, int) assert isinstance(self.timeout, int) + # To return the data to users, when there's a partial failure + # we need to have a timeout that's smaller than the users' timeout. + # 80% is configured arbitrarily. + self.timeout = int(self.timeout * self._server_timeout_multiplier) # TODO(sang): Replace it with Pydantic or gRPC schema (once interface is finalized). @@ -94,3 +104,30 @@ class RuntimeEnvState: error: str creation_time_ms: float node_id: str + + +@dataclass(init=True) +class ListApiResponse: + # Returned data. None if no data is returned. + result: Union[ + Dict[ + str, + Union[ + ActorState, + PlacementGroupState, + NodeState, + JobInfo, + WorkerState, + TaskState, + ObjectState, + ], + ], + List[RuntimeEnvState], + ] = None + # List API can have a partial failure if queries to + # all sources fail. For example, getting object states + # require to ping all raylets, and it is possible some of + # them fails. Note that it is impossible to guarantee high + # availability of data because ray's state information is + # not replicated. + partial_failure_warning: str = "" diff --git a/python/ray/experimental/state/exception.py b/python/ray/experimental/state/exception.py new file mode 100644 index 000000000000..e91b8d9313ea --- /dev/null +++ b/python/ray/experimental/state/exception.py @@ -0,0 +1,12 @@ +"""Internal Error""" + + +class DataSourceUnavailable(Exception): + pass + + +"""User-facing Error""" + + +class RayStateApiException(Exception): + pass diff --git a/python/ray/experimental/state/state_cli.py b/python/ray/experimental/state/state_cli.py index d221ec57f741..24773092a68a 100644 --- a/python/ray/experimental/state/state_cli.py +++ b/python/ray/experimental/state/state_cli.py @@ -59,6 +59,12 @@ def get_state_api_output_to_print( ) +def _should_explain(format: AvailableFormat): + # If the format is json or yaml, it should not print stats because + # users don't want additional strings. + return format == AvailableFormat.DEFAULT or format == AvailableFormat.TABLE + + @click.group("list") @click.pass_context def list_state_cli_group(ctx): @@ -71,6 +77,7 @@ def list_state_cli_group(ctx): namespace=ray_constants.KV_NAMESPACE_DASHBOARD, num_retries=20, ) + if api_server_url is None: raise ValueError( ( @@ -92,9 +99,11 @@ def list_state_cli_group(ctx): @click.pass_context def actors(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_actors(api_server_url=url), format=AvailableFormat(format) + list_actors(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -106,10 +115,11 @@ def actors(ctx, format: str): @click.pass_context def placement_groups(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_placement_groups(api_server_url=url), - format=AvailableFormat(format), + list_placement_groups(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -121,9 +131,11 @@ def placement_groups(ctx, format: str): @click.pass_context def nodes(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_nodes(api_server_url=url), format=AvailableFormat(format) + list_nodes(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -135,9 +147,11 @@ def nodes(ctx, format: str): @click.pass_context def jobs(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_jobs(api_server_url=url), format=AvailableFormat(format) + list_jobs(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -149,9 +163,11 @@ def jobs(ctx, format: str): @click.pass_context def workers(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_workers(api_server_url=url), format=AvailableFormat(format) + list_workers(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -163,9 +179,11 @@ def workers(ctx, format: str): @click.pass_context def tasks(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_tasks(api_server_url=url), format=AvailableFormat(format) + list_tasks(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -177,9 +195,11 @@ def tasks(ctx, format: str): @click.pass_context def objects(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_objects(api_server_url=url), format=AvailableFormat(format) + list_objects(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -191,9 +211,10 @@ def objects(ctx, format: str): @click.pass_context def runtime_envs(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_runtime_envs(api_server_url=url), - format=AvailableFormat(format), + list_runtime_envs(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) diff --git a/python/ray/experimental/state/state_manager.py b/python/ray/experimental/state/state_manager.py index ce471ed3db94..656f2ae24cc9 100644 --- a/python/ray/experimental/state/state_manager.py +++ b/python/ray/experimental/state/state_manager.py @@ -6,7 +6,7 @@ import grpc import ray -from typing import Dict, List +from typing import Dict, List, Optional from ray import ray_constants from ray.core.generated.gcs_service_pb2 import ( @@ -33,19 +33,13 @@ from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub from ray.dashboard.modules.job.common import JobInfoStorageClient, JobInfo +from ray.experimental.state.exception import DataSourceUnavailable logger = logging.getLogger(__name__) -class StateSourceNetworkException(Exception): - """Exceptions raised when there's a network error from data source query.""" - - pass - - -def handle_network_errors(func): - """Apply the network error handling logic to each APIs, - such as retry or exception policies. +def handle_grpc_network_errors(func): + """Decorator to add a network handling logic. It is a helper method for `StateDataSourceClient`. The method can only be used for async methods. @@ -54,20 +48,32 @@ def handle_network_errors(func): @wraps(func) async def api_with_network_error_handler(*args, **kwargs): + """Apply the network error handling logic to each APIs, + such as retry or exception policies. + + Returns: + If RPC succeeds, it returns what the original function returns. + If RPC fails, it raises exceptions. + Exceptions: + DataSourceUnavailable: if the source is unavailable because it is down + or there's a slow network issue causing timeout. + Otherwise, the raw network exceptions (e.g., gRPC) will be raised. + """ # TODO(sang): Add a retry policy. try: return await func(*args, **kwargs) - except ( - # https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc-exceptions - grpc.aio.AioRpcError, - grpc.aio.InternalError, - grpc.aio.AbortError, - grpc.aio.BaseError, - grpc.aio.UsageError, - ) as e: - raise StateSourceNetworkException( - f"Failed to query the data source, {func}" - ) from e + except grpc.aio.AioRpcError as e: + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + or e.code() == grpc.StatusCode.UNAVAILABLE + ): + raise DataSourceUnavailable( + "Failed to query the data source. " + "It is either there's a network issue, or the source is down." + ) + else: + logger.exception(e) + raise e return api_with_network_error_handler @@ -81,8 +87,9 @@ class StateDataSourceClient: finding services and register stubs through `register*` APIs. Non `register*` APIs + - Return the protobuf directly if it succeeds to query the source. + - Raises an exception if there's any network issue. - throw a ValueError if it cannot find the source. - - throw `StateSourceNetworkException` if there's any network errors. """ def __init__(self, gcs_channel: grpc.aio.Channel): @@ -132,50 +139,66 @@ def get_all_registered_raylet_ids(self) -> List[str]: def get_all_registered_agent_ids(self) -> List[str]: return self._agent_stubs.keys() - @handle_network_errors - async def get_all_actor_info(self, timeout: int = None) -> GetAllActorInfoReply: + @handle_grpc_network_errors + async def get_all_actor_info( + self, timeout: int = None + ) -> Optional[GetAllActorInfoReply]: request = GetAllActorInfoRequest() reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=timeout ) return reply - @handle_network_errors + @handle_grpc_network_errors async def get_all_placement_group_info( self, timeout: int = None - ) -> GetAllPlacementGroupReply: + ) -> Optional[GetAllPlacementGroupReply]: request = GetAllPlacementGroupRequest() reply = await self._gcs_pg_info_stub.GetAllPlacementGroup( request, timeout=timeout ) return reply - @handle_network_errors - async def get_all_node_info(self, timeout: int = None) -> GetAllNodeInfoReply: + @handle_grpc_network_errors + async def get_all_node_info( + self, timeout: int = None + ) -> Optional[GetAllNodeInfoReply]: request = GetAllNodeInfoRequest() reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout) return reply - @handle_network_errors - async def get_all_worker_info(self, timeout: int = None) -> GetAllWorkerInfoReply: + @handle_grpc_network_errors + async def get_all_worker_info( + self, timeout: int = None + ) -> Optional[GetAllWorkerInfoReply]: request = GetAllWorkerInfoRequest() reply = await self._gcs_worker_info_stub.GetAllWorkerInfo( request, timeout=timeout ) return reply - def get_job_info(self) -> Dict[str, JobInfo]: - # Cannot use @handle_network_errors because async def is not supported yet. + def get_job_info(self) -> Optional[Dict[str, JobInfo]]: + # Cannot use @handle_grpc_network_errors because async def is not supported yet. # TODO(sang): Support timeout & make it async try: return self._job_client.get_all_jobs() - except Exception as e: - raise StateSourceNetworkException("Failed to query the job info.") from e - - @handle_network_errors + except grpc.aio.AioRpcError as e: + if ( + e.code == grpc.StatusCode.DEADLINE_EXCEEDED + or e.code == grpc.StatusCode.UNAVAILABLE + ): + raise DataSourceUnavailable( + "Failed to query the data source. " + "It is either there's a network issue, or the source is down." + ) + else: + logger.exception(e) + raise e + + @handle_grpc_network_errors async def get_task_info( self, node_id: str, timeout: int = None - ) -> GetTasksInfoReply: + ) -> Optional[GetTasksInfoReply]: stub = self._raylet_stubs.get(node_id) if not stub: raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.") @@ -183,10 +206,10 @@ async def get_task_info( reply = await stub.GetTasksInfo(GetTasksInfoRequest(), timeout=timeout) return reply - @handle_network_errors + @handle_grpc_network_errors async def get_object_info( self, node_id: str, timeout: int = None - ) -> GetNodeStatsReply: + ) -> Optional[GetNodeStatsReply]: stub = self._raylet_stubs.get(node_id) if not stub: raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.") @@ -197,10 +220,10 @@ async def get_object_info( ) return reply - @handle_network_errors + @handle_grpc_network_errors async def get_runtime_envs_info( self, node_id: str, timeout: int = None - ) -> GetRuntimeEnvsInfoReply: + ) -> Optional[GetRuntimeEnvsInfoReply]: stub = self._agent_stubs.get(node_id) if not stub: raise ValueError(f"Agent for a node id, {node_id} doesn't exist.") diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 7b715fe4128e..e9ac6d2ccd2b 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -1,5 +1,5 @@ from libcpp cimport bool as c_bool -from libc.stdint cimport int64_t, uint64_t, uint32_t, int32_t +from libc.stdint cimport int64_t, uint64_t, uint32_t from libcpp.string cimport string as c_string from libcpp.unordered_map cimport unordered_map @@ -68,5 +68,3 @@ cdef extern from "ray/common/ray_config.h" nogil: c_bool start_python_importer_thread() const c_bool use_ray_syncer() const - - int32_t gcs_rpc_server_reconnect_timeout_s() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index 49950095373d..c65bfbc291c3 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -84,10 +84,6 @@ cdef class Config: def object_manager_default_chunk_size(): return RayConfig.instance().object_manager_default_chunk_size() - @staticmethod - def gcs_rpc_server_reconnect_timeout_s(): - return RayConfig.instance().gcs_rpc_server_reconnect_timeout_s() - @staticmethod def maximum_gcs_deletion_batch_size(): return RayConfig.instance().maximum_gcs_deletion_batch_size() diff --git a/python/ray/ml/checkpoint.py b/python/ray/ml/checkpoint.py index 920b34ca9cf5..985e4a370100 100644 --- a/python/ray/ml/checkpoint.py +++ b/python/ray/ml/checkpoint.py @@ -479,7 +479,7 @@ def _temporary_checkpoint_dir() -> str: def _pack(path: str) -> bytes: """Pack directory in ``path`` into an archive, return as bytes string.""" stream = io.BytesIO() - with tarfile.open(fileobj=stream, mode="w:gz", format=tarfile.PAX_FORMAT) as tar: + with tarfile.open(fileobj=stream, mode="w", format=tarfile.PAX_FORMAT) as tar: tar.add(path, arcname="") return stream.getvalue() diff --git a/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py b/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py index c4c91c966025..2eb125db30aa 100644 --- a/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py +++ b/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py @@ -147,5 +147,5 @@ def build_model(self): if self.model_weights: model.set_weights(self.model_weights) - prediction = model(tensor).numpy().ravel() - return pd.DataFrame(prediction, columns=["predictions"]) + prediction = list(model(tensor).numpy()) + return pd.DataFrame({"predictions": prediction}, columns=["predictions"]) diff --git a/python/ray/ml/preprocessors/encoder.py b/python/ray/ml/preprocessors/encoder.py index 7416545ed2cc..e82e259f8a00 100644 --- a/python/ray/ml/preprocessors/encoder.py +++ b/python/ray/ml/preprocessors/encoder.py @@ -1,5 +1,6 @@ from typing import List, Dict, Optional, Union +from collections import Counter import pandas as pd from ray.data import Dataset @@ -50,19 +51,50 @@ class OneHotEncoder(Preprocessor): for each of the values from the fitted dataset. The value of a column will be set to 1 if the value matches, otherwise 0. - Transforming values not included in the fitted dataset will result in all - of the encoded column values being 0. + Transforming values not included in the fitted dataset or not among + the top popular values (see ``limit``) will result in all of the encoded column + values being 0. + + Example: + + .. code-block:: python + + ohe = OneHotEncoder( + columns=[ + "trip_start_hour", + "trip_start_day", + "trip_start_month", + "dropoff_census_tract", + "pickup_community_area", + "dropoff_community_area", + "payment_type", + "company", + ], + limit={ + "dropoff_census_tract": 25, + "pickup_community_area": 20, + "dropoff_community_area": 20, + "payment_type": 2, + "company": 7, + }, + ) Args: columns: The columns that will individually be encoded. + limit: If set, only the top "limit" number of most popular values become + categorical variables. The less frequent ones will result in all + the encoded column values being 0. This is a dict of column to + its corresponding limit. The column in this dictionary has to be + in ``columns``. """ - def __init__(self, columns: List[str]): + def __init__(self, columns: List[str], limit: Optional[Dict[str, int]] = None): # TODO: add `drop` parameter. self.columns = columns + self.limit = limit def _fit(self, dataset: Dataset) -> Preprocessor: - self.stats_ = _get_unique_value_indices(dataset, self.columns) + self.stats_ = _get_unique_value_indices(dataset, self.columns, limit=self.limit) return self def _transform_pandas(self, df: pd.DataFrame): @@ -177,35 +209,59 @@ def _get_unique_value_indices( columns: List[str], drop_na_values: bool = False, key_format: str = "unique_values({0})", + limit: Optional[Dict[str, int]] = None, ) -> Dict[str, Dict[str, int]]: """If drop_na_values is True, will silently drop NA values.""" + limit = limit or {} + for column in limit: + if column not in columns: + raise ValueError( + f"You set limit for {column}, which is not present in {columns}." + ) - def get_pd_unique_values(df: pd.DataFrame) -> List[Dict[str, set]]: - return [{col: set(df[col].unique()) for col in columns}] - - uniques = dataset.map_batches(get_pd_unique_values, batch_format="pandas") - final_uniques = {col: set() for col in columns} - for batch in uniques.iter_batches(): - for col_uniques in batch: - for col, uniques in col_uniques.items(): - final_uniques[col].update(uniques) - - for col, uniques in final_uniques.items(): + def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]: + result = [ + { + col: Counter(df[col].value_counts(dropna=False).to_dict()) + for col in columns + } + ] + return result + + value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas") + final_counters = {col: Counter() for col in columns} + for batch in value_counts.iter_batches(): + for col_value_counts in batch: + for col, value_counts in col_value_counts.items(): + final_counters[col] += value_counts + + # Inspect if there is any NA values. + for col in columns: if drop_na_values: - final_uniques[col] = {v for v in uniques if not pd.isnull(v)} + counter = final_counters[col] + counter_dict = dict(counter) + sanitized_dict = {k: v for k, v in counter_dict.items() if not pd.isnull(k)} + final_counters[col] = Counter(sanitized_dict) else: - if any(pd.isnull(v) for v in uniques): + if any(pd.isnull(k) for k in final_counters[col]): raise ValueError( - f"Unable to fit column '{col}' because it contains null values. " - f"Consider imputing missing values first." + f"Unable to fit column '{col}' because it contains null" + f" values. Consider imputing missing values first." ) - unique_values_with_indices = { - key_format.format(column): { - k: j for j, k in enumerate(sorted(final_uniques[column])) - } - for column in columns - } + unique_values_with_indices = dict() + for column in columns: + if column in limit: + # Output sorted by freq. + unique_values_with_indices[key_format.format(column)] = { + k[0]: j + for j, k in enumerate(final_counters[column].most_common(limit[column])) + } + else: + # Output sorted by column name. + unique_values_with_indices[key_format.format(column)] = { + k: j for j, k in enumerate(sorted(dict(final_counters[column]).keys())) + } return unique_values_with_indices diff --git a/python/ray/ml/tests/test_preprocessors.py b/python/ray/ml/tests/test_preprocessors.py index 85aec4ccdae5..36b08f409e1e 100644 --- a/python/ray/ml/tests/test_preprocessors.py +++ b/python/ray/ml/tests/test_preprocessors.py @@ -470,6 +470,20 @@ def test_one_hot_encoder(): null_encoder.transform_batch(nonnull_df) +def test_one_hot_encoder_with_limit(): + """Tests basic OneHotEncoder functionality with limit.""" + col_a = ["red", "green", "blue", "red"] + col_b = ["warm", "cold", "hot", "cold"] + col_c = [1, 10, 5, 10] + in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c}) + ds = ray.data.from_pandas(in_df) + + encoder = OneHotEncoder(["B", "C"], limit={"B": 2}) + + ds_out = encoder.fit_transform(ds) + assert len(ds_out.to_pandas().columns) == 1 + 2 + 3 + + def test_label_encoder(): """Tests basic LabelEncoder functionality.""" col_a = ["red", "green", "blue", "red"] diff --git a/python/ray/ml/train/integrations/sklearn/sklearn_trainer.py b/python/ray/ml/train/integrations/sklearn/sklearn_trainer.py index 8696470b0d52..f9da0e4494e0 100644 --- a/python/ray/ml/train/integrations/sklearn/sklearn_trainer.py +++ b/python/ray/ml/train/integrations/sklearn/sklearn_trainer.py @@ -452,7 +452,7 @@ def load_checkpoint( with checkpoint.as_directory() as checkpoint_path: estimator_path = os.path.join(checkpoint_path, MODEL_KEY) with open(estimator_path, "rb") as f: - estimator_path = cpickle.load(f) + estimator = cpickle.load(f) preprocessor = load_preprocessor_from_dir(checkpoint_path) - return estimator_path, preprocessor + return estimator, preprocessor diff --git a/python/ray/ml/train/integrations/tensorflow/tensorflow_trainer.py b/python/ray/ml/train/integrations/tensorflow/tensorflow_trainer.py index 30987731eb49..2f25781f2d1e 100644 --- a/python/ray/ml/train/integrations/tensorflow/tensorflow_trainer.py +++ b/python/ray/ml/train/integrations/tensorflow/tensorflow_trainer.py @@ -125,7 +125,7 @@ def train_loop_for_worker(config): ) model.fit(tf_dataset) train.save_checkpoint( - epoch=epoch, model_weights=model.get_weights()) + epoch=epoch, model=model.get_weights()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 6a4e873df236..291864dc99b2 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -16,7 +16,14 @@ from uvicorn.config import Config from uvicorn.lifespan.on import LifespanOn -from ray.serve.common import DeploymentStatusInfo +import ray +from ray import cloudpickle +from ray._private.usage import usage_lib +from ray.experimental.dag import DAGNode +from ray.util.annotations import PublicAPI + +from ray.serve.application import Application +from ray.serve.client import ServeControllerClient, get_controller_namespace from ray.serve.config import ( AutoscalingConfig, DeploymentConfig, @@ -30,13 +37,23 @@ DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, ) +from ray.serve.context import ( + set_global_client, + get_global_client, + get_internal_replica_context, + ReplicaContext, +) from ray.serve.controller import ServeController from ray.serve.deployment import Deployment +from ray.serve.deployment_graph import ClassNode, FunctionNode from ray.serve.exceptions import RayServeException -from ray.experimental.dag import DAGNode from ray.serve.handle import RayServeHandle from ray.serve.http_util import ASGIHTTPSender, make_fastapi_class_based_view from ray.serve.logging_utils import LoggingContext +from ray.serve.pipeline.api import ( + build as pipeline_build, + get_and_validate_ingress_deployment, +) from ray.serve.utils import ( ensure_serialization_context, format_actor_name, @@ -46,21 +63,7 @@ DEFAULT, install_serve_encoders_to_fastapi, ) -from ray.util.annotations import PublicAPI -import ray -from ray import cloudpickle -from ray.serve.deployment_graph import ClassNode, FunctionNode -from ray.serve.application import Application -from ray.serve.client import ServeControllerClient, get_controller_namespace -from ray.serve.context import ( - set_global_client, - get_global_client, - get_internal_replica_context, - ReplicaContext, -) -from ray.serve.pipeline.api import build as pipeline_build -from ray.serve.pipeline.api import get_and_validate_ingress_deployment -from ray._private.usage import usage_lib + logger = logging.getLogger(__file__) @@ -547,27 +550,6 @@ def list_deployments() -> Dict[str, Deployment]: return deployments -def get_deployment_statuses() -> Dict[str, DeploymentStatusInfo]: - """Returns a dictionary of deployment statuses. - - A deployment's status is one of {UPDATING, UNHEALTHY, and HEALTHY}. - - Example: - >>> from ray.serve.api import get_deployment_statuses - >>> statuses = get_deployment_statuses() # doctest: +SKIP - >>> status_info = statuses["deployment_name"] # doctest: +SKIP - >>> status = status_info.status # doctest: +SKIP - >>> message = status_info.message # doctest: +SKIP - - Returns: - Dict[str, DeploymentStatus]: This dictionary maps the running - deployment's name to a DeploymentStatus object containing its - status and a message explaining the status. - """ - - return get_global_client().get_deployment_statuses() - - @PublicAPI(stability="alpha") def run( target: Union[ClassNode, FunctionNode], diff --git a/python/ray/serve/client.py b/python/ray/serve/client.py index 33e298db4916..8872dc550be1 100644 --- a/python/ray/serve/client.py +++ b/python/ray/serve/client.py @@ -21,7 +21,7 @@ from ray.serve.common import ( DeploymentInfo, DeploymentStatus, - DeploymentStatusInfo, + StatusOverview, ) from ray.serve.config import ( DeploymentConfig, @@ -38,7 +38,7 @@ from ray.serve.generated.serve_pb2 import ( DeploymentRoute, DeploymentRouteList, - DeploymentStatusInfoList, + StatusOverview as StatusOverviewProto, ) from ray.serve.handle import RayServeHandle, RayServeSyncHandle @@ -155,16 +155,19 @@ def _wait_for_deployments_shutdown(self, timeout_s: int = 60): """ start = time.time() while time.time() - start < timeout_s: - statuses = self.get_deployment_statuses() - if len(statuses) == 0: + deployment_statuses = self.get_serve_status().deployment_statuses + if len(deployment_statuses) == 0: break else: logger.debug( - f"Waiting for shutdown, {len(statuses)} deployments still alive." + f"Waiting for shutdown, {len(deployment_statuses)} " + "deployments still alive." ) time.sleep(CLIENT_POLLING_INTERVAL_S) else: - live_names = list(statuses.keys()) + live_names = [ + deployment_status.name for deployment_status in deployment_statuses + ] raise TimeoutError( f"Shutdown didn't complete after {timeout_s}s. " f"Deployments still alive: {live_names}." @@ -180,25 +183,28 @@ def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1): """ start = time.time() while time.time() - start < timeout_s or timeout_s < 0: - statuses = self.get_deployment_statuses() - try: - status = statuses[name] - except KeyError: + + status = self.get_serve_status().get_deployment_status(name) + + if status is None: raise RuntimeError( f"Waiting for deployment {name} to be HEALTHY, " "but deployment doesn't exist." - ) from None + ) if status.status == DeploymentStatus.HEALTHY: break elif status.status == DeploymentStatus.UNHEALTHY: - raise RuntimeError(f"Deployment {name} is UNHEALTHY: {status.message}") + raise RuntimeError( + f"Deployment {name} is UNHEALTHY: " f"{status.message}" + ) else: # Guard against new unhandled statuses being added. assert status.status == DeploymentStatus.UPDATING logger.debug( - f"Waiting for {name} to be healthy, current status: {status.status}." + f"Waiting for {name} to be healthy, current status: " + f"{status.status}." ) time.sleep(CLIENT_POLLING_INTERVAL_S) else: @@ -213,14 +219,12 @@ def _wait_for_deployment_deleted(self, name: str, timeout_s: int = 60): """ start = time.time() while time.time() - start < timeout_s: - statuses = self.get_deployment_statuses() - if name not in statuses: + curr_status = self.get_serve_status().get_deployment_status(name) + if curr_status is None: break - else: - curr_status = statuses[name].status - logger.debug( - f"Waiting for {name} to be deleted, current status: {curr_status}." - ) + logger.debug( + f"Waiting for {name} to be deleted, current status: {curr_status}." + ) time.sleep(CLIENT_POLLING_INTERVAL_S) else: raise TimeoutError(f"Deployment {name} wasn't deleted after {timeout_s}s.") @@ -313,7 +317,7 @@ def deploy_group( deployment_names_to_delete = all_deployments_names.difference( new_deployments_names ) - self.delete_deployments(deployment_names_to_delete) + self.delete_deployments(deployment_names_to_delete, blocking=_blocking) @_ensure_connected def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None: @@ -346,16 +350,11 @@ def list_deployments(self) -> Dict[str, Tuple[DeploymentInfo, str]]: } @_ensure_connected - def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]: - proto = DeploymentStatusInfoList.FromString( - ray.get(self._controller.get_deployment_statuses.remote()) + def get_serve_status(self) -> StatusOverview: + proto = StatusOverviewProto.FromString( + ray.get(self._controller.get_serve_status.remote()) ) - return { - deployment_status_info.name: DeploymentStatusInfo.from_proto( - deployment_status_info - ) - for deployment_status_info in proto.deployment_status_infos - } + return StatusOverview.from_proto(proto) @_ensure_connected def get_handle( diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index cd438706cd45..8bbafcc46f70 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass +import json from enum import Enum -from typing import Any, Dict, Optional +from dataclasses import dataclass, field, asdict +from typing import Any, List, Dict, Optional import ray from ray.actor import ActorHandle @@ -10,6 +11,10 @@ DeploymentInfo as DeploymentInfoProto, DeploymentStatusInfo as DeploymentStatusInfoProto, DeploymentStatus as DeploymentStatusProto, + DeploymentStatusInfoList as DeploymentStatusInfoListProto, + ApplicationStatus as ApplicationStatusProto, + ApplicationStatusInfo as ApplicationStatusInfoProto, + StatusOverview as StatusOverviewProto, DeploymentLanguage, ) @@ -24,28 +29,127 @@ class EndpointInfo: route: str +class ApplicationStatus(str, Enum): + DEPLOYING = "DEPLOYING" + RUNNING = "RUNNING" + DEPLOY_FAILED = "DEPLOY_FAILED" + + +@dataclass(eq=True) +class ApplicationStatusInfo: + status: ApplicationStatus + message: str = "" + deployment_timestamp: float = 0 + + def debug_string(self): + return json.dumps(asdict(self), indent=4) + + def to_proto(self): + return ApplicationStatusInfoProto( + status=self.status, + message=self.message, + deployment_timestamp=self.deployment_timestamp, + ) + + @classmethod + def from_proto(cls, proto: ApplicationStatusInfoProto): + return cls( + status=ApplicationStatus(ApplicationStatusProto.Name(proto.status)), + message=proto.message, + deployment_timestamp=proto.deployment_timestamp, + ) + + class DeploymentStatus(str, Enum): UPDATING = "UPDATING" HEALTHY = "HEALTHY" UNHEALTHY = "UNHEALTHY" -@dataclass +@dataclass(eq=True) class DeploymentStatusInfo: + name: str status: DeploymentStatus message: str = "" + def debug_string(self): + return json.dumps(asdict(self), indent=4) + def to_proto(self): - return DeploymentStatusInfoProto(status=self.status, message=self.message) + return DeploymentStatusInfoProto( + name=self.name, status=self.status, message=self.message + ) @classmethod def from_proto(cls, proto: DeploymentStatusInfoProto): return cls( + name=proto.name, status=DeploymentStatus(DeploymentStatusProto.Name(proto.status)), message=proto.message, ) +@dataclass(eq=True) +class StatusOverview: + app_status: ApplicationStatusInfo + deployment_statuses: List[DeploymentStatusInfo] = field(default_factory=list) + + def debug_string(self): + return json.dumps(asdict(self), indent=4) + + def get_deployment_status(self, name: str) -> Optional[DeploymentStatusInfo]: + """Get a deployment's status by name. + + Args: + name (str): Deployment's name. + + Return (Optional[DeploymentStatusInfo]): Status with a name matching + the argument, if one exists. Otherwise, returns None. + """ + + for deployment_status in self.deployment_statuses: + if name == deployment_status.name: + return deployment_status + + return None + + def to_proto(self): + + # Create a protobuf for the Serve Application info + app_status_proto = self.app_status.to_proto() + + # Create protobufs for all individual deployment statuses + deployment_status_protos = map( + lambda status: status.to_proto(), self.deployment_statuses + ) + + # Create a protobuf list containing all the deployment status protobufs + deployment_status_proto_list = DeploymentStatusInfoListProto() + deployment_status_proto_list.deployment_status_infos.extend( + deployment_status_protos + ) + + # Return protobuf encapsulating application and deployment protos + return StatusOverviewProto( + app_status=app_status_proto, + deployment_statuses=deployment_status_proto_list, + ) + + @classmethod + def from_proto(cls, proto: StatusOverviewProto): + + # Recreate Serve Application info + app_status = ApplicationStatusInfo.from_proto(proto.app_status) + + # Recreate deployment statuses + deployment_statuses = [] + for proto in proto.deployment_statuses.deployment_status_infos: + deployment_statuses.append(DeploymentStatusInfo.from_proto(proto)) + + # Recreate StatusInfo + return cls(app_status=app_status, deployment_statuses=deployment_statuses) + + HEALTH_CHECK_CONCURRENCY_GROUP = "health_check" REPLICA_DEFAULT_ACTOR_OPTIONS = { "concurrency_groups": {HEALTH_CHECK_CONCURRENCY_GROUP: 1} diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 530072e3d8a4..cc62d17df47e 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -18,6 +18,9 @@ EndpointInfo, NodeId, RunningReplicaInfo, + ApplicationStatus, + ApplicationStatusInfo, + StatusOverview, ) from ray.serve.config import DeploymentConfig, HTTPOptions, ReplicaConfig from ray.serve.constants import ( @@ -477,22 +480,21 @@ def list_deployments(self, include_deleted: Optional[bool] = False) -> bytes: ) return deployment_route_list.SerializeToString() - def get_deployment_statuses(self) -> bytes: - """Gets the current status information about all deployments. + def get_serve_status(self) -> bytes: - Returns: - DeploymentStatusInfoList's protobuf serialized bytes - """ - from ray.serve.generated.serve_pb2 import DeploymentStatusInfoList - - deployment_status_info_list = DeploymentStatusInfoList() - for ( - name, - deployment_status_info, - ) in self.deployment_state_manager.get_deployment_statuses().items(): - deployment_status_info_proto = deployment_status_info.to_proto() - deployment_status_info_proto.name = name - deployment_status_info_list.deployment_status_infos.append( - deployment_status_info_proto - ) - return deployment_status_info_list.SerializeToString() + # TODO (shrekris-anyscale): Replace defaults with actual REST API status + serve_app_status = ApplicationStatus.RUNNING + serve_app_message = "" + deployment_timestamp = time.time() + + app_status = ApplicationStatusInfo( + serve_app_status, serve_app_message, deployment_timestamp + ) + deployment_statuses = self.deployment_state_manager.get_deployment_statuses() + + status_info = StatusOverview( + app_status=app_status, + deployment_statuses=deployment_statuses, + ) + + return status_info.to_proto().SerializeToString() diff --git a/python/ray/serve/deployment_executor_node.py b/python/ray/serve/deployment_executor_node.py new file mode 100644 index 000000000000..61794a13aecd --- /dev/null +++ b/python/ray/serve/deployment_executor_node.py @@ -0,0 +1,79 @@ +from typing import Any, Dict, List + +from ray.experimental.dag import DAGNode +from ray.serve.deployment_method_executor_node import DeploymentMethodExecutorNode +from ray.experimental.dag.constants import DAGNODE_TYPE_KEY, PARENT_CLASS_NODE_KEY +from ray.experimental.dag.format_utils import get_dag_node_str +from ray.serve.handle import RayServeHandle + + +class DeploymentExecutorNode(DAGNode): + """The lightweight executor DAGNode of DeploymentNode that optimizes for + efficiency. + + - We need Ray DAGNode's traversal and replacement mechanism to deal + with deeply nested nodes as args in the DAG + - Meanwhile, __init__, _copy_impl and _execute_impl are on the critical + pass of execution for every request. + + Therefore for serve we introduce a minimal weight node as the final product + of DAG transformation, and will be used in actual execution as well as + deployment. + """ + + def __init__( + self, + deployment_handle, + dag_args, # Not deployment init args + dag_kwargs, # Not deployment init kwargs + ): + self._deployment_handle = deployment_handle + super().__init__(dag_args, dag_kwargs, {}, {}) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ) -> "DeploymentExecutorNode": + return DeploymentExecutorNode( + self._deployment_handle, + new_args, + new_kwargs, + ) + + def _execute_impl(self, *args, **kwargs) -> RayServeHandle: + """Does not call into anything or produce a new value, as the time + this function gets called, all child nodes are already resolved to + ObjectRefs. + """ + return self._deployment_handle + + def __getattr__(self, method_name: str): + return DeploymentMethodExecutorNode( + method_name, + (), + {}, + other_args_to_resolve={ + PARENT_CLASS_NODE_KEY: self, + }, + ) + + def __str__(self) -> str: + return get_dag_node_str(self, str(self._deployment_handle)) + + def to_json(self) -> Dict[str, Any]: + return { + DAGNODE_TYPE_KEY: DeploymentExecutorNode.__name__, + "deployment_handle": self._deployment_handle, + "args": self.get_args(), + "kwargs": self.get_kwargs(), + } + + @classmethod + def from_json(cls, input_json): + assert input_json[DAGNODE_TYPE_KEY] == DeploymentExecutorNode.__name__ + return cls( + input_json["deployment_handle"], input_json["args"], input_json["kwargs"] + ) diff --git a/python/ray/serve/deployment_function_executor_node.py b/python/ray/serve/deployment_function_executor_node.py new file mode 100644 index 000000000000..381f58bf7eb8 --- /dev/null +++ b/python/ray/serve/deployment_function_executor_node.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List, Union + +from ray import ObjectRef +from ray.experimental.dag import DAGNode +from ray.serve.handle import RayServeSyncHandle, RayServeHandle +from ray.experimental.dag.constants import DAGNODE_TYPE_KEY +from ray.experimental.dag.format_utils import get_dag_node_str + + +class DeploymentFunctionExecutorNode(DAGNode): + """The lightweight executor DAGNode of DeploymentFunctionNode that optimizes + for efficiency. + + - We need Ray DAGNode's traversal and replacement mechanism to deal + with deeply nested nodes as args in the DAG + - Meanwhile, __init__, _copy_impl and _execute_impl are on the critical + pass of execution for every request. + + Therefore for serve we introduce a minimal weight node as the final product + of DAG transformation, and will be used in actual execution as well as + deployment. + """ + + def __init__( + self, + deployment_function_handle: Union[RayServeSyncHandle, RayServeHandle], + func_args, + func_kwargs, + ): + super().__init__( + func_args, + func_kwargs, + {}, + {}, + ) + self._deployment_function_handle = deployment_function_handle + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ) -> "DeploymentFunctionExecutorNode": + return DeploymentFunctionExecutorNode( + self._deployment_function_handle, new_args, new_kwargs + ) + + def _execute_impl(self, *args, **kwargs) -> ObjectRef: + """Executor of DeploymentNode getting called each time on dag.execute. + + The execute implementation is recursive, that is, the method nodes will + receive whatever this method returns. We return a handle here so method + node can directly call upon. + """ + return self._deployment_function_handle.remote( + *self._bound_args, **self._bound_kwargs + ) + + def __str__(self) -> str: + return get_dag_node_str(self, str(self._deployment_function_handle)) + + def to_json(self) -> Dict[str, Any]: + return { + DAGNODE_TYPE_KEY: DeploymentFunctionExecutorNode.__name__, + "deployment_function_handle": self._deployment_function_handle, + "args": self.get_args(), + "kwargs": self.get_kwargs(), + } + + @classmethod + def from_json(cls, input_json): + assert input_json[DAGNODE_TYPE_KEY] == DeploymentFunctionExecutorNode.__name__ + return cls( + input_json["deployment_function_handle"], + input_json["args"], + input_json["kwargs"], + ) diff --git a/python/ray/serve/deployment_method_executor_node.py b/python/ray/serve/deployment_method_executor_node.py new file mode 100644 index 000000000000..99d24600ac41 --- /dev/null +++ b/python/ray/serve/deployment_method_executor_node.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, List + +from ray import ObjectRef +from ray.experimental.dag import DAGNode +from ray.experimental.dag.constants import DAGNODE_TYPE_KEY, PARENT_CLASS_NODE_KEY +from ray.experimental.dag.format_utils import get_dag_node_str + + +class DeploymentMethodExecutorNode(DAGNode): + """The lightweight executor DAGNode of DeploymentMethodNode that optimizes + for efficiency. + + - We need Ray DAGNode's traversal and replacement mechanism to deal + with deeply nested nodes as args in the DAG + - Meanwhile, __init__, _copy_impl and _execute_impl are on the critical + pass of execution for every request. + + Therefore for serve we introduce a minimal weight node as the final product + of DAG transformation, and will be used in actual execution as well as + deployment. + """ + + def __init__( + self, + deployment_method_name: str, + dag_args, + dag_kwargs, + other_args_to_resolve=None, + ): + super().__init__( + dag_args, dag_kwargs, {}, other_args_to_resolve=other_args_to_resolve + ) + self._deployment_node_replaced_by_handle = other_args_to_resolve[ + PARENT_CLASS_NODE_KEY + ] + self._deployment_method_name = deployment_method_name + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ) -> "DeploymentMethodExecutorNode": + return DeploymentMethodExecutorNode( + self._deployment_method_name, + new_args, + new_kwargs, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def _execute_impl(self, *args, **kwargs) -> ObjectRef: + """Executor of DeploymentNode getting called each time on dag.execute. + + The execute implementation is recursive, that is, the method nodes will + receive whatever this method returns. We return a handle here so method + node can directly call upon. + """ + method_body = getattr( + self._deployment_node_replaced_by_handle, self._deployment_method_name + ) + return method_body.remote(*self._bound_args, **self._bound_kwargs) + + def __str__(self) -> str: + return get_dag_node_str(self, str(self._deployment_method_name)) + + def to_json(self) -> Dict[str, Any]: + return { + DAGNODE_TYPE_KEY: DeploymentMethodExecutorNode.__name__, + "deployment_method_name": self._deployment_method_name, + "args": self.get_args(), + "kwargs": self.get_kwargs(), + "other_args_to_resolve": self.get_other_args_to_resolve(), + } + + @classmethod + def from_json(cls, input_json): + assert input_json[DAGNODE_TYPE_KEY] == DeploymentMethodExecutorNode.__name__ + return cls( + input_json["deployment_method_name"], + input_json["args"], + input_json["kwargs"], + other_args_to_resolve=input_json["other_args_to_resolve"], + ) diff --git a/python/ray/serve/deployment_state.py b/python/ray/serve/deployment_state.py index ab1cc8af10ed..b0e169bf7f7f 100644 --- a/python/ray/serve/deployment_state.py +++ b/python/ray/serve/deployment_state.py @@ -941,7 +941,7 @@ def __init__( self._replica_constructor_retry_counter: int = 0 self._replicas: ReplicaStateContainer = ReplicaStateContainer() self._curr_status_info: DeploymentStatusInfo = DeploymentStatusInfo( - DeploymentStatus.UPDATING + self._name, DeploymentStatus.UPDATING ) def get_target_state_checkpoint_data(self): @@ -1045,7 +1045,9 @@ def _set_deployment_goal(self, deployment_info: Optional[DeploymentInfo]) -> Non else: self._target_replicas = 0 - self._curr_status_info = DeploymentStatusInfo(DeploymentStatus.UPDATING) + self._curr_status_info = DeploymentStatusInfo( + self._name, DeploymentStatus.UPDATING + ) version_str = ( deployment_info if deployment_info is None else deployment_info.version @@ -1306,6 +1308,7 @@ def _check_curr_status(self) -> bool: self._replica_constructor_retry_counter = -1 else: self._curr_status_info = DeploymentStatusInfo( + name=self._name, status=DeploymentStatus.UNHEALTHY, message=( "The Deployment constructor failed " @@ -1333,7 +1336,9 @@ def _check_curr_status(self) -> bool: # Check for a non-zero number of deployments. elif target_replica_count == running_at_target_version_replica_cnt: - self._curr_status_info = DeploymentStatusInfo(DeploymentStatus.HEALTHY) + self._curr_status_info = DeploymentStatusInfo( + self._name, DeploymentStatus.HEALTHY + ) return False return False @@ -1412,7 +1417,7 @@ def _check_and_update_replicas(self) -> bool: # recovered or a new deploy happens. if replica.version == self._target_version: self._curr_status_info: DeploymentStatusInfo = DeploymentStatusInfo( - DeploymentStatus.UNHEALTHY + self._name, DeploymentStatus.UNHEALTHY ) slow_start_replicas = [] @@ -1505,6 +1510,7 @@ def update(self) -> bool: deleted = self._check_curr_status() except Exception: self._curr_status_info = DeploymentStatusInfo( + name=self._name, status=DeploymentStatus.UNHEALTHY, message="Failed to update deployment:" f"\n{traceback.format_exc()}", ) @@ -1690,11 +1696,10 @@ def get_deployment( else: return None - def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]: - return { - name: state.curr_status_info - for name, state in self._deployment_states.items() - } + def get_deployment_statuses(self) -> List[DeploymentStatusInfo]: + return list( + map(lambda state: state.curr_status_info, self._deployment_states.values()) + ) def deploy(self, deployment_name: str, deployment_info: DeploymentInfo) -> bool: """Deploy the deployment. diff --git a/python/ray/serve/pipeline/api.py b/python/ray/serve/pipeline/api.py index afa96854d928..00bb1e6ba08b 100644 --- a/python/ray/serve/pipeline/api.py +++ b/python/ray/serve/pipeline/api.py @@ -4,7 +4,9 @@ from ray.serve.pipeline.generate import ( transform_ray_dag_to_serve_dag, extract_deployments_from_serve_dag, + transform_serve_dag_to_serve_executor_dag, process_ingress_deployment_in_serve_dag, + generate_executor_dag_driver_deployment, ) from ray.serve.deployment import Deployment from ray.experimental.dag.utils import DAGNodeNameGenerator @@ -67,6 +69,19 @@ def build(ray_dag_root_node: DAGNode) -> List[Deployment]: lambda node: transform_ray_dag_to_serve_dag(node, node_name_generator) ) deployments = extract_deployments_from_serve_dag(serve_root_dag) + + # After Ray DAG is transformed to Serve DAG with deployments and their init + # args filled, generate a minimal weight executor serve dag for perf + serve_executor_root_dag = serve_root_dag.apply_recursive( + transform_serve_dag_to_serve_executor_dag + ) + root_driver_deployment = deployments[-1] + new_driver_deployment = generate_executor_dag_driver_deployment( + serve_executor_root_dag, root_driver_deployment + ) + # Replace DAGDriver deployment with executor DAGDriver deployment + deployments[-1] = new_driver_deployment + # Validate and only expose HTTP for the endpoint deployments_with_http = process_ingress_deployment_in_serve_dag(deployments) return deployments_with_http diff --git a/python/ray/serve/pipeline/deployment_function_node.py b/python/ray/serve/pipeline/deployment_function_node.py index 9690fbec18f8..0b1df10f0510 100644 --- a/python/ray/serve/pipeline/deployment_function_node.py +++ b/python/ray/serve/pipeline/deployment_function_node.py @@ -7,8 +7,8 @@ from ray.experimental.dag.constants import DAGNODE_TYPE_KEY from ray.serve.deployment import Deployment, schema_to_deployment from ray.serve.config import DeploymentConfig -from ray.serve.schema import DeploymentSchema from ray.serve.handle import RayServeLazySyncHandle +from ray.serve.schema import DeploymentSchema from ray.serve.utils import get_deployment_import_path @@ -74,7 +74,7 @@ def __init__( _internal=True, ) # TODO (jiaodong): Polish with async handle support later - self._deployment_handle = RayServeLazySyncHandle(deployment_name) + self._deployment_handle = RayServeLazySyncHandle(self._deployment.name) def _copy_impl( self, @@ -130,7 +130,7 @@ def to_json(self) -> Dict[str, Any]: @classmethod def from_json(cls, input_json): assert input_json[DAGNODE_TYPE_KEY] == DeploymentFunctionNode.__name__ - node = cls( + return cls( input_json["import_path"], input_json["deployment_name"], input_json["args"], @@ -138,5 +138,3 @@ def from_json(cls, input_json): input_json["options"], other_args_to_resolve=input_json["other_args_to_resolve"], ) - node._stable_uuid = input_json["uuid"] - return node diff --git a/python/ray/serve/pipeline/deployment_method_node.py b/python/ray/serve/pipeline/deployment_method_node.py index 74072a26fe80..9b303a3e7356 100644 --- a/python/ray/serve/pipeline/deployment_method_node.py +++ b/python/ray/serve/pipeline/deployment_method_node.py @@ -21,7 +21,7 @@ def __init__( ): self._deployment = deployment self._deployment_method_name: str = deployment_method_name - self._deployment_node = other_args_to_resolve[PARENT_CLASS_NODE_KEY] + self._deployment_handle = other_args_to_resolve[PARENT_CLASS_NODE_KEY] super().__init__( method_args, method_kwargs, @@ -48,7 +48,7 @@ def _copy_impl( def _execute_impl(self, *args, **kwargs): """Executor of DeploymentMethodNode by ray.remote()""" # Execute with bound args. - method_body = getattr(self._deployment_node, self._deployment_method_name) + method_body = getattr(self._deployment_handle, self._deployment_method_name) return method_body.remote( *self._bound_args, **self._bound_kwargs, diff --git a/python/ray/serve/pipeline/deployment_node.py b/python/ray/serve/pipeline/deployment_node.py index dc6a254b2433..64630fcbf70c 100644 --- a/python/ray/serve/pipeline/deployment_node.py +++ b/python/ray/serve/pipeline/deployment_node.py @@ -3,11 +3,24 @@ from typing import Any, Callable, Dict, Optional, List, Tuple, Union from ray.experimental.dag import DAGNode, InputNode -from ray.serve.handle import RayServeLazySyncHandle, RayServeSyncHandle, RayServeHandle +from ray.serve.deployment_executor_node import DeploymentExecutorNode +from ray.serve.deployment_function_executor_node import ( + DeploymentFunctionExecutorNode, +) +from ray.serve.deployment_method_executor_node import ( + DeploymentMethodExecutorNode, +) +from ray.serve.handle import ( + RayServeLazySyncHandle, + RayServeSyncHandle, + RayServeHandle, +) from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode from ray.serve.pipeline.deployment_function_node import DeploymentFunctionNode -from ray.serve.pipeline.constants import USE_SYNC_HANDLE_KEY -from ray.experimental.dag.constants import DAGNODE_TYPE_KEY, PARENT_CLASS_NODE_KEY +from ray.experimental.dag.constants import ( + DAGNODE_TYPE_KEY, + PARENT_CLASS_NODE_KEY, +) from ray.experimental.dag.format_utils import get_dag_node_str from ray.serve.deployment import Deployment, schema_to_deployment from ray.serve.deployment_graph import RayServeDAGHandle @@ -54,12 +67,24 @@ def __init__( # Thus we need convert all DeploymentNode used in init args into # deployment handles (executable and picklable) in ray serve DAG to make # serve DAG end to end executable. + # TODO(jiaodong): This part does some magic for DAGDriver and will throw + # error with weird pickle replace table error. Move this out. def replace_with_handle(node): if isinstance(node, DeploymentNode): return node._get_serve_deployment_handle( node._deployment, node._bound_other_args_to_resolve ) - elif isinstance(node, (DeploymentMethodNode, DeploymentFunctionNode)): + elif isinstance(node, DeploymentExecutorNode): + return node._deployment_handle + elif isinstance( + node, + ( + DeploymentMethodNode, + DeploymentMethodExecutorNode, + DeploymentFunctionNode, + DeploymentFunctionExecutorNode, + ), + ): from ray.serve.pipeline.json_serde import DAGNodeEncoder serve_dag_root_json = json.dumps(node, cls=DAGNodeEncoder) @@ -71,7 +96,15 @@ def replace_with_handle(node): ) = self.apply_functional( [deployment_init_args, deployment_init_kwargs], predictate_fn=lambda node: isinstance( - node, (DeploymentNode, DeploymentMethodNode, DeploymentFunctionNode) + node, + ( + DeploymentNode, + DeploymentMethodNode, + DeploymentFunctionNode, + DeploymentExecutorNode, + DeploymentFunctionExecutorNode, + DeploymentMethodExecutorNode, + ), ), apply_fn=replace_with_handle, ) @@ -117,9 +150,9 @@ def replace_with_handle(node): ray_actor_options=ray_actor_options, _internal=True, ) - self._deployment_handle: Union[ - RayServeLazySyncHandle, RayServeHandle, RayServeSyncHandle - ] = self._get_serve_deployment_handle(self._deployment, other_args_to_resolve) + self._deployment_handle: RayServeLazySyncHandle = ( + self._get_serve_deployment_handle(self._deployment, other_args_to_resolve) + ) def _copy_impl( self, @@ -165,20 +198,8 @@ def _get_serve_deployment_handle( return async handle only if user explicitly set USE_SYNC_HANDLE_KEY with value of False. """ - # TODO (jiaodong): Support configurable async handle - if USE_SYNC_HANDLE_KEY not in bound_other_args_to_resolve: - # Return sync RayServeLazySyncHandle - return RayServeLazySyncHandle(deployment.name) - elif bound_other_args_to_resolve.get(USE_SYNC_HANDLE_KEY) is True: - # Return sync RayServeSyncHandle - return deployment.get_handle(sync=True) - elif bound_other_args_to_resolve.get(USE_SYNC_HANDLE_KEY) is False: - # Return async RayServeHandle - return deployment.get_handle(sync=False) - else: - raise ValueError( - f"{USE_SYNC_HANDLE_KEY} should only be set with a boolean value." - ) + # TODO: (jiaodong) Support async handle + return RayServeLazySyncHandle(deployment.name) def _contains_input_node(self) -> bool: """Check if InputNode is used in children DAGNodes with current node @@ -234,7 +255,7 @@ def to_json(self) -> Dict[str, Any]: } @classmethod - def from_json(cls, input_json, object_hook=None): + def from_json(cls, input_json): assert input_json[DAGNODE_TYPE_KEY] == DeploymentNode.__name__ return cls( input_json["import_path"], diff --git a/python/ray/serve/pipeline/generate.py b/python/ray/serve/pipeline/generate.py index ec0a99c9d5f3..e6711560179d 100644 --- a/python/ray/serve/pipeline/generate.py +++ b/python/ray/serve/pipeline/generate.py @@ -1,3 +1,4 @@ +import json from typing import List from collections import OrderedDict @@ -11,9 +12,14 @@ from ray.experimental.dag.input_node import InputNode from ray.experimental.dag.utils import DAGNodeNameGenerator from ray.serve.deployment import Deployment +from ray.serve.deployment_graph import RayServeDAGHandle from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode from ray.serve.pipeline.deployment_node import DeploymentNode from ray.serve.pipeline.deployment_function_node import DeploymentFunctionNode +from ray.serve.deployment_executor_node import DeploymentExecutorNode +from ray.serve.deployment_method_executor_node import DeploymentMethodExecutorNode +from ray.serve.deployment_function_executor_node import DeploymentFunctionExecutorNode +from ray.serve.pipeline.json_serde import DAGNodeEncoder def transform_ray_dag_to_serve_dag( @@ -95,6 +101,90 @@ def extractor(dag_node): return list(deployments.values()) +def transform_serve_dag_to_serve_executor_dag(serve_dag_root_node: DAGNode): + """Given a runnable serve dag with deployment init args and options + processed, transform into an equivalent, but minimal dag optimized for + execution. + """ + if isinstance(serve_dag_root_node, DeploymentNode): + return DeploymentExecutorNode( + serve_dag_root_node._deployment_handle, + serve_dag_root_node.get_args(), + serve_dag_root_node.get_kwargs(), + ) + elif isinstance(serve_dag_root_node, DeploymentMethodNode): + return DeploymentMethodExecutorNode( + # Deployment method handle + serve_dag_root_node._deployment_method_name, + serve_dag_root_node.get_args(), + serve_dag_root_node.get_kwargs(), + other_args_to_resolve=serve_dag_root_node.get_other_args_to_resolve(), + ) + elif isinstance(serve_dag_root_node, DeploymentFunctionNode): + return DeploymentFunctionExecutorNode( + serve_dag_root_node._deployment_handle, + serve_dag_root_node.get_args(), + serve_dag_root_node.get_kwargs(), + ) + else: + return serve_dag_root_node + + +def generate_executor_dag_driver_deployment( + serve_executor_dag_root_node: DAGNode, original_driver_deployment: Deployment +): + """Given a transformed minimal execution serve dag, and original DAGDriver + deployment, generate new DAGDriver deployment that uses new serve executor + dag as init_args. + + Args: + serve_executor_dag_root_node (DeploymentExecutorNode): Transformed + executor serve dag with only barebone deployment handles. + original_driver_deployment (Deployment): User's original DAGDriver + deployment that wrapped Ray DAG as init args. + Returns: + executor_dag_driver_deployment (Deployment): New DAGDriver deployment + with executor serve dag as init args. + """ + + def replace_with_handle(node): + if isinstance(node, DeploymentExecutorNode): + return node._deployment_handle + elif isinstance( + node, + ( + DeploymentMethodExecutorNode, + DeploymentFunctionExecutorNode, + ), + ): + serve_dag_root_json = json.dumps(node, cls=DAGNodeEncoder) + return RayServeDAGHandle(serve_dag_root_json) + + ( + replaced_deployment_init_args, + replaced_deployment_init_kwargs, + ) = serve_executor_dag_root_node.apply_functional( + [ + serve_executor_dag_root_node.get_args(), + serve_executor_dag_root_node.get_kwargs(), + ], + predictate_fn=lambda node: isinstance( + node, + ( + DeploymentExecutorNode, + DeploymentFunctionExecutorNode, + DeploymentMethodExecutorNode, + ), + ), + apply_fn=replace_with_handle, + ) + + return original_driver_deployment.options( + init_args=replaced_deployment_init_args, + init_kwargs=replaced_deployment_init_kwargs, + ) + + def get_pipeline_input_node(serve_dag_root_node: DAGNode): """Return the InputNode singleton node from serve dag, and throw exceptions if we didn't find any, or found more than one. diff --git a/python/ray/serve/pipeline/json_serde.py b/python/ray/serve/pipeline/json_serde.py index 392e4cdfbcfa..5a4085ab85f7 100644 --- a/python/ray/serve/pipeline/json_serde.py +++ b/python/ray/serve/pipeline/json_serde.py @@ -15,6 +15,10 @@ from ray.serve.pipeline.deployment_node import DeploymentNode from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode from ray.serve.pipeline.deployment_function_node import DeploymentFunctionNode +from ray.serve.deployment_executor_node import DeploymentExecutorNode +from ray.serve.deployment_method_executor_node import DeploymentMethodExecutorNode +from ray.serve.deployment_function_executor_node import DeploymentFunctionExecutorNode + from ray.serve.schema import ( DeploymentSchema, ) @@ -125,6 +129,21 @@ def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]: that we perserve the same parent node. - .options() does not contain any DAGNode type """ + node_type_to_cls = { + # Ray DAG Inputs + InputNode.__name__: InputNode, + InputAttributeNode.__name__: InputAttributeNode, + # Ray DAG Nodes + ClassMethodNode.__name__: ClassMethodNode, + # Deployment transformation nodes + DeploymentNode.__name__: DeploymentNode, + DeploymentMethodNode.__name__: DeploymentMethodNode, + DeploymentFunctionNode.__name__: DeploymentFunctionNode, + # Deployment graph execution nodes + DeploymentExecutorNode.__name__: DeploymentExecutorNode, + DeploymentMethodExecutorNode.__name__: DeploymentMethodExecutorNode, + DeploymentFunctionExecutorNode.__name__: DeploymentFunctionExecutorNode, + } # Deserialize RayServeHandle type if SERVE_HANDLE_JSON_KEY in input_json: return serve_handle_from_json_dict(input_json) @@ -141,18 +160,8 @@ def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]: HandleOptions(input_json["handle_options_method_name"]), ) # Deserialize DAGNode type - elif input_json[DAGNODE_TYPE_KEY] == InputNode.__name__: - return InputNode.from_json(input_json) - elif input_json[DAGNODE_TYPE_KEY] == InputAttributeNode.__name__: - return InputAttributeNode.from_json(input_json) - elif input_json[DAGNODE_TYPE_KEY] == ClassMethodNode.__name__: - return ClassMethodNode.from_json(input_json) - elif input_json[DAGNODE_TYPE_KEY] == DeploymentNode.__name__: - return DeploymentNode.from_json(input_json) - elif input_json[DAGNODE_TYPE_KEY] == DeploymentMethodNode.__name__: - return DeploymentMethodNode.from_json(input_json) - elif input_json[DAGNODE_TYPE_KEY] == DeploymentFunctionNode.__name__: - return DeploymentFunctionNode.from_json(input_json) + elif input_json[DAGNODE_TYPE_KEY] in node_type_to_cls: + return node_type_to_cls[input_json[DAGNODE_TYPE_KEY]].from_json(input_json) else: # Class and Function nodes require original module as body. module_name, attr_name = parse_import_path(input_json["import_path"]) diff --git a/python/ray/serve/pipeline/tests/test_deployment_node.py b/python/ray/serve/pipeline/tests/test_deployment_node.py index 84e029f72375..e80397e133fe 100644 --- a/python/ray/serve/pipeline/tests/test_deployment_node.py +++ b/python/ray/serve/pipeline/tests/test_deployment_node.py @@ -43,7 +43,7 @@ async def get(self): return self.i -@pytest.mark.asyncio +@pytest.mark.skip(reason="async handle not enabled yet") async def test_simple_deployment_async(serve_instance): """Internal testing only for simple creation and execution. @@ -131,21 +131,6 @@ def test_no_input_node_as_init_args(): ) -def test_invalid_use_sync_handle(): - with pytest.raises( - ValueError, - match=f"{USE_SYNC_HANDLE_KEY} should only be set with a boolean value", - ): - _ = DeploymentNode( - Actor, - "test", - [], - {}, - {}, - other_args_to_resolve={USE_SYNC_HANDLE_KEY: {"options_a": "hii"}}, - ) - - def test_mix_sync_async_handle(serve_instance): # TODO: (jiaodong) Add complex multi-deployment tests from ray DAG. pass diff --git a/python/ray/serve/pipeline/tests/test_generate.py b/python/ray/serve/pipeline/tests/test_generate.py index f91629c22f28..60368026b018 100644 --- a/python/ray/serve/pipeline/tests/test_generate.py +++ b/python/ray/serve/pipeline/tests/test_generate.py @@ -2,14 +2,13 @@ import ray from ray import serve -from ray.serve.handle import RayServeLazySyncHandle from ray.experimental.dag import InputNode +from ray.serve.handle import RayServeLazySyncHandle from ray.serve.pipeline.generate import ( transform_ray_dag_to_serve_dag, extract_deployments_from_serve_dag, get_pipeline_input_node, ) -from ray.serve.pipeline.api import build from ray.serve.pipeline.tests.resources.test_modules import ( Model, NESTED_HANDLE_KEY, @@ -238,12 +237,21 @@ def test_get_pipeline_input_node(): get_pipeline_input_node(serve_dag) -def test_unique_name_reset_upon_build(): +def test_unique_name_reset_upon_build(serve_instance): ray_dag, _ = get_multi_instantiation_class_deployment_in_init_args_dag() - deployments = build(ray_dag) + with DAGNodeNameGenerator() as node_name_generator: + serve_root_dag = ray_dag.apply_recursive( + lambda node: transform_ray_dag_to_serve_dag(node, node_name_generator) + ) + deployments = extract_deployments_from_serve_dag(serve_root_dag) assert deployments[0].name == "Model" assert deployments[1].name == "Model_1" - deployments = build(ray_dag) + + with DAGNodeNameGenerator() as node_name_generator: + serve_root_dag = ray_dag.apply_recursive( + lambda node: transform_ray_dag_to_serve_dag(node, node_name_generator) + ) + deployments = extract_deployments_from_serve_dag(serve_root_dag) # Assert we don't keep increasing suffix id between build() calls assert deployments[0].name == "Model" assert deployments[1].name == "Model_1" diff --git a/python/ray/serve/schema.py b/python/ray/serve/schema.py index f193af35c102..ab8e86ca0f37 100644 --- a/python/ray/serve/schema.py +++ b/python/ray/serve/schema.py @@ -1,7 +1,11 @@ from pydantic import BaseModel, Field, Extra, root_validator, validator from typing import Union, Tuple, List, Dict from ray._private.runtime_env.packaging import parse_uri -from ray.serve.common import DeploymentStatus, DeploymentStatusInfo +from ray.serve.common import ( + DeploymentStatusInfo, + ApplicationStatusInfo, + StatusOverview, +) from ray.serve.utils import DEFAULT @@ -393,46 +397,29 @@ def import_path_format_valid(cls, v: str): ) -class DeploymentStatusSchema(BaseModel, extra=Extra.forbid): - name: str = Field(..., description="The deployment's name.") - status: DeploymentStatus = Field( - default=None, description="The deployment's status." +class ServeStatusSchema(BaseModel, extra=Extra.forbid): + app_status: ApplicationStatusInfo = Field( + ..., + description=( + "Describes if the Serve application is DEPLOYING, if the " + "DEPLOY_FAILED, or if the app is RUNNING. Includes a timestamp of " + "when the application was deployed." + ), ) - message: str = Field( - default="", description="Information about the deployment's status." + deployment_statuses: List[DeploymentStatusInfo] = Field( + default=[], + description=( + "List of statuses for all the deployments running in this Serve " + "application. Each status contains the deployment name, the " + "deployment's status, and a message providing extra context on " + "the status." + ), ) -class ServeApplicationStatusSchema(BaseModel, extra=Extra.forbid): - statuses: List[DeploymentStatusSchema] = Field(...) - - -def status_info_to_schema( - deployment_name: str, status_info: Union[DeploymentStatusInfo, Dict] -) -> DeploymentStatusSchema: - if isinstance(status_info, DeploymentStatusInfo): - return DeploymentStatusSchema( - name=deployment_name, status=status_info.status, message=status_info.message - ) - elif isinstance(status_info, dict): - return DeploymentStatusSchema( - name=deployment_name, - status=status_info["status"], - message=status_info["message"], - ) - else: - raise TypeError( - f"Got {type(status_info)} as status_info's " - "type. Expected status_info to be either a " - "DeploymentStatusInfo or a dictionary." - ) - - -def serve_application_status_to_schema( - status_infos: Dict[str, Union[DeploymentStatusInfo, Dict]] -) -> ServeApplicationStatusSchema: - schemas = [ - status_info_to_schema(deployment_name, status_info) - for deployment_name, status_info in status_infos.items() - ] - return ServeApplicationStatusSchema(statuses=schemas) +def serve_status_to_schema(serve_status: StatusOverview) -> ServeStatusSchema: + + return ServeStatusSchema( + app_status=serve_status.app_status, + deployment_statuses=serve_status.deployment_statuses, + ) diff --git a/python/ray/serve/scripts.py b/python/ray/serve/scripts.py index c2265bbfc2ae..6b7cd9241058 100644 --- a/python/ray/serve/scripts.py +++ b/python/ray/serve/scripts.py @@ -1,12 +1,11 @@ #!/usr/bin/env python -import json import os -import pathlib -import click -import time import sys -from typing import Optional, Union +import time import yaml +import click +import pathlib +from typing import Optional, Union import ray from ray._private.utils import import_attr @@ -350,7 +349,7 @@ def config(address: str): def status(address: str): app_status = ServeSubmissionClient(address).get_status() if app_status is not None: - print(json.dumps(app_status["statuses"], indent=4)) + print(yaml.safe_dump(app_status, default_flow_style=False, sort_keys=False)) @cli.command( diff --git a/python/ray/serve/tests/test_cli.py b/python/ray/serve/tests/test_cli.py index 33f08cec26ef..62ac55aa304d 100644 --- a/python/ray/serve/tests/test_cli.py +++ b/python/ray/serve/tests/test_cli.py @@ -1,11 +1,11 @@ -import yaml -import json import os -import subprocess import sys +import time +import yaml import signal import pytest import requests +import subprocess from tempfile import NamedTemporaryFile import ray @@ -107,6 +107,7 @@ def test_deploy(ray_start_stop): assert success_message_fragment in deploy_response for name, deployment_config in expected_deployments.items(): + # New deployments must be deployed wait_for_condition( lambda: ( requests.get(f"{request_url}{name}").text @@ -115,6 +116,12 @@ def test_deploy(ray_start_stop): timeout=15, ) + # Outdated deployments should be deleted + wait_for_condition( + lambda: len(serve.list_deployments()) == len(expected_deployments), + timeout=15, + ) + running_deployments = serve.list_deployments() # Check that running deployment names match expected deployment names @@ -209,15 +216,21 @@ def test_status(ray_start_stop): subprocess.check_output(["serve", "deploy", config_file_name]) status_response = subprocess.check_output(["serve", "status"]) - statuses = json.loads(status_response) + serve_status = yaml.safe_load(status_response) expected_deployments = {"shallow", "deep", "one"} - for status in statuses: + for status in serve_status["deployment_statuses"]: expected_deployments.remove(status["name"]) assert status["status"] in {"HEALTHY", "UPDATING"} assert "message" in status assert len(expected_deployments) == 0 + assert serve_status["app_status"]["status"] in {"DEPLOYING", "RUNNING"} + wait_for_condition( + lambda: time.time() > serve_status["app_status"]["deployment_timestamp"], + timeout=2, + ) + @pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.") def test_delete(ray_start_stop): diff --git a/python/ray/serve/tests/test_common.py b/python/ray/serve/tests/test_common.py index beda91f42486..e3284c0fc044 100644 --- a/python/ray/serve/tests/test_common.py +++ b/python/ray/serve/tests/test_common.py @@ -1,7 +1,20 @@ +import time import pytest from ray.serve.utils import get_random_letters -from ray.serve.common import ReplicaName +from ray.serve.common import ( + ReplicaName, + StatusOverview, + DeploymentStatus, + DeploymentStatusInfo, + ApplicationStatus, + ApplicationStatusInfo, +) +from ray.serve.generated.serve_pb2 import ( + StatusOverview as StatusOverviewProto, + DeploymentStatusInfo as DeploymentStatusInfoProto, + ApplicationStatusInfo as ApplicationStatusInfoProto, +) def test_replica_tag_formatting(): @@ -48,6 +61,130 @@ def test_is_replica_name(): ) +class TestDeploymentStatusInfo: + def test_name_required(self): + with pytest.raises(TypeError): + DeploymentStatusInfo(status=DeploymentStatus.HEALTHY) + + def test_deployment_status_required(self): + with pytest.raises(TypeError): + DeploymentStatusInfo(name="test_name") + + @pytest.mark.parametrize("status", list(DeploymentStatus)) + def test_proto(self, status): + deployment_status_info = DeploymentStatusInfo( + name="test_name", status=status, message="context about status" + ) + serialized_proto = deployment_status_info.to_proto().SerializeToString() + deserialized_proto = DeploymentStatusInfoProto.FromString(serialized_proto) + reconstructed_info = DeploymentStatusInfo.from_proto(deserialized_proto) + + assert deployment_status_info == reconstructed_info + + +class TestApplicationStatusInfo: + def test_application_status_required(self): + with pytest.raises(TypeError): + ApplicationStatusInfo( + message="context about status", deployment_timestamp=time.time() + ) + + @pytest.mark.parametrize("status", list(ApplicationStatus)) + def test_proto(self, status): + serve_application_status_info = ApplicationStatusInfo( + status=status, + message="context about status", + deployment_timestamp=time.time(), + ) + serialized_proto = serve_application_status_info.to_proto().SerializeToString() + deserialized_proto = ApplicationStatusInfoProto.FromString(serialized_proto) + reconstructed_info = ApplicationStatusInfo.from_proto(deserialized_proto) + + assert serve_application_status_info == reconstructed_info + + +class TestStatusOverview: + def get_valid_serve_application_status_info(self): + return ApplicationStatusInfo( + status=ApplicationStatus.RUNNING, + message="", + deployment_timestamp=time.time(), + ) + + def test_app_status_required(self): + with pytest.raises(TypeError): + StatusOverview(deployment_statuses=[]) + + def test_empty_list_valid(self): + """Should be able to create StatusOverview with no deployment statuses.""" + + # Check default is empty list + status_info = StatusOverview( + app_status=self.get_valid_serve_application_status_info() + ) + status_info.deployment_statuses == [] + + # Ensure empty list can be passed in explicitly + status_info = StatusOverview( + app_status=self.get_valid_serve_application_status_info(), + deployment_statuses=[], + ) + status_info.deployment_statuses == [] + + def test_equality_mismatched_deployment_statuses(self): + """Check that StatusOverviews with different numbers of statuses are unequal.""" + + status_info_few_deployments = StatusOverview( + app_status=self.get_valid_serve_application_status_info(), + deployment_statuses=[ + DeploymentStatusInfo(name="1", status=DeploymentStatus.HEALTHY), + DeploymentStatusInfo(name="2", status=DeploymentStatus.UNHEALTHY), + ], + ) + + status_info_many_deployments = StatusOverview( + app_status=self.get_valid_serve_application_status_info(), + deployment_statuses=[ + DeploymentStatusInfo(name="1", status=DeploymentStatus.HEALTHY), + DeploymentStatusInfo(name="2", status=DeploymentStatus.UNHEALTHY), + DeploymentStatusInfo(name="3", status=DeploymentStatus.UNHEALTHY), + DeploymentStatusInfo(name="4", status=DeploymentStatus.UPDATING), + ], + ) + + assert status_info_few_deployments != status_info_many_deployments + + @pytest.mark.parametrize("application_status", list(ApplicationStatus)) + def test_proto(self, application_status): + status_info = StatusOverview( + app_status=ApplicationStatusInfo( + status=application_status, + message="context about this status", + deployment_timestamp=time.time(), + ), + deployment_statuses=[ + DeploymentStatusInfo( + name="name1", + status=DeploymentStatus.UPDATING, + message="deployment updating", + ), + DeploymentStatusInfo( + name="name2", status=DeploymentStatus.HEALTHY, message="" + ), + DeploymentStatusInfo( + name="name3", + status=DeploymentStatus.UNHEALTHY, + message="this deployment is unhealthy", + ), + ], + ) + serialized_proto = status_info.to_proto().SerializeToString() + deserialized_proto = StatusOverviewProto.FromString(serialized_proto) + reconstructed_info = StatusOverview.from_proto(deserialized_proto) + + assert status_info == reconstructed_info + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/tests/test_pipeline_dag.py b/python/ray/serve/tests/test_pipeline_dag.py index 36d00ced6a67..4e8edff8e532 100644 --- a/python/ray/serve/tests/test_pipeline_dag.py +++ b/python/ray/serve/tests/test_pipeline_dag.py @@ -156,16 +156,23 @@ def func_1(input): def func_2(input): return input * 2 + @serve.deployment + def func_3(input): + return input * 3 + with InputNode() as dag_input: output_1 = func_1.bind(dag_input) output_2 = func_2.bind(dag_input) - serve_dag = combine.bind(output_1, output_2) + output_3 = func_3.bind(output_2) + ray_dag = combine.bind(output_1, output_2, kwargs_output=output_3) with pytest.raises(ValueError, match="Please provide a driver class"): - _ = serve.run(serve_dag) + _ = serve.run(ray_dag) - handle = serve.run(DAGDriver.bind(serve_dag, http_adapter=json_resolver)) - assert ray.get(handle.predict.remote(2)) == 6 # 2 + 2*2 - assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6 + serve_dag = DAGDriver.bind(ray_dag, http_adapter=json_resolver) + + handle = serve.run(serve_dag) + assert ray.get(handle.predict.remote(2)) == 18 # 2 + 2*2 + (2*2) * 3 + assert requests.post("http://127.0.0.1:8000/", json=2).json() == 18 @pytest.mark.parametrize("use_build", [False, True]) diff --git a/python/ray/serve/tests/test_schema.py b/python/ray/serve/tests/test_schema.py index d26ac1da595b..e3b51a2534f1 100644 --- a/python/ray/serve/tests/test_schema.py +++ b/python/ray/serve/tests/test_schema.py @@ -1,24 +1,26 @@ import sys -import requests +import time import pytest +import requests from pydantic import ValidationError from typing import List, Dict import ray from ray import serve +from ray.serve.common import ( + StatusOverview, + DeploymentStatusInfo, + ApplicationStatusInfo, +) from ray.serve.schema import ( RayActorOptionsSchema, DeploymentSchema, - DeploymentStatusSchema, ServeApplicationSchema, - ServeApplicationStatusSchema, - status_info_to_schema, - serve_application_status_to_schema, + ServeStatusSchema, + serve_status_to_schema, ) from ray.util.accelerators.accelerators import NVIDIA_TESLA_V100, NVIDIA_TESLA_P4 from ray.serve.config import AutoscalingConfig -from ray.serve.common import DeploymentStatus, DeploymentStatusInfo -from ray.serve.api import get_deployment_statuses from ray.serve.deployment import ( deployment_to_schema, schema_to_deployment, @@ -503,91 +505,49 @@ def test_serve_application_invalid_import_path(self, path): ServeApplicationSchema.parse_obj(serve_application_schema) -class TestDeploymentStatusSchema: - def get_valid_deployment_status_schema(self): - return { - "deployment_1": DeploymentStatusInfo(DeploymentStatus.HEALTHY), - "deployment_2": DeploymentStatusInfo( - DeploymentStatus.UNHEALTHY, "This is an unhealthy deployment." +class TestServeStatusSchema: + def get_valid_serve_status_schema(self): + return StatusOverview( + app_status=ApplicationStatusInfo( + status="DEPLOYING", + message="", + deployment_timestamp=time.time(), ), - "deployment_3": DeploymentStatusInfo(DeploymentStatus.UPDATING), - } - - def test_valid_deployment_status_schema(self): - # Ensure valid DeploymentStatusSchemas can be generated - - deployment_status_schemas = self.get_valid_deployment_status_schema() - - for name, status_info in deployment_status_schemas.items(): - status_info_to_schema(name, status_info) - - def test_invalid_status(self): - # Ensure a DeploymentStatusSchema cannot be initialized with an invalid status - - status_info = { - "status": "nonexistent status", - "message": "welcome to nonexistence", - } - with pytest.raises(ValidationError): - status_info_to_schema("deployment name", status_info) - - def test_extra_fields_invalid_deployment_status_schema(self): - # Undefined fields should be forbidden in the schema - - deployment_status_schemas = self.get_valid_deployment_status_schema() - - # Schema should be createable with valid fields - for name, status_info in deployment_status_schemas.items(): - DeploymentStatusSchema( - name=name, status=status_info.status, message=status_info.message - ) - - # Schema should raise error when a nonspecified field is included - for name, status_info in deployment_status_schemas.items(): - with pytest.raises(ValidationError): - DeploymentStatusSchema( - name=name, - status=status_info.status, - message=status_info.message, - fake_field=None, - ) - - -class TestServeApplicationStatusSchema: - def get_valid_serve_application_status_schema(self): - return { - "deployment_1": {"status": "HEALTHY", "message": ""}, - "deployment_2": { - "status": "UNHEALTHY", - "message": "this deployment is deeply unhealthy", - }, - } + deployment_statuses=[ + DeploymentStatusInfo( + name="deployment_1", + status="HEALTHY", + message="", + ), + DeploymentStatusInfo( + name="deployment_2", + status="UNHEALTHY", + message="this deployment is deeply unhealthy", + ), + ], + ) - def test_valid_serve_application_status_schema(self): - # Ensure a valid ServeApplicationStatusSchema can be generated + def test_valid_serve_status_schema(self): + # Ensure a valid ServeStatusSchema can be generated - serve_application_status_schema = ( - self.get_valid_serve_application_status_schema() - ) - serve_application_status_to_schema(serve_application_status_schema) + serve_status_schema = self.get_valid_serve_status_schema() + serve_status_to_schema(serve_status_schema) - def test_extra_fields_invalid_serve_application_status_schema(self): + def test_extra_fields_invalid_serve_status_schema(self): # Undefined fields should be forbidden in the schema - serve_application_status_schema = ( - self.get_valid_serve_application_status_schema() - ) + serve_status_schema = self.get_valid_serve_status_schema() # Schema should be createable with valid fields - serve_application_status_to_schema(serve_application_status_schema) + serve_status_to_schema(serve_status_schema) # Schema should raise error when a nonspecified field is included with pytest.raises(ValidationError): - statuses = [ - status_info_to_schema(name, status_info) - for name, status_info in serve_application_status_schema.items() - ] - ServeApplicationStatusSchema(statuses=statuses, fake_field=None) + ServeStatusSchema( + app_status=serve_status_schema.app_status, + deployment_statuses=[], + fake_field=None, + ) # This function is defined globally to be accessible via import path @@ -680,13 +640,13 @@ def f2(): f1._func_or_class = "ray.serve.tests.test_schema.global_f" f2._func_or_class = "ray.serve.tests.test_schema.global_f" - serve.start() + client = serve.start() f1.deploy() f2.deploy() # Check statuses - statuses = serve_application_status_to_schema(get_deployment_statuses()).statuses + statuses = serve_status_to_schema(client.get_serve_status()).deployment_statuses deployment_names = {"f1", "f2"} for deployment_status in statuses: assert deployment_status.status in {"UPDATING", "HEALTHY"} diff --git a/python/ray/serve/tests/test_standalone2.py b/python/ray/serve/tests/test_standalone2.py index 60e231bf1eb4..b93c84249bde 100644 --- a/python/ray/serve/tests/test_standalone2.py +++ b/python/ray/serve/tests/test_standalone2.py @@ -195,6 +195,26 @@ def controller_died(handle): ray.shutdown() +def test_get_serve_status(shutdown_ray): + + ray.init() + client = serve.start() + + @serve.deployment + def f(*args): + return "Hello world" + + f.deploy() + + status_info_1 = client.get_serve_status() + assert status_info_1.app_status.status == "RUNNING" + assert status_info_1.deployment_statuses[0].name == "f" + assert status_info_1.deployment_statuses[0].status in {"UPDATING", "HEALTHY"} + + serve.shutdown() + ray.shutdown() + + def test_shutdown_remote(start_and_shutdown_ray_cli): """Check that serve.shutdown() works on a remote Ray cluster.""" diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index d32c9258addc..df0be478ed67 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -59,7 +59,6 @@ py_test_module_list( "test_healthcheck.py", "test_kill_raylet_signal_log.py", "test_memstat.py", - "test_mldataset.py", ], size = "medium", extra_srcs = SRCS, @@ -221,6 +220,7 @@ py_test_module_list( "test_threaded_actor.py", "test_stress_failure.py", "test_reconstruction.py", + "test_reconstruction_2.py", "test_failure_2.py", "test_failure_3.py", "test_chaos.py", diff --git a/python/ray/tests/test_actor_pool.py b/python/ray/tests/test_actor_pool.py index b02cdf331abc..27d40e505541 100644 --- a/python/ray/tests/test_actor_pool.py +++ b/python/ray/tests/test_actor_pool.py @@ -1,3 +1,4 @@ +import asyncio import sys import time import pytest @@ -101,11 +102,15 @@ def double(self, x): def test_map_gh23107(init): + sleep_time = 40 + # Reference - https://github.com/ray-project/ray/issues/23107 @ray.remote class DummyActor: async def identity(self, s): - return s + if s == 6: + await asyncio.sleep(sleep_time) + return s, time.time() def func(a, v): return a.identity.remote(v) @@ -114,13 +119,21 @@ def func(a, v): pool_map = ActorPool([DummyActor.remote() for i in range(2)]) pool_map.submit(func, 6) + start_time = time.time() gen = pool_map.map(func, map_values) - assert list(gen) == [1, 2, 3, 4, 5] + assert all(elem[0] in [1, 2, 3, 4, 5] for elem in list(gen)) + assert all( + abs(elem[1] - start_time) < sleep_time in [1, 2, 3, 4, 5] for elem in list(gen) + ) pool_map_unordered = ActorPool([DummyActor.remote() for i in range(2)]) pool_map_unordered.submit(func, 6) - gen = pool_map_unordered.map(func, map_values) - assert all(elem in [1, 2, 3, 4, 5] for elem in list(gen)) + start_time = time.time() + gen = pool_map_unordered.map_unordered(func, map_values) + assert all(elem[0] in [1, 2, 3, 4, 5] for elem in list(gen)) + assert all( + abs(elem[1] - start_time) < sleep_time in [1, 2, 3, 4, 5] for elem in list(gen) + ) def test_get_next_timeout(init): diff --git a/python/ray/tests/test_cli_patterns/test_ray_up.txt b/python/ray/tests/test_cli_patterns/test_ray_up.txt index a9da9ee7d70a..9c72fac96746 100644 --- a/python/ray/tests/test_cli_patterns/test_ray_up.txt +++ b/python/ray/tests/test_cli_patterns/test_ray_up.txt @@ -32,7 +32,7 @@ Acquiring an up-to-date head node \[3/7\] No worker file mounts to sync New status: setting-up \[4/7\] Running initialization commands - \[5/7\] Initalizing command runner + \[5/7\] Initializing command runner \[6/7\] Running setup commands \(0/4\) echo a \(1/4\) echo b diff --git a/python/ray/tests/test_cli_patterns/test_ray_up_docker.txt b/python/ray/tests/test_cli_patterns/test_ray_up_docker.txt index a9da9ee7d70a..9c72fac96746 100644 --- a/python/ray/tests/test_cli_patterns/test_ray_up_docker.txt +++ b/python/ray/tests/test_cli_patterns/test_ray_up_docker.txt @@ -32,7 +32,7 @@ Acquiring an up-to-date head node \[3/7\] No worker file mounts to sync New status: setting-up \[4/7\] Running initialization commands - \[5/7\] Initalizing command runner + \[5/7\] Initializing command runner \[6/7\] Running setup commands \(0/4\) echo a \(1/4\) echo b diff --git a/python/ray/tests/test_cli_patterns/test_ray_up_record.txt b/python/ray/tests/test_cli_patterns/test_ray_up_record.txt index f313e2db6ccb..401c0bec06bf 100644 --- a/python/ray/tests/test_cli_patterns/test_ray_up_record.txt +++ b/python/ray/tests/test_cli_patterns/test_ray_up_record.txt @@ -58,7 +58,7 @@ .+\.py.*Running `echo init` .+\.py.*Full command is `ssh.+` .+\.py.*NodeUpdater: i-.+: Initialization commands succeeded \[LogTimer=.+\] -.+\.py.*\[5/7\] Initalizing command runner +.+\.py.*\[5/7\] Initializing command runner .+\.py.*\[6/7\] Running setup commands .+\.py.*\(0/4\) echo a .+\.py.*Running `echo a` diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index 52a3c3c3d7a0..2e9025f41a7a 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -9,16 +9,11 @@ from ray._private.test_utils import ( generate_system_config_map, - run_string_as_driver_nonblocking, wait_for_condition, wait_for_pid_to_exit, convert_actor_state, ) -import logging - -logger = logging.getLogger(__name__) - @ray.remote class Increase: @@ -352,72 +347,31 @@ def ready(self): ], indirect=True, ) -def test_py_resubscription(tmp_path, ray_start_regular_with_external_redis): - # This test is to ensure python pubsub works - from filelock import FileLock - - lock_file1 = str(tmp_path / "lock1") - lock1 = FileLock(lock_file1) - lock1.acquire() - - lock_file2 = str(tmp_path / "lock2") - lock2 = FileLock(lock_file2) - - script = f""" -from filelock import FileLock -import ray - -@ray.remote -def f(): - print("OK1", flush=True) - # wait until log_monitor push this - from time import sleep - sleep(2) - lock1 = FileLock(r"{lock_file1}") - lock2 = FileLock(r"{lock_file2}") - - lock2.acquire() - lock1.acquire() - - # wait until log_monitor push this - from time import sleep - sleep(2) - print("OK2", flush=True) - -ray.init(address='auto') -ray.get(f.remote()) -ray.shutdown() -""" - proc = run_string_as_driver_nonblocking(script) - - def condition(): - import filelock +def test_detached_actor_restarts(ray_start_regular_with_external_redis): + # Detached actors are owned by GCS. This test is to ensure detached actors + # can restart even GCS restarts. - try: - lock2.acquire(timeout=1) - except filelock.Timeout: - return True + @ray.remote + class A: + def ready(self): + import os - lock2.release() - return False + return os.getpid() - # make sure the script has printed "OK1" - wait_for_condition(condition, timeout=10) + a = A.options(name="a", lifetime="detached", max_restarts=-1).remote() + pid = ray.get(a.ready.remote()) ray.worker._global_node.kill_gcs_server() - import time - - time.sleep(2) + p = psutil.Process(pid) + p.kill() ray.worker._global_node.start_gcs_server() - lock1.release() - proc.wait() - output = proc.stdout.read() - # Print logs which are useful for debugging in CI - print("=================== OUTPUTS ============") - print(output.decode()) - assert b"OK1" in output - assert b"OK2" in output + while True: + try: + assert ray.get(a.ready.remote()) != pid + break + except ray.exceptions.RayActorError: + continue @pytest.mark.parametrize("auto_reconnect", [True, False]) diff --git a/python/ray/tests/test_gcs_pubsub.py b/python/ray/tests/test_gcs_pubsub.py index b2d47bad7e59..8d0942abd83f 100644 --- a/python/ray/tests/test_gcs_pubsub.py +++ b/python/ray/tests/test_gcs_pubsub.py @@ -1,7 +1,6 @@ import sys import threading -import ray from ray._private.gcs_pubsub import ( GcsPublisher, GcsErrorSubscriber, @@ -35,72 +34,6 @@ def test_publish_and_subscribe_error_info(ray_start_regular): subscriber.close() -def test_publish_and_subscribe_error_info_ft(ray_start_regular_with_external_redis): - address_info = ray_start_regular_with_external_redis - gcs_server_addr = address_info["gcs_address"] - from threading import Barrier, Thread - - subscriber = GcsErrorSubscriber(address=gcs_server_addr) - subscriber.subscribe() - - publisher = GcsPublisher(address=gcs_server_addr) - - err1 = ErrorTableData(error_message="test error message 1") - err2 = ErrorTableData(error_message="test error message 2") - err3 = ErrorTableData(error_message="test error message 3") - err4 = ErrorTableData(error_message="test error message 4") - b = Barrier(3) - - def publisher_func(): - print("Publisher HERE") - publisher.publish_error(b"aaa_id", err1) - publisher.publish_error(b"bbb_id", err2) - - b.wait() - - print("Publisher HERE") - # Wait fo subscriber to subscribe first. - # It's ok to loose log messages. - from time import sleep - - sleep(5) - publisher.publish_error(b"aaa_id", err3) - print("pub err1") - publisher.publish_error(b"bbb_id", err4) - print("pub err2") - print("DONE") - - def subscriber_func(): - print("Subscriber HERE") - assert subscriber.poll() == (b"aaa_id", err1) - assert subscriber.poll() == (b"bbb_id", err2) - - b.wait() - assert subscriber.poll() == (b"aaa_id", err3) - print("sub err1") - assert subscriber.poll() == (b"bbb_id", err4) - print("sub err2") - - subscriber.close() - print("DONE") - - t1 = Thread(target=publisher_func) - t2 = Thread(target=subscriber_func) - t1.start() - t2.start() - b.wait() - - ray.worker._global_node.kill_gcs_server() - from time import sleep - - sleep(1) - ray.worker._global_node.start_gcs_server() - sleep(1) - - t1.join() - t2.join() - - @pytest.mark.asyncio async def test_aio_publish_and_subscribe_error_info(ray_start_regular): address_info = ray_start_regular @@ -121,61 +54,6 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular): await subscriber.close() -@pytest.mark.asyncio -async def test_aio_publish_and_subscribe_error_info_ft( - ray_start_regular_with_external_redis, -): - address_info = ray_start_regular_with_external_redis - gcs_server_addr = address_info["gcs_address"] - - subscriber = GcsAioErrorSubscriber(address=gcs_server_addr) - await subscriber.subscribe() - - err1 = ErrorTableData(error_message="test error message 1") - err2 = ErrorTableData(error_message="test error message 2") - err3 = ErrorTableData(error_message="test error message 3") - err4 = ErrorTableData(error_message="test error message 4") - - def restart_gcs_server(): - import asyncio - - asyncio.set_event_loop(asyncio.new_event_loop()) - from time import sleep - - publisher = GcsAioPublisher(address=gcs_server_addr) - asyncio.get_event_loop().run_until_complete( - publisher.publish_error(b"aaa_id", err1) - ) - asyncio.get_event_loop().run_until_complete( - publisher.publish_error(b"bbb_id", err2) - ) - - # wait until subscribe consume everything - sleep(5) - ray.worker._global_node.kill_gcs_server() - sleep(1) - ray.worker._global_node.start_gcs_server() - # wait until subscriber resubscribed - sleep(5) - - asyncio.get_event_loop().run_until_complete( - publisher.publish_error(b"aaa_id", err3) - ) - asyncio.get_event_loop().run_until_complete( - publisher.publish_error(b"bbb_id", err4) - ) - - t1 = threading.Thread(target=restart_gcs_server) - t1.start() - assert await subscriber.poll() == (b"aaa_id", err1) - assert await subscriber.poll() == (b"bbb_id", err2) - assert await subscriber.poll() == (b"aaa_id", err3) - assert await subscriber.poll() == (b"bbb_id", err4) - - await subscriber.close() - t1.join() - - def test_publish_and_subscribe_logs(ray_start_regular): address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"] diff --git a/python/ray/tests/test_mldataset.py b/python/ray/tests/test_mldataset.py deleted file mode 100644 index 9faba8a3063a..000000000000 --- a/python/ray/tests/test_mldataset.py +++ /dev/null @@ -1,147 +0,0 @@ -import ray.util.iter as parallel_it -import ray.util.data as ml_data -import pytest - -import pyarrow as pa -import pyarrow.parquet as pq -import pandas as pd -import numpy as np -import os - - -def test_read_parquet(ray_start_regular_shared, tmp_path): - df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - table = pa.Table.from_pandas(df1) - pq.write_table(table, os.path.join(tmp_path, "test1.parquet")) - df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - table = pa.Table.from_pandas(df2) - pq.write_table(table, os.path.join(tmp_path, "test2.parquet")) - - # without columns - ds = ml_data.read_parquet(tmp_path, num_shards=2) - result = list(ds.gather_sync()) - assert df1.equals(result[0]) - assert df2.equals(result[1]) - - # with columns one - ds = ml_data.read_parquet(tmp_path, num_shards=2, columns=["one"]) - result = list(ds.gather_sync()) - assert df1[["one"]].equals(result[0]) - assert df2[["one"]].equals(result[1]) - - # with columns two - ds = ml_data.read_parquet(tmp_path, num_shards=2, columns=["two"]) - result = list(ds.gather_sync()) - assert df1[["two"]].equals(result[0]) - assert df2[["two"]].equals(result[1]) - - -def test_from_parallel_it(ray_start_regular_shared): - para_it = parallel_it.from_range(4).for_each(lambda x: [x]) - ds = ml_data.from_parallel_iter(para_it, batch_size=2) - assert repr(ds) == ( - "MLDataset[from_range[4, shards=2].for_each().batch(2).to_pandas()]" - ) - collected = list(ds.gather_sync()) - assert len(collected) == 2 - assert all(d.shape == (2, 1) for d in collected) - expected = para_it.flatten().batch(2).gather_sync().flatten() - flattened = ds.gather_sync().for_each(lambda x: x[0].to_list()).flatten() - assert list(flattened) == list(expected) - - -def test_batch(ray_start_regular_shared): - para_it = parallel_it.from_range(16).for_each(lambda x: [x]) - ds = ml_data.from_parallel_iter(para_it, batch_size=2) - collected = list(ds.gather_sync()) - assert len(collected) == 8 - assert all(d.shape == (2, 1) for d in collected) - - ds = ds.batch(4) - assert repr(ds) == ( - "MLDataset[from_range[16, shards=2]" - ".for_each().batch(2).to_pandas().batch(4)]" - ) - collected = list(ds.gather_sync()) - assert len(collected) == 4 - assert all(d.shape == (4, 1) for d in collected) - expected = para_it.flatten().batch(4).gather_sync().flatten() - flattened = ds.gather_sync().for_each(lambda x: x[0].to_list()).flatten() - assert list(flattened) == list(expected) - - -def test_local_shuffle(ray_start_regular_shared): - para_it = parallel_it.from_range(100).for_each(lambda x: [x]) - - # batch_size larger than 1 and shuffle_buffer_size larger than 1 - ds = ml_data.from_parallel_iter(para_it, batch_size=10) - ds1 = ds.local_shuffle(shuffle_buffer_size=5) - ds2 = ds.local_shuffle(shuffle_buffer_size=5) - - l1 = list(ds1.gather_sync()) - l2 = list(ds2.gather_sync()) - assert not all(df1.equals(df2) for df1, df2 in zip(l1, l2)) - - # batch_size equals 1 and shuffle_buffer_size larger than 1 - ds = ml_data.from_parallel_iter(para_it, batch_size=1) - ds1 = ds.local_shuffle(shuffle_buffer_size=5) - ds2 = ds.local_shuffle(shuffle_buffer_size=5) - - l1 = list(ds1.gather_sync()) - l2 = list(ds2.gather_sync()) - assert not all(df1.equals(df2) for df1, df2 in zip(l1, l2)) - - # batch_size equals 1 and shuffle_buffer_size equals 1 - ds = ml_data.from_parallel_iter(para_it, batch_size=1) - ds1 = ds.local_shuffle(shuffle_buffer_size=1) - ds2 = ds.local_shuffle(shuffle_buffer_size=1) - - l1 = list(ds1.gather_sync()) - l2 = list(ds2.gather_sync()) - assert all(df1.equals(df2) for df1, df2 in zip(l1, l2)) - - -def test_union(ray_start_regular_shared): - para_it1 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x]) - ds1 = ml_data.from_parallel_iter(para_it1, True, 2, False) - para_it2 = parallel_it.from_range(4, 2, True).for_each(lambda x: [x]) - ds2 = ml_data.from_parallel_iter(para_it2, True, 2, True) - - with pytest.raises(TypeError) as ex: - ds1.union(ds2) - assert "two MLDataset which have different repeated type" in str(ex.value) - - # union two MLDataset with same batch size - para_it2 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x]) - ds2 = ml_data.from_parallel_iter(para_it2, True, 2, False) - ds = ds1.union(ds2) - assert ds.batch_size == 2 - - # union two MLDataset with different batch size - para_it2 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x]) - ds2 = ml_data.from_parallel_iter(para_it2, True, 1, False) - ds = ds1.union(ds2) - # batch_size 0 means batch_size unknown - assert ds.batch_size == 0 - - -@pytest.mark.skipif( - True, reason="Broken on all platforms (incorrect use of gather_sync())" -) -def test_from_modin(ray_start_regular_shared): - try: - import modin.pandas as pd - except ImportError: - pytest.mark.skip(reason="Modin is not installed") - return - - df = pd.DataFrame(np.random.randint(0, 100, size=(2 ** 8, 16))).add_prefix("col") - ds = ml_data.MLDataset.from_modin(df, 2) - # Not guaranteed to maintain order, so sort to ensure equality - assert df._to_pandas().sort_index().equals(ds.gather_sync().sort_index()) - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_reconstruction.py b/python/ray/tests/test_reconstruction.py index ba23be897f66..a157af43ca50 100644 --- a/python/ray/tests/test_reconstruction.py +++ b/python/ray/tests/test_reconstruction.py @@ -10,19 +10,10 @@ from ray._private.test_utils import ( wait_for_condition, wait_for_pid_to_exit, - SignalActor, - Semaphore, ) -from ray.internal.internal_api import memory_summary SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM -# Task status. -WAITING_FOR_DEPENDENCIES = "WAITING_FOR_DEPENDENCIES" -SCHEDULED = "SCHEDULED" -FINISHED = "FINISHED" -WAITING_FOR_EXECUTION = "WAITING_FOR_EXECUTION" - def test_cached_object(ray_start_cluster): config = { @@ -711,407 +702,6 @@ def dependent_task(x): i += 1 -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -@pytest.mark.parametrize("reconstruction_enabled", [False, True]) -def test_nondeterministic_output(ray_start_cluster, reconstruction_enabled): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "max_direct_call_object_size": 100, - "task_retry_delay_ms": 100, - "object_timeout_milliseconds": 200, - } - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, _system_config=config, enable_object_reconstruction=True - ) - ray.init(address=cluster.address) - # Node to place the initial object. - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - cluster.wait_for_nodes() - - @ray.remote - def nondeterministic_object(): - if np.random.rand() < 0.5: - return np.zeros(10 ** 5, dtype=np.uint8) - else: - return 0 - - @ray.remote - def dependent_task(x): - return - - for _ in range(10): - obj = nondeterministic_object.options(resources={"node1": 1}).remote() - for _ in range(3): - ray.get(dependent_task.remote(obj)) - x = dependent_task.remote(obj) - cluster.remove_node(node_to_kill, allow_graceful=False) - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - ray.get(x) - - -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -def test_reconstruction_hangs(ray_start_cluster): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "max_direct_call_object_size": 100, - "task_retry_delay_ms": 100, - "object_timeout_milliseconds": 200, - "fetch_warn_timeout_milliseconds": 1000, - } - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, _system_config=config, enable_object_reconstruction=True - ) - ray.init(address=cluster.address) - # Node to place the initial object. - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - cluster.wait_for_nodes() - - @ray.remote - def sleep(): - # Task takes longer than the reconstruction timeout. - time.sleep(3) - return np.zeros(10 ** 5, dtype=np.uint8) - - @ray.remote - def dependent_task(x): - return - - obj = sleep.options(resources={"node1": 1}).remote() - for _ in range(3): - ray.get(dependent_task.remote(obj)) - x = dependent_task.remote(obj) - cluster.remove_node(node_to_kill, allow_graceful=False) - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - ray.get(x) - - -def test_lineage_evicted(ray_start_cluster): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "object_timeout_milliseconds": 200, - "max_lineage_bytes": 10_000, - } - - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, - _system_config=config, - object_store_memory=10 ** 8, - enable_object_reconstruction=True, - ) - ray.init(address=cluster.address) - node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - cluster.wait_for_nodes() - - @ray.remote - def large_object(): - return np.zeros(10 ** 7, dtype=np.uint8) - - @ray.remote - def chain(x): - return x - - @ray.remote - def dependent_task(x): - return x - - obj = large_object.remote() - for _ in range(5): - obj = chain.remote(obj) - ray.get(dependent_task.remote(obj)) - - cluster.remove_node(node_to_kill, allow_graceful=False) - node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - ray.get(dependent_task.remote(obj)) - - # Lineage now exceeds the eviction factor. - for _ in range(100): - obj = chain.remote(obj) - ray.get(dependent_task.remote(obj)) - - cluster.remove_node(node_to_kill, allow_graceful=False) - cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - try: - ray.get(dependent_task.remote(obj)) - assert False - except ray.exceptions.RayTaskError as e: - assert "ObjectReconstructionFailedLineageEvictedError" in str(e) - - -@pytest.mark.parametrize("reconstruction_enabled", [False, True]) -def test_multiple_returns(ray_start_cluster, reconstruction_enabled): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "object_timeout_milliseconds": 200, - } - # Workaround to reset the config to the default value. - if not reconstruction_enabled: - config["lineage_pinning_enabled"] = False - - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, - _system_config=config, - enable_object_reconstruction=reconstruction_enabled, - ) - ray.init(address=cluster.address) - # Node to place the initial object. - node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - cluster.wait_for_nodes() - - @ray.remote(num_returns=2) - def two_large_objects(): - return (np.zeros(10 ** 7, dtype=np.uint8), np.zeros(10 ** 7, dtype=np.uint8)) - - @ray.remote - def dependent_task(x): - return - - obj1, obj2 = two_large_objects.remote() - ray.get(dependent_task.remote(obj1)) - cluster.add_node(num_cpus=1, resources={"node": 1}, object_store_memory=10 ** 8) - ray.get(dependent_task.options(resources={"node": 1}).remote(obj1)) - - cluster.remove_node(node_to_kill, allow_graceful=False) - wait_for_condition( - lambda: not all(node["Alive"] for node in ray.nodes()), timeout=10 - ) - - if reconstruction_enabled: - ray.get(dependent_task.remote(obj1)) - ray.get(dependent_task.remote(obj2)) - else: - with pytest.raises(ray.exceptions.RayTaskError): - ray.get(dependent_task.remote(obj1)) - ray.get(dependent_task.remote(obj2)) - with pytest.raises(ray.exceptions.ObjectLostError): - ray.get(obj2) - - -@pytest.mark.parametrize("reconstruction_enabled", [False, True]) -def test_nested(ray_start_cluster, reconstruction_enabled): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "object_timeout_milliseconds": 200, - "fetch_fail_timeout_milliseconds": 10_000, - } - # Workaround to reset the config to the default value. - if not reconstruction_enabled: - config["lineage_pinning_enabled"] = False - - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, - _system_config=config, - enable_object_reconstruction=reconstruction_enabled, - ) - ray.init(address=cluster.address) - done_signal = SignalActor.remote() - exit_signal = SignalActor.remote() - ray.get(done_signal.wait.remote(should_wait=False)) - ray.get(exit_signal.wait.remote(should_wait=False)) - - # Node to place the initial object. - node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) - cluster.wait_for_nodes() - - @ray.remote - def dependent_task(x): - return - - @ray.remote - def large_object(): - return np.zeros(10 ** 7, dtype=np.uint8) - - @ray.remote - def nested(done_signal, exit_signal): - ref = ray.put(np.zeros(10 ** 7, dtype=np.uint8)) - # Flush object store. - for _ in range(20): - ray.put(np.zeros(10 ** 7, dtype=np.uint8)) - dep = dependent_task.options(resources={"node": 1}).remote(ref) - ray.get(done_signal.send.remote(clear=True)) - ray.get(dep) - return ray.get(ref) - - ref = nested.remote(done_signal, exit_signal) - # Wait for task to get scheduled on the node to kill. - ray.get(done_signal.wait.remote()) - # Wait for ray.put object to get transferred to the other node. - cluster.add_node(num_cpus=2, resources={"node": 10}, object_store_memory=10 ** 8) - ray.get(dependent_task.remote(ref)) - - # Destroy the task's output. - cluster.remove_node(node_to_kill, allow_graceful=False) - wait_for_condition( - lambda: not all(node["Alive"] for node in ray.nodes()), timeout=10 - ) - - if reconstruction_enabled: - ray.get(ref, timeout=60) - else: - with pytest.raises(ray.exceptions.ObjectLostError): - ray.get(ref, timeout=60) - - -@pytest.mark.parametrize("reconstruction_enabled", [False, True]) -def test_spilled(ray_start_cluster, reconstruction_enabled): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "object_timeout_milliseconds": 200, - } - # Workaround to reset the config to the default value. - if not reconstruction_enabled: - config["lineage_pinning_enabled"] = False - - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, - _system_config=config, - enable_object_reconstruction=reconstruction_enabled, - ) - ray.init(address=cluster.address) - # Node to place the initial object. - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - cluster.wait_for_nodes() - - @ray.remote(max_retries=1 if reconstruction_enabled else 0) - def large_object(): - return np.zeros(10 ** 7, dtype=np.uint8) - - @ray.remote - def dependent_task(x): - return - - obj = large_object.options(resources={"node1": 1}).remote() - ray.get(dependent_task.options(resources={"node1": 1}).remote(obj)) - # Force spilling. - objs = [large_object.options(resources={"node1": 1}).remote() for _ in range(20)] - for o in objs: - ray.get(o) - - cluster.remove_node(node_to_kill, allow_graceful=False) - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - - if reconstruction_enabled: - ray.get(dependent_task.remote(obj), timeout=60) - else: - with pytest.raises(ray.exceptions.RayTaskError): - ray.get(dependent_task.remote(obj), timeout=60) - with pytest.raises(ray.exceptions.ObjectLostError): - ray.get(obj, timeout=60) - - -def test_memory_util(ray_start_cluster): - config = { - "num_heartbeats_timeout": 10, - "raylet_heartbeat_period_milliseconds": 100, - "object_timeout_milliseconds": 200, - } - - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, - resources={"head": 1}, - _system_config=config, - enable_object_reconstruction=True, - ) - ray.init(address=cluster.address) - # Node to place the initial object. - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - cluster.wait_for_nodes() - - @ray.remote - def large_object(sema=None): - if sema is not None: - ray.get(sema.acquire.remote()) - return np.zeros(10 ** 7, dtype=np.uint8) - - @ray.remote - def dependent_task(x, sema): - ray.get(sema.acquire.remote()) - return x - - def stats(): - info = memory_summary(cluster.address, line_wrap=False) - print(info) - info = info.split("\n") - reconstructing_waiting = [ - line - for line in info - if "Attempt #2" in line and WAITING_FOR_DEPENDENCIES in line - ] - reconstructing_scheduled = [ - line - for line in info - if "Attempt #2" in line and WAITING_FOR_EXECUTION in line - ] - reconstructing_finished = [ - line for line in info if "Attempt #2" in line and FINISHED in line - ] - return ( - len(reconstructing_waiting), - len(reconstructing_scheduled), - len(reconstructing_finished), - ) - - sema = Semaphore.options(resources={"head": 1}).remote(value=0) - obj = large_object.options(resources={"node1": 1}).remote(sema) - x = dependent_task.options(resources={"node1": 1}).remote(obj, sema) - ref = dependent_task.options(resources={"node1": 1}).remote(x, sema) - ray.get(sema.release.remote()) - ray.get(sema.release.remote()) - ray.get(sema.release.remote()) - ray.get(ref) - wait_for_condition(lambda: stats() == (0, 0, 0)) - del ref - - cluster.remove_node(node_to_kill, allow_graceful=False) - node_to_kill = cluster.add_node( - num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 - ) - - ref = dependent_task.remote(x, sema) - wait_for_condition(lambda: stats() == (1, 1, 0)) - ray.get(sema.release.remote()) - wait_for_condition(lambda: stats() == (0, 1, 1)) - ray.get(sema.release.remote()) - ray.get(sema.release.remote()) - ray.get(ref) - wait_for_condition(lambda: stats() == (0, 0, 2)) - - if __name__ == "__main__": import pytest diff --git a/python/ray/tests/test_reconstruction_2.py b/python/ray/tests/test_reconstruction_2.py new file mode 100644 index 000000000000..237dfeb1ba3c --- /dev/null +++ b/python/ray/tests/test_reconstruction_2.py @@ -0,0 +1,426 @@ +import sys +import time + +import numpy as np +import pytest + +import ray +from ray._private.test_utils import ( + wait_for_condition, + SignalActor, + Semaphore, +) +from ray.internal.internal_api import memory_summary + +# Task status. +WAITING_FOR_DEPENDENCIES = "WAITING_FOR_DEPENDENCIES" +SCHEDULED = "SCHEDULED" +FINISHED = "FINISHED" +WAITING_FOR_EXECUTION = "WAITING_FOR_EXECUTION" + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.parametrize("reconstruction_enabled", [False, True]) +def test_nondeterministic_output(ray_start_cluster, reconstruction_enabled): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "max_direct_call_object_size": 100, + "task_retry_delay_ms": 100, + "object_timeout_milliseconds": 200, + } + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, _system_config=config, enable_object_reconstruction=True + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + cluster.wait_for_nodes() + + @ray.remote + def nondeterministic_object(): + if np.random.rand() < 0.5: + return np.zeros(10 ** 5, dtype=np.uint8) + else: + return 0 + + @ray.remote + def dependent_task(x): + return + + for _ in range(10): + obj = nondeterministic_object.options(resources={"node1": 1}).remote() + for _ in range(3): + ray.get(dependent_task.remote(obj)) + x = dependent_task.remote(obj) + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + ray.get(x) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_reconstruction_hangs(ray_start_cluster): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "max_direct_call_object_size": 100, + "task_retry_delay_ms": 100, + "object_timeout_milliseconds": 200, + "fetch_warn_timeout_milliseconds": 1000, + } + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, _system_config=config, enable_object_reconstruction=True + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + cluster.wait_for_nodes() + + @ray.remote + def sleep(): + # Task takes longer than the reconstruction timeout. + time.sleep(3) + return np.zeros(10 ** 5, dtype=np.uint8) + + @ray.remote + def dependent_task(x): + return + + obj = sleep.options(resources={"node1": 1}).remote() + for _ in range(3): + ray.get(dependent_task.remote(obj)) + x = dependent_task.remote(obj) + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + ray.get(x) + + +def test_lineage_evicted(ray_start_cluster): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "object_timeout_milliseconds": 200, + "max_lineage_bytes": 10_000, + } + + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + _system_config=config, + object_store_memory=10 ** 8, + enable_object_reconstruction=True, + ) + ray.init(address=cluster.address) + node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + cluster.wait_for_nodes() + + @ray.remote + def large_object(): + return np.zeros(10 ** 7, dtype=np.uint8) + + @ray.remote + def chain(x): + return x + + @ray.remote + def dependent_task(x): + return x + + obj = large_object.remote() + for _ in range(5): + obj = chain.remote(obj) + ray.get(dependent_task.remote(obj)) + + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + ray.get(dependent_task.remote(obj)) + + # Lineage now exceeds the eviction factor. + for _ in range(100): + obj = chain.remote(obj) + ray.get(dependent_task.remote(obj)) + + cluster.remove_node(node_to_kill, allow_graceful=False) + cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + try: + ray.get(dependent_task.remote(obj)) + assert False + except ray.exceptions.RayTaskError as e: + assert "ObjectReconstructionFailedLineageEvictedError" in str(e) + + +@pytest.mark.parametrize("reconstruction_enabled", [False, True]) +def test_multiple_returns(ray_start_cluster, reconstruction_enabled): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "object_timeout_milliseconds": 200, + } + # Workaround to reset the config to the default value. + if not reconstruction_enabled: + config["lineage_pinning_enabled"] = False + + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + _system_config=config, + enable_object_reconstruction=reconstruction_enabled, + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + cluster.wait_for_nodes() + + @ray.remote(num_returns=2) + def two_large_objects(): + return (np.zeros(10 ** 7, dtype=np.uint8), np.zeros(10 ** 7, dtype=np.uint8)) + + @ray.remote + def dependent_task(x): + return + + obj1, obj2 = two_large_objects.remote() + ray.get(dependent_task.remote(obj1)) + cluster.add_node(num_cpus=1, resources={"node": 1}, object_store_memory=10 ** 8) + ray.get(dependent_task.options(resources={"node": 1}).remote(obj1)) + + cluster.remove_node(node_to_kill, allow_graceful=False) + wait_for_condition( + lambda: not all(node["Alive"] for node in ray.nodes()), timeout=10 + ) + + if reconstruction_enabled: + ray.get(dependent_task.remote(obj1)) + ray.get(dependent_task.remote(obj2)) + else: + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(dependent_task.remote(obj1)) + ray.get(dependent_task.remote(obj2)) + with pytest.raises(ray.exceptions.ObjectLostError): + ray.get(obj2) + + +@pytest.mark.parametrize("reconstruction_enabled", [False, True]) +def test_nested(ray_start_cluster, reconstruction_enabled): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "object_timeout_milliseconds": 200, + "fetch_fail_timeout_milliseconds": 10_000, + } + # Workaround to reset the config to the default value. + if not reconstruction_enabled: + config["lineage_pinning_enabled"] = False + + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + _system_config=config, + enable_object_reconstruction=reconstruction_enabled, + ) + ray.init(address=cluster.address) + done_signal = SignalActor.remote() + exit_signal = SignalActor.remote() + ray.get(done_signal.wait.remote(should_wait=False)) + ray.get(exit_signal.wait.remote(should_wait=False)) + + # Node to place the initial object. + node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10 ** 8) + cluster.wait_for_nodes() + + @ray.remote + def dependent_task(x): + return + + @ray.remote + def large_object(): + return np.zeros(10 ** 7, dtype=np.uint8) + + @ray.remote + def nested(done_signal, exit_signal): + ref = ray.put(np.zeros(10 ** 7, dtype=np.uint8)) + # Flush object store. + for _ in range(20): + ray.put(np.zeros(10 ** 7, dtype=np.uint8)) + dep = dependent_task.options(resources={"node": 1}).remote(ref) + ray.get(done_signal.send.remote(clear=True)) + ray.get(dep) + return ray.get(ref) + + ref = nested.remote(done_signal, exit_signal) + # Wait for task to get scheduled on the node to kill. + ray.get(done_signal.wait.remote()) + # Wait for ray.put object to get transferred to the other node. + cluster.add_node(num_cpus=2, resources={"node": 10}, object_store_memory=10 ** 8) + ray.get(dependent_task.remote(ref)) + + # Destroy the task's output. + cluster.remove_node(node_to_kill, allow_graceful=False) + wait_for_condition( + lambda: not all(node["Alive"] for node in ray.nodes()), timeout=10 + ) + + if reconstruction_enabled: + ray.get(ref, timeout=60) + else: + with pytest.raises(ray.exceptions.ObjectLostError): + ray.get(ref, timeout=60) + + +@pytest.mark.parametrize("reconstruction_enabled", [False, True]) +def test_spilled(ray_start_cluster, reconstruction_enabled): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "object_timeout_milliseconds": 200, + } + # Workaround to reset the config to the default value. + if not reconstruction_enabled: + config["lineage_pinning_enabled"] = False + + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + _system_config=config, + enable_object_reconstruction=reconstruction_enabled, + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + cluster.wait_for_nodes() + + @ray.remote(max_retries=1 if reconstruction_enabled else 0) + def large_object(): + return np.zeros(10 ** 7, dtype=np.uint8) + + @ray.remote + def dependent_task(x): + return + + obj = large_object.options(resources={"node1": 1}).remote() + ray.get(dependent_task.options(resources={"node1": 1}).remote(obj)) + # Force spilling. + objs = [large_object.options(resources={"node1": 1}).remote() for _ in range(20)] + for o in objs: + ray.get(o) + + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + + if reconstruction_enabled: + ray.get(dependent_task.remote(obj), timeout=60) + else: + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(dependent_task.remote(obj), timeout=60) + with pytest.raises(ray.exceptions.ObjectLostError): + ray.get(obj, timeout=60) + + +def test_memory_util(ray_start_cluster): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "object_timeout_milliseconds": 200, + } + + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + resources={"head": 1}, + _system_config=config, + enable_object_reconstruction=True, + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + cluster.wait_for_nodes() + + @ray.remote + def large_object(sema=None): + if sema is not None: + ray.get(sema.acquire.remote()) + return np.zeros(10 ** 7, dtype=np.uint8) + + @ray.remote + def dependent_task(x, sema): + ray.get(sema.acquire.remote()) + return x + + def stats(): + info = memory_summary(cluster.address, line_wrap=False) + print(info) + info = info.split("\n") + reconstructing_waiting = [ + line + for line in info + if "Attempt #2" in line and WAITING_FOR_DEPENDENCIES in line + ] + reconstructing_scheduled = [ + line + for line in info + if "Attempt #2" in line and WAITING_FOR_EXECUTION in line + ] + reconstructing_finished = [ + line for line in info if "Attempt #2" in line and FINISHED in line + ] + return ( + len(reconstructing_waiting), + len(reconstructing_scheduled), + len(reconstructing_finished), + ) + + sema = Semaphore.options(resources={"head": 1}).remote(value=0) + obj = large_object.options(resources={"node1": 1}).remote(sema) + x = dependent_task.options(resources={"node1": 1}).remote(obj, sema) + ref = dependent_task.options(resources={"node1": 1}).remote(x, sema) + ray.get(sema.release.remote()) + ray.get(sema.release.remote()) + ray.get(sema.release.remote()) + ray.get(ref) + wait_for_condition(lambda: stats() == (0, 0, 0)) + del ref + + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + + ref = dependent_task.remote(x, sema) + wait_for_condition(lambda: stats() == (1, 1, 0)) + ray.get(sema.release.remote()) + wait_for_condition(lambda: stats() == (0, 1, 1)) + ray.get(sema.release.remote()) + ray.get(sema.release.remote()) + ray.get(ref) + wait_for_condition(lambda: stats() == (0, 0, 2)) + + +if __name__ == "__main__": + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_scheduling_2.py b/python/ray/tests/test_scheduling_2.py index 2f414f87c414..73649e2ba28c 100644 --- a/python/ray/tests/test_scheduling_2.py +++ b/python/ray/tests/test_scheduling_2.py @@ -647,6 +647,50 @@ def check_backlog_info(): cluster.shutdown() +def test_data_locality_spilled_objects( + ray_start_cluster_enabled, fs_only_object_spilling_config +): + cluster = ray_start_cluster_enabled + object_spilling_config, _ = fs_only_object_spilling_config + cluster.add_node( + num_cpus=1, + object_store_memory=100 * 1024 * 1024, + _system_config={ + "min_spilling_size": 1, + "object_spilling_config": object_spilling_config, + }, + ) + ray.init(cluster.address) + cluster.add_node( + num_cpus=1, object_store_memory=100 * 1024 * 1024, resources={"remote": 1} + ) + + @ray.remote(resources={"remote": 1}) + def f(): + return ( + np.zeros(50 * 1024 * 1024, dtype=np.uint8), + ray.runtime_context.get_runtime_context().node_id, + ) + + @ray.remote + def check_locality(x): + _, node_id = x + assert node_id == ray.runtime_context.get_runtime_context().node_id + + # Check locality works when dependent task is already submitted by the time + # the upstream task finishes. + for _ in range(5): + ray.get(check_locality.remote(f.remote())) + + # Check locality works when some objects were spilled. + xs = [f.remote() for _ in range(5)] + ray.wait(xs, num_returns=len(xs), fetch_local=False) + for i, x in enumerate(xs): + task = check_locality.remote(x) + print(i, x, task) + ray.get(task) + + if __name__ == "__main__": import pytest diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 1a1a0a8ed3a3..2c31f33dec73 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -8,7 +8,10 @@ from unittest.mock import MagicMock -from asyncmock import AsyncMock +if sys.version_info > (3, 7, 0): + from unittest.mock import AsyncMock +else: + from asyncmock import AsyncMock import ray import ray.ray_constants as ray_constants @@ -41,7 +44,11 @@ ) from ray.core.generated.runtime_env_agent_pb2 import GetRuntimeEnvsInfoReply import ray.dashboard.consts as dashboard_consts -from ray.dashboard.state_aggregator import StateAPIManager +from ray.dashboard.state_aggregator import ( + StateAPIManager, + GCS_QUERY_FAILURE_WARNING, + NODE_QUERY_FAILURE_WARNING, +) from ray.experimental.state.api import ( list_actors, list_placement_groups, @@ -64,9 +71,9 @@ DEFAULT_RPC_TIMEOUT, DEFAULT_LIMIT, ) +from ray.experimental.state.exception import DataSourceUnavailable, RayStateApiException from ray.experimental.state.state_manager import ( StateDataSourceClient, - StateSourceNetworkException, ) from ray.experimental.state.state_cli import ( list_state_cli_group, @@ -188,7 +195,7 @@ def generate_runtime_env_info(runtime_env, creation_time=None): def list_api_options(timeout: int = DEFAULT_RPC_TIMEOUT, limit: int = DEFAULT_LIMIT): - return ListApiOptions(limit=limit, timeout=timeout) + return ListApiOptions(limit=limit, timeout=timeout, _server_timeout_multiplier=1.0) @pytest.mark.asyncio @@ -199,15 +206,25 @@ async def test_api_manager_list_actors(state_api_manager): actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")] ) result = await state_api_manager.list_actors(option=list_api_options()) - actor_data = list(result.values())[0] + data = result.result + actor_data = list(data.values())[0] verify_schema(ActorState, actor_data) """ Test limit """ - assert len(result) == 2 + assert len(data) == 2 result = await state_api_manager.list_actors(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_actor_info.side_effect = DataSourceUnavailable() + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_actors(option=list_api_options(limit=1)) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING @pytest.mark.asyncio @@ -223,17 +240,31 @@ async def test_api_manager_list_pgs(state_api_manager): ) ) result = await state_api_manager.list_placement_groups(option=list_api_options()) - data = list(result.values())[0] + data = result.result + data = list(data.values())[0] verify_schema(PlacementGroupState, data) """ Test limit """ - assert len(result) == 2 + assert len(data) == 2 result = await state_api_manager.list_placement_groups( option=list_api_options(limit=1) ) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_placement_group_info.side_effect = ( + DataSourceUnavailable() + ) + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_placement_groups( + option=list_api_options(limit=1) + ) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING @pytest.mark.asyncio @@ -244,15 +275,25 @@ async def test_api_manager_list_nodes(state_api_manager): node_info_list=[generate_node_data(id), generate_node_data(b"12345")] ) result = await state_api_manager.list_nodes(option=list_api_options()) - data = list(result.values())[0] + data = result.result + data = list(data.values())[0] verify_schema(NodeState, data) """ Test limit """ - assert len(result) == 2 + assert len(data) == 2 result = await state_api_manager.list_nodes(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_node_info.side_effect = DataSourceUnavailable() + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_nodes(option=list_api_options(limit=1)) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING @pytest.mark.asyncio @@ -266,19 +307,30 @@ async def test_api_manager_list_workers(state_api_manager): ] ) result = await state_api_manager.list_workers(option=list_api_options()) - data = list(result.values())[0] + data = result.result + data = list(data.values())[0] verify_schema(WorkerState, data) """ Test limit """ - assert len(result) == 2 + assert len(result.result) == 2 result = await state_api_manager.list_workers(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable() + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_workers(option=list_api_options(limit=1)) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING -@pytest.mark.skip( - reason=("Not passing in CI although it works locally. Will handle it later.") +@pytest.mark.skipif( + sys.version_info <= (3, 7, 0), + reason=("Not passing in CI although it works locally. Will handle it later."), ) @pytest.mark.asyncio async def test_api_manager_list_tasks(state_api_manager): @@ -288,17 +340,19 @@ async def test_api_manager_list_tasks(state_api_manager): first_task_name = "1" second_task_name = "2" + data_source_client.get_task_info = AsyncMock() data_source_client.get_task_info.side_effect = [ generate_task_data(b"1234", first_task_name), generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=list_api_options()) - data_source_client.get_task_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT) - data_source_client.get_task_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT) - result = list(result.values()) - assert len(result) == 2 - verify_schema(TaskState, result[0]) - verify_schema(TaskState, result[1]) + data_source_client.get_task_info.assert_any_await("1", timeout=DEFAULT_RPC_TIMEOUT) + data_source_client.get_task_info.assert_any_await("2", timeout=DEFAULT_RPC_TIMEOUT) + data = result.result + data = list(data.values()) + assert len(data) == 2 + verify_schema(TaskState, data[0]) + verify_schema(TaskState, data[1]) """ Test limit @@ -308,11 +362,38 @@ async def test_api_manager_list_tasks(state_api_manager): generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_task_info.side_effect = [ + DataSourceUnavailable(), + generate_task_data(b"2345", second_task_name), + ] + result = await state_api_manager.list_tasks(option=list_api_options(limit=1)) + # Make sure warnings are returned. + warning = result.partial_failure_warning + assert ( + NODE_QUERY_FAILURE_WARNING.format( + type="raylet", total=2, network_failures=1, log_command="raylet.out" + ) + in warning + ) + + # Test if all RPCs fail, it will raise an exception. + data_source_client.get_task_info.side_effect = [ + DataSourceUnavailable(), + DataSourceUnavailable(), + ] + with pytest.raises(DataSourceUnavailable): + result = await state_api_manager.list_tasks(option=list_api_options(limit=1)) -@pytest.mark.skip( - reason=("Not passing in CI although it works locally. Will handle it later.") +@pytest.mark.skipif( + sys.version_info <= (3, 7, 0), + reason=("Not passing in CI although it works locally. Will handle it later."), ) @pytest.mark.asyncio async def test_api_manager_list_objects(state_api_manager): @@ -322,17 +403,23 @@ async def test_api_manager_list_objects(state_api_manager): data_source_client.get_all_registered_raylet_ids = MagicMock() data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"] + data_source_client.get_object_info = AsyncMock() data_source_client.get_object_info.side_effect = [ GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]), GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options()) - data_source_client.get_object_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT) - data_source_client.get_object_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT) - result = list(result.values()) - assert len(result) == 2 - verify_schema(ObjectState, result[0]) - verify_schema(ObjectState, result[1]) + data = result.result + data_source_client.get_object_info.assert_any_await( + "1", timeout=DEFAULT_RPC_TIMEOUT + ) + data_source_client.get_object_info.assert_any_await( + "2", timeout=DEFAULT_RPC_TIMEOUT + ) + data = list(data.values()) + assert len(data) == 2 + verify_schema(ObjectState, data[0]) + verify_schema(ObjectState, data[1]) """ Test limit @@ -342,11 +429,38 @@ async def test_api_manager_list_objects(state_api_manager): GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + """ + Test error handling + """ + data_source_client.get_object_info.side_effect = [ + DataSourceUnavailable(), + GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), + ] + result = await state_api_manager.list_objects(option=list_api_options(limit=1)) + # Make sure warnings are returned. + warning = result.partial_failure_warning + assert ( + NODE_QUERY_FAILURE_WARNING.format( + type="raylet", total=2, network_failures=1, log_command="raylet.out" + ) + in warning + ) -@pytest.mark.skip( - reason=("Not passing in CI although it works locally. Will handle it later.") + # Test if all RPCs fail, it will raise an exception. + data_source_client.get_object_info.side_effect = [ + DataSourceUnavailable(), + DataSourceUnavailable(), + ] + with pytest.raises(DataSourceUnavailable): + result = await state_api_manager.list_objects(option=list_api_options(limit=1)) + + +@pytest.mark.skipif( + sys.version_info <= (3, 7, 0), + reason=("Not passing in CI although it works locally. Will handle it later."), ) @pytest.mark.asyncio async def test_api_manager_list_runtime_envs(state_api_manager): @@ -354,6 +468,7 @@ async def test_api_manager_list_runtime_envs(state_api_manager): data_source_client.get_all_registered_agent_ids = MagicMock() data_source_client.get_all_registered_agent_ids.return_value = ["1", "2", "3"] + data_source_client.get_runtime_envs_info = AsyncMock() data_source_client.get_runtime_envs_info.side_effect = [ generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})), generate_runtime_env_info( @@ -362,23 +477,25 @@ async def test_api_manager_list_runtime_envs(state_api_manager): generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}), creation_time=10), ] result = await state_api_manager.list_runtime_envs(option=list_api_options()) - data_source_client.get_runtime_envs_info.assert_any_call( + data = result.result + data_source_client.get_runtime_envs_info.assert_any_await( "1", timeout=DEFAULT_RPC_TIMEOUT ) - data_source_client.get_runtime_envs_info.assert_any_call( + data_source_client.get_runtime_envs_info.assert_any_await( "2", timeout=DEFAULT_RPC_TIMEOUT ) - data_source_client.get_runtime_envs_info.assert_any_call( + + data_source_client.get_runtime_envs_info.assert_any_await( "3", timeout=DEFAULT_RPC_TIMEOUT ) - assert len(result) == 3 - verify_schema(RuntimeEnvState, result[0]) - verify_schema(RuntimeEnvState, result[1]) - verify_schema(RuntimeEnvState, result[2]) + assert len(data) == 3 + verify_schema(RuntimeEnvState, data[0]) + verify_schema(RuntimeEnvState, data[1]) + verify_schema(RuntimeEnvState, data[2]) # Make sure the higher creation time is sorted first. - assert "creation_time_ms" not in result[0] - result[1]["creation_time_ms"] > result[2]["creation_time_ms"] + assert "creation_time_ms" not in data[0] + data[1]["creation_time_ms"] > data[2]["creation_time_ms"] """ Test limit @@ -391,7 +508,38 @@ async def test_api_manager_list_runtime_envs(state_api_manager): generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), ] result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_runtime_envs_info.side_effect = [ + DataSourceUnavailable(), + generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), + generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), + ] + result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1)) + # Make sure warnings are returned. + warning = result.partial_failure_warning + print(warning) + assert ( + NODE_QUERY_FAILURE_WARNING.format( + type="agent", total=3, network_failures=1, log_command="dashboard_agent.log" + ) + in warning + ) + + # Test if all RPCs fail, it will raise an exception. + data_source_client.get_runtime_envs_info.side_effect = [ + DataSourceUnavailable(), + DataSourceUnavailable(), + DataSourceUnavailable(), + ] + with pytest.raises(DataSourceUnavailable): + result = await state_api_manager.list_runtime_envs( + option=list_api_options(limit=1) + ) """ @@ -495,7 +643,7 @@ async def test_state_data_source_client(ray_start_cluster): """ with pytest.raises(ValueError): # Since we didn't register this node id, it should raise an exception. - result = await client.get_object_info("1234") + result = await client.get_runtime_envs_info("1234") wait_for_condition(lambda: len(ray.nodes()) == 2) for node in ray.nodes(): node_id = node["NodeID"] @@ -527,10 +675,9 @@ def get_port(): if node["Alive"]: continue - # Querying to the dead node raises gRPC error, which should be - # translated into `StateSourceNetworkException` - with pytest.raises(StateSourceNetworkException): - result = await client.get_object_info(node_id) + # Querying to the dead node raises gRPC error, which should raise an exception. + with pytest.raises(DataSourceUnavailable): + await client.get_object_info(node_id) # Make sure unregister API works as expected. client.unregister_raylet_client(node_id) @@ -685,6 +832,7 @@ def test_list_jobs(shutdown_only): def verify(): job_data = list(list_jobs().values())[0] + print(job_data) job_id_from_api = list(list_jobs().keys())[0] correct_state = job_data["status"] == "SUCCEEDED" correct_id = job_id == job_id_from_api @@ -914,6 +1062,91 @@ def ready(self): assert output == list_actors(limit=2) +def test_network_failure(shutdown_only): + """When the request fails due to network failure, + verifies it raises an exception.""" + ray.init() + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + wait_for_condition(lambda: len(list_tasks()) == 4) + + # Kill raylet so that list_tasks will have network error on querying raylets. + ray.worker._global_node.kill_raylet() + + with pytest.raises(RayStateApiException): + list_tasks(_explain=True) + + +def test_network_partial_failures(ray_start_cluster): + """When the request fails due to network failure, + verifies it prints proper warning.""" + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + n = cluster.add_node(num_cpus=2) + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + wait_for_condition(lambda: len(list_tasks()) == 4) + + # Make sure when there's 0 node failure, it doesn't print the error. + with pytest.warns(None) as record: + list_tasks(_explain=True) + assert len(record) == 0 + + # Kill raylet so that list_tasks will have network error on querying raylets. + cluster.remove_node(n, allow_graceful=False) + + with pytest.warns(RuntimeWarning): + list_tasks(_explain=True) + + # Make sure when _explain == False, warning is not printed. + with pytest.warns(None) as record: + list_tasks(_explain=False) + assert len(record) == 0 + + +def test_network_partial_failures_timeout(monkeypatch, ray_start_cluster): + """When the request fails due to network timeout, + verifies it prints proper warning.""" + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_testing_asio_delay_us", + "NodeManagerService.grpc_server.GetTasksInfo=10000000:10000000", + ) + cluster.add_node(num_cpus=2) + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + + def verify(): + with pytest.warns(None) as record: + list_tasks(_explain=True, timeout=5) + return len(record) == 1 + + wait_for_condition(verify) + + @pytest.mark.asyncio async def test_cli_format_print(state_api_manager): data_source_client = state_api_manager.data_source_client @@ -922,6 +1155,7 @@ async def test_cli_format_print(state_api_manager): actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")] ) result = await state_api_manager.list_actors(option=list_api_options()) + result = result.result # If the format is not yaml, it will raise an exception. yaml.load( get_state_api_output_to_print(result, format=AvailableFormat.YAML), diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index 7b62546f4da0..178212761285 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -3,14 +3,14 @@ import warnings from ray.tune.checkpoint_manager import _TuneCheckpoint -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI if TYPE_CHECKING: from ray.tune.trial import Trial from ray.tune.stopper import Stopper -class CallbackMeta(ABCMeta): +class _CallbackMeta(ABCMeta): """A helper metaclass to ensure container classes (e.g. CallbackList) have implemented all the callback methods (e.g. `on_*`). """ @@ -60,7 +60,7 @@ def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool: @PublicAPI(stability="beta") -class Callback(metaclass=CallbackMeta): +class Callback(metaclass=_CallbackMeta): """Tune base callback that can be extended and passed to a ``TrialRunner`` Tune callbacks are called from within the ``TrialRunner`` class. There are @@ -270,6 +270,7 @@ def on_experiment_end(self, trials: List["Trial"], **info): pass +@DeveloperAPI class CallbackList(Callback): """Call multiple callbacks at once.""" diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 75cf4b8cb835..e2afe5b0b890 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -62,7 +62,7 @@ def __repr__(self): return f"Checkpoint({self.storage}, {self.value})" -class QueueItem: +class _QueueItem: def __init__(self, priority, value): self.priority = priority self.value = value @@ -74,7 +74,7 @@ def __repr__(self): return f"QueueItem({repr(self.value)})" -class CheckpointManager: +class _CheckpointManager: """Manages checkpoints on the driver for a trial.""" def __init__( @@ -184,7 +184,7 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): # The tuple structure is (not is_nan(), metric), which makes # the nan values to be always considered as the worst # metrics by the heap - queue_item = QueueItem(self._priority(checkpoint), checkpoint) + queue_item = _QueueItem(self._priority(checkpoint), checkpoint) except KeyError: logger.error( "Result dict has no key: {}. " diff --git a/python/ray/tune/error.py b/python/ray/tune/error.py index 42eb3840eabf..9f2b427a2788 100644 --- a/python/ray/tune/error.py +++ b/python/ray/tune/error.py @@ -1,16 +1,20 @@ +from ray.util.annotations import PublicAPI + + +@PublicAPI class TuneError(Exception): """General error class raised by ray.tune.""" pass -class AbortTrialExecution(TuneError): +class _AbortTrialExecution(TuneError): """Error that indicates a trial should not be retried.""" pass -class SubCategoryTuneError(TuneError): +class _SubCategoryTuneError(TuneError): """The more specific TuneError that happens for a certain Tune subroutine. For example starting/stopping a trial. """ @@ -22,19 +26,19 @@ def __str__(self): return self.traceback_str -class TuneStopTrialError(SubCategoryTuneError): +class _TuneStopTrialError(_SubCategoryTuneError): """Error that happens when stopping a tune trial.""" pass -class TuneStartTrialError(SubCategoryTuneError): +class _TuneStartTrialError(_SubCategoryTuneError): """Error that happens when starting a tune trial.""" pass -class TuneGetNextExecutorEventError(SubCategoryTuneError): +class _TuneNoNextExecutorEventError(_SubCategoryTuneError): """Error that happens when waiting to get the next event to handle from RayTrialExecutor. diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index bc17cdb5de13..93cd5304b3c0 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Optional +from ray.util.annotations import DeveloperAPI from six.moves import queue from ray.util.debug import log_once @@ -42,6 +43,7 @@ TEMP_MARKER = ".temp_marker" +@DeveloperAPI class FuncCheckpointUtil: """Utility class holding various function-checkpointing mechanisms. @@ -120,14 +122,14 @@ def create_perm_checkpoint(checkpoint_dir, logdir, step): return perm_checkpoint_dir -class StatusReporter: +class _StatusReporter: """Object passed into your function that you can report status through. Example: - >>> from ray.tune.function_runner import StatusReporter - >>> reporter = StatusReporter(...) # doctest: +SKIP + >>> from ray.tune.function_runner import _StatusReporter + >>> reporter = _StatusReporter(...) # doctest: +SKIP >>> def trainable_function(config, reporter): # doctest: +SKIP - >>> assert isinstance(reporter, StatusReporter) # doctest: +SKIP + >>> assert isinstance(reporter, _StatusReporter) # doctest: +SKIP >>> reporter(timesteps_this_iter=1) # doctest: +SKIP """ @@ -169,8 +171,8 @@ def __call__(self, _metric=None, **kwargs): kwargs: Latest training result status. Example: - >>> from ray.tune.function_runner import StatusReporter - >>> reporter = StatusReporter(...) # doctest: +SKIP + >>> from ray.tune.function_runner import _StatusReporter + >>> reporter = _StatusReporter(...) # doctest: +SKIP >>> reporter(mean_accuracy=1, training_iteration=4) # doctest: +SKIP >>> reporter( # doctest: +SKIP ... mean_accuracy=1, training_iteration=4, done=True @@ -299,6 +301,7 @@ def run(self): ) +@DeveloperAPI class FunctionRunner(Trainable): """Trainable that runs a user function reporting results. @@ -323,7 +326,7 @@ def setup(self, config): # reporting to block until finished. self._error_queue = queue.Queue(1) - self._status_reporter = StatusReporter( + self._status_reporter = _StatusReporter( self._results_queue, self._continue_semaphore, self._end_event, diff --git a/python/ray/tune/insufficient_resources_manager.py b/python/ray/tune/insufficient_resources_manager.py index 5b9d8cc0d0e0..0599b98e44c1 100644 --- a/python/ray/tune/insufficient_resources_manager.py +++ b/python/ray/tune/insufficient_resources_manager.py @@ -86,7 +86,7 @@ def _get_insufficient_resources_error_msg(trial: Trial) -> str: ) -class InsufficientResourcesManager: +class _InsufficientResourcesManager: """Insufficient resources manager. Makes best effort, conservative guesses about if Tune loop is stuck due to diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index ab5f42658f2d..1d74b1ce1500 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -221,9 +221,10 @@ def __init__( self._trial_name = self.kwargs.get("name", "unknown") def run(self): - wandb.require("service") + # Since we're running in a separate process already, use threads. + os.environ["WANDB_START_METHOD"] = "thread" wandb.init(*self.args, **self.kwargs) - wandb.setup() + while True: item_type, item_content = self.queue.get() if item_type == _QueueItem.END: @@ -613,6 +614,6 @@ def __init__(self, config: Dict, *args, **kwargs): self.wandb = self._wandb.init(**wandb_init_kwargs) def stop(self): - self._wandb.join() + self._wandb.finish() if hasattr(super(), "stop"): super().stop() diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index eba7cbd5b9dc..ad7dd4d0c889 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -22,7 +22,7 @@ EXPR_RESULT_FILE, ) from ray.tune.utils import flatten_dict -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI if TYPE_CHECKING: from ray.tune.trial import Trial # noqa: F401 @@ -33,6 +33,7 @@ VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64] +@DeveloperAPI class Logger: """Logging interface for ray.tune. @@ -76,11 +77,13 @@ def flush(self): pass +@PublicAPI class NoopLogger(Logger): def on_result(self, result): pass +@PublicAPI class JsonLogger(Logger): """Logs trial results in json format. @@ -119,6 +122,7 @@ def update_config(self, config: Dict): cloudpickle.dump(self.config, f) +@PublicAPI class CSVLogger(Logger): """Logs results to progress.csv under the trial directory. @@ -168,6 +172,7 @@ def close(self): self._file.close() +@PublicAPI class TBXLogger(Logger): """TensorBoardX Logger. @@ -296,6 +301,7 @@ def _try_log_hparams(self, result): DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger) +@PublicAPI class UnifiedLogger(Logger): """Unified result logger for TensorBoard, rllab/viskit, plain json. @@ -446,6 +452,7 @@ def on_trial_error( self.log_trial_end(trial, failed=True) +@DeveloperAPI class LegacyLoggerCallback(LoggerCallback): """Supports logging to trial-specific `Logger` classes. @@ -495,6 +502,7 @@ def log_trial_end(self, trial: "Trial", failed: bool = False): trial_loggers[trial].close() +@PublicAPI class JsonLoggerCallback(LoggerCallback): """Logs trial results in json format. @@ -551,6 +559,7 @@ def update_config(self, trial: "Trial", config: Dict): cloudpickle.dump(self._trial_configs[trial], f) +@PublicAPI class CSVLoggerCallback(LoggerCallback): """Logs results to progress.csv under the trial directory. @@ -608,6 +617,7 @@ def log_trial_end(self, trial: "Trial", failed: bool = False): del self._trial_files[trial] +@PublicAPI class TBXLoggerCallback(LoggerCallback): """TensorBoardX Logger. diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 6e05b4a1b557..53a6d0045a05 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -27,7 +27,7 @@ TIMESTEPS_TOTAL, AUTO_RESULT_KEYS, ) -from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial, Location +from ray.tune.trial import DEBUG_PRINT_INTERVAL, Trial, _Location from ray.tune.utils import unflattened_lookup from ray.tune.utils.log import Verbosity, has_verbosity @@ -1021,12 +1021,12 @@ def _fair_filter_trials( return filtered_trials -def _get_trial_location(trial: Trial, result: dict) -> Location: +def _get_trial_location(trial: Trial, result: dict) -> _Location: # we get the location from the result, as the one in trial will be # reset when trial terminates node_ip, pid = result.get(NODE_IP, None), result.get(PID, None) if node_ip and pid: - location = Location(node_ip, pid) + location = _Location(node_ip, pid) else: # fallback to trial location if there hasn't been a report yet location = trial.location diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 323e4b92ca22..2992b4adcb0d 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -23,19 +23,19 @@ import ray from ray.exceptions import GetTimeoutError, RayTaskError from ray.tune.error import ( - AbortTrialExecution, + _AbortTrialExecution, TuneError, - TuneStartTrialError, - TuneGetNextExecutorEventError, + _TuneStartTrialError, + _TuneNoNextExecutorEventError, ) from ray.tune.logger import NoopLogger from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE -from ray.tune.utils.placement_groups import PlacementGroupManager, get_tune_pg_prefix +from ray.tune.utils.placement_groups import _PlacementGroupManager, get_tune_pg_prefix from ray.tune.utils.trainable import TrainableUtil -from ray.tune.trial import Trial, _TuneCheckpoint, Location, TrialInfo +from ray.tune.trial import Trial, _TuneCheckpoint, _Location, _TrialInfo from ray.tune.trial_executor import TrialExecutor from ray.tune.utils import warn_if_slow -from ray.tune.utils.resource_updater import ResourceUpdater +from ray.tune.utils.resource_updater import _ResourceUpdater from ray.util import log_once from ray.util.annotations import DeveloperAPI from ray.util.placement_group import remove_placement_group, PlacementGroup @@ -149,7 +149,7 @@ def noop_logger_creator(config, logdir): return NoopLogger(config, logdir) -class ExecutorEventType(Enum): +class _ExecutorEventType(Enum): """The executor event type. Some of the events are internal events to executor while others @@ -165,7 +165,7 @@ class ExecutorEventType(Enum): YIELD = 8 # Yielding back to TrialRunner's main event loop. -class ExecutorEvent: +class _ExecutorEvent: """A struct that describes the event to be processed by TrialRunner. Attributes: @@ -180,7 +180,7 @@ class ExecutorEvent: def __init__( self, - event_type: ExecutorEventType, + event_type: _ExecutorEventType, trial: Optional[Trial] = None, result: Optional[Dict] = None, ): @@ -215,13 +215,13 @@ def __init__( else: self._trial_cleanup = None - self._resource_updater = ResourceUpdater(refresh_period) + self._resource_updater = _ResourceUpdater(refresh_period) self._has_cleaned_up_pgs = False self._reuse_actors = reuse_actors # The maxlen will be updated when `set_max_pending_trials()` is called self._cached_actor_pg = deque(maxlen=1) - self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix()) + self._pg_manager = _PlacementGroupManager(prefix=get_tune_pg_prefix()) self._staged_trials = set() self._trial_just_finished = False self._trial_just_finished_before = False @@ -308,7 +308,7 @@ def _setup_remote_runner(self, trial): if not self.reset_trial( trial, trial.config, trial.experiment_tag, logger_creator ): - raise AbortTrialExecution( + raise _AbortTrialExecution( "Trainable runner reuse requires reset_config() to be " "implemented and return True." ) @@ -316,7 +316,7 @@ def _setup_remote_runner(self, trial): trainable_cls = trial.get_trainable_cls() if not trainable_cls: - raise AbortTrialExecution( + raise _AbortTrialExecution( f"Invalid trainable: {trial.trainable_name}. If you passed " f"a string, make sure the trainable was registered before." ) @@ -328,12 +328,12 @@ def _setup_remote_runner(self, trial): full_actor_class = self._pg_manager.get_full_actor_cls(trial, _actor_cls) # Clear the Trial's location (to be updated later on result) # since we don't know where the remote runner is placed. - trial.set_location(Location()) + trial.set_location(_Location()) logger.debug("Trial %s: Setting up new remote runner.", trial) # Logging for trials is handled centrally by TrialRunner, so # configure the remote runner to use a noop-logger. trial_config = copy.deepcopy(trial.config) - trial_config[TRIAL_INFO] = TrialInfo(trial) + trial_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file trial_config[STDOUT_FILE] = stdout_file @@ -410,7 +410,7 @@ def _train(self, trial): if isinstance(remote, dict): remote = _LocalWrapper(remote) - self._futures[remote] = (ExecutorEventType.TRAINING_RESULT, trial) + self._futures[remote] = (_ExecutorEventType.TRAINING_RESULT, trial) trial_item = self._find_future(trial) assert len(trial_item) < 2, trial_item @@ -458,7 +458,7 @@ def _stop_trial( """ self.set_status(trial, Trial.ERROR if error or exc else Trial.TERMINATED) self._trial_just_finished = True - trial.set_location(Location()) + trial.set_location(_Location()) try: trial.write_error_log(exc=exc) @@ -498,7 +498,7 @@ def _stop_trial( future = trial.runner.stop.remote() pg = self._pg_manager.remove_from_in_use(trial) - self._futures[future] = (ExecutorEventType.STOP_RESULT, pg) + self._futures[future] = (_ExecutorEventType.STOP_RESULT, pg) if self._trial_cleanup: # force trial cleanup within a deadline self._trial_cleanup.add(future) @@ -524,7 +524,7 @@ def start_trial(self, trial: Trial) -> bool: """ try: return self._start_trial(trial) - except AbortTrialExecution as e: + except _AbortTrialExecution as e: logger.exception("Trial %s: Error starting runner, aborting!", trial) time.sleep(2) self._stop_trial(trial, exc=e) @@ -535,7 +535,9 @@ def start_trial(self, trial: Trial) -> bool: if isinstance(e, TuneError): self._stop_trial(trial, exc=e) else: - self._stop_trial(trial, exc=TuneStartTrialError(traceback.format_exc())) + self._stop_trial( + trial, exc=_TuneStartTrialError(traceback.format_exc()) + ) # Note that we don't return the resources, since they may # have been lost. TODO(ujvl): is this the right thing to do? return False @@ -590,7 +592,7 @@ def reset_trial( # Pass magic variables extra_config = copy.deepcopy(new_config) - extra_config[TRIAL_INFO] = TrialInfo(trial) + extra_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file extra_config[STDOUT_FILE] = stdout_file @@ -700,7 +702,7 @@ def save( value = trial.runner.save.remote() checkpoint = _TuneCheckpoint(storage, value, result) trial.saving_to = checkpoint - self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial) + self._futures[value] = (_ExecutorEventType.SAVING_RESULT, trial) return checkpoint def restore(self, trial: Trial) -> None: @@ -745,13 +747,13 @@ def restore(self, trial: Trial) -> None: with self._change_working_directory(trial): remote = trial.runner.restore_from_object.remote(obj) else: - raise AbortTrialExecution( + raise _AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` for remote " "storage-based restoration" ) - self._futures[remote] = (ExecutorEventType.RESTORING_RESULT, trial) + self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT, trial) trial.restoring_from = checkpoint def export_trial_if_needed(self, trial: Trial) -> Dict: @@ -782,7 +784,7 @@ def cleanup(self, trials: List[Trial]) -> None: if not ready: continue event_type, trial_or_pg = self._futures.pop(ready[0]) - if event_type == ExecutorEventType.STOP_RESULT: + if event_type == _ExecutorEventType.STOP_RESULT: post_stop_cleanup(ready[0], trial_or_pg) self._pg_manager.reconcile_placement_groups(trials) @@ -808,7 +810,7 @@ def _change_working_directory(self, trial): def get_next_executor_event( self, live_trials: Set[Trial], next_trial_exists: bool - ) -> ExecutorEvent: + ) -> _ExecutorEvent: """Get the next executor event to be processed in TrialRunner. In case there are multiple events available for handling, the next @@ -863,11 +865,11 @@ def get_next_executor_event( # runner. The next trial can then be scheduled on this PG. if next_trial_exists: if len(self._cached_actor_pg) > 0: - return ExecutorEvent(ExecutorEventType.PG_READY) + return _ExecutorEvent(_ExecutorEventType.PG_READY) # TODO(xwjiang): Expose proper API when we decide to do # ActorPool abstraction. if any(len(r) > 0 for r in self._pg_manager._ready.values()): - return ExecutorEvent(ExecutorEventType.PG_READY) + return _ExecutorEvent(_ExecutorEventType.PG_READY) ################################################################### # Prepare for futures to wait @@ -901,11 +903,11 @@ def get_next_executor_event( # infeasible. # TODO: Move InsufficientResourceManager's logic # to TrialExecutor. It is not Runner's responsibility! - return ExecutorEvent(ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT) + return _ExecutorEvent(_ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT) else: # Training simply takes long time, yield the control back to main # event loop to print progress info etc. - return ExecutorEvent(ExecutorEventType.YIELD) + return _ExecutorEvent(_ExecutorEventType.YIELD) ################################################################### # If there is future returned. @@ -918,13 +920,13 @@ def get_next_executor_event( ################################################################### if ready_future not in self._futures.keys(): self._pg_manager.handle_ready_future(ready_future) - return ExecutorEvent(ExecutorEventType.PG_READY) + return _ExecutorEvent(_ExecutorEventType.PG_READY) ################################################################### # non PG_READY event ################################################################### result_type, trial_or_pg = self._futures.pop(ready_future) - if result_type == ExecutorEventType.STOP_RESULT: + if result_type == _ExecutorEventType.STOP_RESULT: pg = trial_or_pg post_stop_cleanup(ready_future, pg) else: @@ -936,30 +938,30 @@ def get_next_executor_event( if isinstance(future_result, _LocalWrapper): future_result = future_result.unwrap() if result_type in ( - ExecutorEventType.TRAINING_RESULT, - ExecutorEventType.SAVING_RESULT, - ExecutorEventType.RESTORING_RESULT, + _ExecutorEventType.TRAINING_RESULT, + _ExecutorEventType.SAVING_RESULT, + _ExecutorEventType.RESTORING_RESULT, ): logger.debug(f"Returning [{result_type}] for trial {trial}") - return ExecutorEvent( + return _ExecutorEvent( result_type, trial, - result={ExecutorEvent.KEY_FUTURE_RESULT: future_result}, + result={_ExecutorEvent.KEY_FUTURE_RESULT: future_result}, ) else: raise TuneError(f"Unexpected future type - [{result_type}]") except RayTaskError as e: - return ExecutorEvent( - ExecutorEventType.ERROR, + return _ExecutorEvent( + _ExecutorEventType.ERROR, trial, - result={ExecutorEvent.KEY_EXCEPTION: e.as_instanceof_cause()}, + result={_ExecutorEvent.KEY_EXCEPTION: e.as_instanceof_cause()}, ) except Exception: - return ExecutorEvent( - ExecutorEventType.ERROR, + return _ExecutorEvent( + _ExecutorEventType.ERROR, trial, result={ - ExecutorEvent.KEY_EXCEPTION: TuneGetNextExecutorEventError( + _ExecutorEvent.KEY_EXCEPTION: _TuneNoNextExecutorEventError( traceback.format_exc() ) }, diff --git a/python/ray/tune/resources.py b/python/ray/tune/resources.py index 7d960139b6ae..6de29d624a7d 100644 --- a/python/ray/tune/resources.py +++ b/python/ray/tune/resources.py @@ -6,6 +6,7 @@ # For compatibility under py2 to consider unicode as str from typing import Optional +from ray.util.annotations import Deprecated from six import string_types from ray._private.resource_spec import NODE_ID_PREFIX @@ -14,6 +15,7 @@ logger = logging.getLogger(__name__) +@Deprecated class Resources( namedtuple( "Resources", diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index 478e14ea6198..a9de07b5749f 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -7,6 +7,8 @@ import numpy as np # Backwards compatibility +from ray.util.annotations import DeveloperAPI, PublicAPI + try: # Added in numpy>=1.17 but we require numpy>=1.16 np_random_generator = np.random.Generator @@ -67,6 +69,7 @@ def __getattr__(self, name: str) -> Any: ] +@DeveloperAPI class Domain: """Base class to specify a type and valid range to sample parameters from. @@ -129,6 +132,7 @@ def domain_str(self): return "(unknown)" +@DeveloperAPI class Sampler: def sample( self, @@ -140,16 +144,19 @@ def sample( raise NotImplementedError +@DeveloperAPI class BaseSampler(Sampler): def __str__(self): return "Base" +@DeveloperAPI class Uniform(Sampler): def __str__(self): return "Uniform" +@DeveloperAPI class LogUniform(Sampler): def __init__(self, base: float = 10): self.base = base @@ -159,6 +166,7 @@ def __str__(self): return "LogUniform" +@DeveloperAPI class Normal(Sampler): def __init__(self, mean: float = 0.0, sd: float = 0.0): self.mean = mean @@ -170,6 +178,7 @@ def __str__(self): return "Normal" +@DeveloperAPI class Grid(Sampler): """Dummy sampler used for grid search""" @@ -183,6 +192,7 @@ def sample( return RuntimeError("Do not call `sample()` on grid.") +@DeveloperAPI class Float(Domain): class _Uniform(Uniform): def sample( @@ -315,6 +325,7 @@ def domain_str(self): return f"({self.lower}, {self.upper})" +@DeveloperAPI class Integer(Domain): class _Uniform(Uniform): def sample( @@ -396,6 +407,7 @@ def domain_str(self): return f"({self.lower}, {self.upper})" +@DeveloperAPI class Categorical(Domain): class _Uniform(Uniform): def sample( @@ -444,6 +456,7 @@ def domain_str(self): return f"{self.categories}" +@DeveloperAPI class Function(Domain): class _CallSampler(BaseSampler): def sample( @@ -499,6 +512,7 @@ def domain_str(self): return f"{self.func}()" +@DeveloperAPI class Quantized(Sampler): def __init__(self, sampler: Sampler, q: Union[float, int]): self.sampler = sampler @@ -532,6 +546,7 @@ def function(func): ) +@PublicAPI def sample_from(func: Callable[[Dict], Any]): """Specify that tune should sample configuration values from this function. @@ -541,6 +556,7 @@ def sample_from(func: Callable[[Dict], Any]): return Function(func) +@PublicAPI def uniform(lower: float, upper: float): """Sample a float value uniformly between ``lower`` and ``upper``. @@ -551,6 +567,7 @@ def uniform(lower: float, upper: float): return Float(lower, upper).uniform() +@PublicAPI def quniform(lower: float, upper: float, q: float): """Sample a quantized float value uniformly between ``lower`` and ``upper``. @@ -564,6 +581,7 @@ def quniform(lower: float, upper: float, q: float): return Float(lower, upper).uniform().quantized(q) +@PublicAPI def loguniform(lower: float, upper: float, base: float = 10): """Sugar for sampling in different orders of magnitude. @@ -576,6 +594,7 @@ def loguniform(lower: float, upper: float, base: float = 10): return Float(lower, upper).loguniform(base) +@PublicAPI def qloguniform(lower: float, upper: float, q: float, base: float = 10): """Sugar for sampling in different orders of magnitude. @@ -594,6 +613,7 @@ def qloguniform(lower: float, upper: float, q: float, base: float = 10): return Float(lower, upper).loguniform(base).quantized(q) +@PublicAPI def choice(categories: Sequence): """Sample a categorical value. @@ -604,6 +624,7 @@ def choice(categories: Sequence): return Categorical(categories).uniform() +@PublicAPI def randint(lower: int, upper: int): """Sample an integer value uniformly between ``lower`` and ``upper``. @@ -621,6 +642,7 @@ def randint(lower: int, upper: int): return Integer(lower, upper).uniform() +@PublicAPI def lograndint(lower: int, upper: int, base: float = 10): """Sample an integer value log-uniformly between ``lower`` and ``upper``, with ``base`` being the base of logarithm. @@ -636,6 +658,7 @@ def lograndint(lower: int, upper: int, base: float = 10): return Integer(lower, upper).loguniform(base) +@PublicAPI def qrandint(lower: int, upper: int, q: int = 1): """Sample an integer value uniformly between ``lower`` and ``upper``. @@ -653,6 +676,7 @@ def qrandint(lower: int, upper: int, q: int = 1): return Integer(lower, upper).uniform().quantized(q) +@PublicAPI def qlograndint(lower: int, upper: int, q: int, base: float = 10): """Sample an integer value log-uniformly between ``lower`` and ``upper``, with ``base`` being the base of logarithm. @@ -671,6 +695,7 @@ def qlograndint(lower: int, upper: int, q: int, base: float = 10): return Integer(lower, upper).loguniform(base).quantized(q) +@PublicAPI def randn(mean: float = 0.0, sd: float = 1.0): """Sample a float value normally with ``mean`` and ``sd``. @@ -682,6 +707,7 @@ def randn(mean: float = 0.0, sd: float = 1.0): return Float(None, None).normal(mean, sd) +@PublicAPI def qrandn(mean: float, sd: float, q: float): """Sample a float value normally with ``mean`` and ``sd``. diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 8272a22a2174..540d56c7add1 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -8,10 +8,12 @@ from ray.tune.result import DEFAULT_METRIC from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler from ray.tune.trial import Trial +from ray.util import PublicAPI logger = logging.getLogger(__name__) +@PublicAPI class AsyncHyperBandScheduler(FIFOScheduler): """Implements the Async Successive Halving. diff --git a/python/ray/tune/schedulers/hb_bohb.py b/python/ray/tune/schedulers/hb_bohb.py index 7652af95d8f8..0cb3171427c2 100644 --- a/python/ray/tune/schedulers/hb_bohb.py +++ b/python/ray/tune/schedulers/hb_bohb.py @@ -5,10 +5,12 @@ from ray.tune.schedulers.trial_scheduler import TrialScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler from ray.tune.trial import Trial +from ray.util import PublicAPI logger = logging.getLogger(__name__) +@PublicAPI class HyperBandForBOHB(HyperBandScheduler): """Extends HyperBand early stopping algorithm for BOHB. diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 12e940102390..40be9fdb5d72 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -9,6 +9,7 @@ from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler from ray.tune.trial import Trial from ray.tune.error import TuneError +from ray.util import PublicAPI logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ # `max_attr=81, eta=3` from the blog post. Trials will fill up # from smallest bracket to largest, with largest # having the most rounds of successive halving. +@PublicAPI class HyperBandScheduler(FIFOScheduler): """Implements the HyperBand early stopping algorithm. @@ -188,7 +190,7 @@ def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): self._trial_info[trial] = cur_bracket, self._state["band_idx"] def _create_bracket(self, s): - return Bracket( + return _Bracket( time_attr=self._time_attr, max_trials=self._get_n0(s), init_t_attr=self._get_r0(s), @@ -232,7 +234,7 @@ def on_trial_result( return action def _process_bracket( - self, trial_runner: "trial_runner.TrialRunner", bracket: "Bracket" + self, trial_runner: "trial_runner.TrialRunner", bracket: "_Bracket" ) -> str: """This is called whenever a trial makes progress. @@ -353,7 +355,7 @@ def state(self) -> Dict[str, int]: } -class Bracket: +class _Bracket: """Logical object for tracking Hyperband bracket progress. Keeps track of proper parameters as designated by HyperBand. diff --git a/python/ray/tune/schedulers/median_stopping_rule.py b/python/ray/tune/schedulers/median_stopping_rule.py index 4948af94ade6..0ccb3643197a 100644 --- a/python/ray/tune/schedulers/median_stopping_rule.py +++ b/python/ray/tune/schedulers/median_stopping_rule.py @@ -8,10 +8,12 @@ from ray.tune.result import DEFAULT_METRIC from ray.tune.trial import Trial from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.util.annotations import PublicAPI logger = logging.getLogger(__name__) +@PublicAPI class MedianStoppingRule(FIFOScheduler): """Implements the median stopping rule as described in the Vizier paper: diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index a43276926eb5..32641bc4b3f5 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -17,12 +17,13 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars from ray.tune.trial import Trial, _TuneCheckpoint +from ray.util import PublicAPI from ray.util.debug import log_once logger = logging.getLogger(__name__) -class PBTTrialState: +class _PBTTrialState: """Internal PBT state tracked per-trial.""" def __init__(self, trial: Trial): @@ -44,7 +45,7 @@ def __repr__(self) -> str: ) -def explore( +def _explore( config: Dict, mutations: Dict, resample_probability: float, @@ -65,7 +66,7 @@ def explore( for key, distribution in mutations.items(): if isinstance(distribution, dict): new_config.update( - {key: explore(config[key], mutations[key], resample_probability, None)} + {key: _explore(config[key], mutations[key], resample_probability, None)} ) elif isinstance(distribution, list): if ( @@ -100,7 +101,7 @@ def explore( return new_config -def make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str: +def _make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str: """Appends perturbed params to the trial name to show in the console.""" resolved_vars = {} @@ -109,7 +110,7 @@ def make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str: return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars)) -def fill_config( +def _fill_config( config: Dict, attr: str, search_space: Union[Callable, Domain, list, dict] ): """Add attr to config by sampling from search_space.""" @@ -122,9 +123,10 @@ def fill_config( elif isinstance(search_space, dict): config[attr] = {} for k, v in search_space.items(): - fill_config(config[attr], k, v) + _fill_config(config[attr], k, v) +@PublicAPI class PopulationBasedTraining(FIFOScheduler): """Implements the Population Based Training (PBT) algorithm. @@ -362,7 +364,7 @@ def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): ) ) - self._trial_state[trial] = PBTTrialState(trial) + self._trial_state[trial] = _PBTTrialState(trial) for attr in self._hyperparam_mutations.keys(): if attr not in trial.config: @@ -373,7 +375,7 @@ def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): ) # Add attr to trial's config by sampling search space from # hyperparam_mutations. - fill_config(trial.config, attr, self._hyperparam_mutations[attr]) + _fill_config(trial.config, attr, self._hyperparam_mutations[attr]) # Make sure this attribute is added to CLI output. trial.evaluated_params[attr] = trial.config[attr] @@ -490,7 +492,7 @@ def on_trial_result( ) def _save_trial_state( - self, state: PBTTrialState, time: int, result: Dict, trial: Trial + self, state: _PBTTrialState, time: int, result: Dict, trial: Trial ): """Saves necessary trial information when result is received. Args: @@ -548,8 +550,8 @@ def _checkpoint_or_exploit( def _log_config_on_step( self, - trial_state: PBTTrialState, - new_state: PBTTrialState, + trial_state: _PBTTrialState, + new_state: _PBTTrialState, trial: Trial, trial_to_clone: Trial, new_config: Dict, @@ -586,7 +588,7 @@ def _log_config_on_step( def _get_new_config(self, trial, trial_to_clone): """Gets new config for trial by exploring trial_to_clone's config.""" - return explore( + return _explore( trial_to_clone.config, self._hyperparam_mutations, self._resample_probability, @@ -632,7 +634,7 @@ def _exploit( trial_state, new_state, trial, trial_to_clone, new_config ) - new_tag = make_experiment_tag( + new_tag = _make_experiment_tag( trial_state.orig_tag, new_config, self._hyperparam_mutations ) if trial.status == Trial.PAUSED: @@ -728,6 +730,7 @@ def debug_string(self) -> str: ) +@PublicAPI class PopulationBasedTrainingReplay(FIFOScheduler): """Replays a Population Based Training run. @@ -875,7 +878,7 @@ def on_trial_result( trial, _TuneCheckpoint.MEMORY, result=result ) - new_tag = make_experiment_tag(self.experiment_tag, new_config, new_config) + new_tag = _make_experiment_tag(self.experiment_tag, new_config, new_config) trial_executor = trial_runner.trial_executor trial_executor.stop_trial(trial) diff --git a/python/ray/tune/schedulers/resource_changing_scheduler.py b/python/ray/tune/schedulers/resource_changing_scheduler.py index 37c2cffacffb..e1837decfee7 100644 --- a/python/ray/tune/schedulers/resource_changing_scheduler.py +++ b/python/ray/tune/schedulers/resource_changing_scheduler.py @@ -678,6 +678,7 @@ def evenly_distribute_cpus_gpus_distributed( ) +@PublicAPI(stability="beta") class ResourceChangingScheduler(TrialScheduler): """A utility scheduler to dynamically change resources of live trials. diff --git a/python/ray/tune/schedulers/trial_scheduler.py b/python/ray/tune/schedulers/trial_scheduler.py index c8f0ccfe7616..6da357e03bf7 100644 --- a/python/ray/tune/schedulers/trial_scheduler.py +++ b/python/ray/tune/schedulers/trial_scheduler.py @@ -3,8 +3,10 @@ from ray.tune import trial_runner from ray.tune.result import DEFAULT_METRIC from ray.tune.trial import Trial +from ray.util.annotations import DeveloperAPI, PublicAPI +@DeveloperAPI class TrialScheduler: """Interface for implementing a Trial Scheduler class.""" @@ -121,6 +123,7 @@ def restore(self, checkpoint_path: str): raise NotImplementedError +@PublicAPI class FIFOScheduler(TrialScheduler): """Simple scheduler that just runs trials in submission order.""" diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 512fcda2f3ed..976016410078 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -21,6 +21,7 @@ ) from ray.tune.suggest.search import SearchAlgorithm from ray.tune.utils.util import atomic_save, load_newest_checkpoint +from ray.util import PublicAPI SERIALIZATION_THRESHOLD = 1e6 @@ -184,6 +185,7 @@ def __iter__(self): return self +@PublicAPI class BasicVariantGenerator(SearchAlgorithm): """Uses Tune's variant generation for resolving variables. diff --git a/python/ray/tune/suggest/repeater.py b/python/ray/tune/suggest/repeater.py index eaa36b396800..82f0873a666f 100644 --- a/python/ray/tune/suggest/repeater.py +++ b/python/ray/tune/suggest/repeater.py @@ -6,6 +6,7 @@ from ray.tune.suggest.suggestion import Searcher from ray.tune.suggest.util import set_search_properties_backwards_compatible +from ray.util import PublicAPI logger = logging.getLogger(__name__) @@ -70,6 +71,7 @@ def count(self) -> int: return len(self._trials) +@PublicAPI class Repeater(Searcher): """A wrapper algorithm for repeating trials of same parameters. diff --git a/python/ray/tune/suggest/search.py b/python/ray/tune/suggest/search.py index 45ca9e2382fd..1bd100ce03f6 100644 --- a/python/ray/tune/suggest/search.py +++ b/python/ray/tune/suggest/search.py @@ -1,8 +1,10 @@ from typing import Dict, List, Optional, Union from ray.tune.experiment import Experiment +from ray.util.annotations import DeveloperAPI +@DeveloperAPI class SearchAlgorithm: """Interface of an event handler API for hyperparameter search. diff --git a/python/ray/tune/suggest/search_generator.py b/python/ray/tune/suggest/search_generator.py index 3261366d690b..a7814d0bdd49 100644 --- a/python/ray/tune/suggest/search_generator.py +++ b/python/ray/tune/suggest/search_generator.py @@ -16,6 +16,7 @@ atomic_save, load_newest_checkpoint, ) +from ray.util.annotations import DeveloperAPI logger = logging.getLogger(__name__) @@ -26,6 +27,7 @@ def _warn_on_repeater(searcher, total_samples): _warn_num_samples(searcher, total_samples) +@DeveloperAPI class SearchGenerator(SearchAlgorithm): """Generates trials to be passed to the TrialRunner. diff --git a/python/ray/tune/suggest/suggestion.py b/python/ray/tune/suggest/suggestion.py index a443f0ba817b..091f1ea5749f 100644 --- a/python/ray/tune/suggest/suggestion.py +++ b/python/ray/tune/suggest/suggestion.py @@ -6,6 +6,7 @@ from typing import Dict, Optional, List, Union, Any, TYPE_CHECKING from ray.tune.suggest.util import set_search_properties_backwards_compatible +from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.debug import log_once if TYPE_CHECKING: @@ -37,6 +38,7 @@ ) +@DeveloperAPI class Searcher: """Abstract class for wrapping suggesting algorithms. @@ -423,6 +425,7 @@ def mode(self) -> str: return self._mode +@PublicAPI class ConcurrencyLimiter(Searcher): """A wrapper algorithm for limiting the number of concurrent trials. diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 3633bc92661c..5e82cce4ba3b 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -9,6 +9,7 @@ from ray.tune import TuneError from ray.tune.sample import Categorical, Domain, Function, RandomState +from ray.util.annotations import DeveloperAPI logger = logging.getLogger(__name__) @@ -473,6 +474,7 @@ def __getattribute__(self, item): return value +@DeveloperAPI class RecursiveDependencyError(Exception): def __init__(self, msg: str): Exception.__init__(self, msg) diff --git a/python/ray/tune/sync_client.py b/python/ray/tune/sync_client.py index 2715e33e31b4..58368e6330a4 100644 --- a/python/ray/tune/sync_client.py +++ b/python/ray/tune/sync_client.py @@ -16,7 +16,7 @@ import ray from ray.tune.error import TuneError from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI from ray.ml.utils.remote_storage import ( S3_PREFIX, GS_PREFIX, @@ -186,6 +186,7 @@ def _is_legacy_sync_fn(func) -> bool: return True +@DeveloperAPI class FunctionBasedClient(SyncClient): def __init__(self, sync_up_func, sync_down_func, delete_func=None): self.sync_up_func = sync_up_func @@ -226,6 +227,7 @@ def delete(self, target): NOOP = FunctionBasedClient(noop, noop) +@DeveloperAPI class CommandBasedClient(SyncClient): """Syncs between two directories with the given command. @@ -429,6 +431,7 @@ def _validate_exclude_template(exclude_template): ) +@DeveloperAPI class RemoteTaskClient(SyncClient): """Sync client that uses remote tasks to synchronize two directories. diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index ab6d05663bb0..58fa3dcaa4ee 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -37,7 +37,7 @@ SyncClient, RemoteTaskClient, ) -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI if TYPE_CHECKING: from ray.tune.trial import Trial @@ -208,6 +208,7 @@ def __post_init__(self): validate_sync_config(self) +@DeveloperAPI class Syncer: def __init__(self, local_dir: str, remote_dir: str, sync_client: SyncClient = NOOP): """Syncs between two directories with the sync_function. @@ -321,6 +322,7 @@ def _remote_path(self) -> Optional[Union[str, Tuple[str, str]]]: return self._remote_dir +@DeveloperAPI class CloudSyncer(Syncer): """Syncer for syncing files to/from the cloud.""" @@ -336,6 +338,7 @@ def sync_down_if_needed(self, exclude: Optional[List] = None): ) +@DeveloperAPI class NodeSyncer(Syncer): """Syncer for syncing files to/from a remote dir to a local dir.""" @@ -407,6 +410,7 @@ def _remote_path(self) -> Optional[Union[str, Tuple[str, str]]]: return "{}@{}:{}/".format(ssh_user, self.worker_ip, self._remote_dir) +@DeveloperAPI def get_cloud_syncer( local_dir: str, remote_dir: Optional[str] = None, @@ -460,6 +464,7 @@ def get_cloud_syncer( return _syncers[key] +@DeveloperAPI def get_node_syncer( local_dir: str, remote_dir: Optional[str] = None, @@ -507,6 +512,7 @@ def get_node_syncer( return _syncers[key] +@DeveloperAPI class SyncerCallback(Callback): def __init__(self, sync_function: Optional[Union[bool, Callable]]): self._sync_function = sync_function diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index b4a9267b63d9..b9bcf5421203 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -7,7 +7,7 @@ from unittest.mock import patch from ray.tune.result import TRAINING_ITERATION -from ray.tune.checkpoint_manager import _TuneCheckpoint, CheckpointManager, logger +from ray.tune.checkpoint_manager import _TuneCheckpoint, _CheckpointManager, logger class CheckpointManagerTest(unittest.TestCase): @@ -16,7 +16,7 @@ def mock_result(metric, i): return {"i": metric, TRAINING_ITERATION: i} def checkpoint_manager(self, keep_checkpoints_num): - return CheckpointManager(keep_checkpoints_num, "i", delete_fn=lambda c: None) + return _CheckpointManager(keep_checkpoints_num, "i", delete_fn=lambda c: None) def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) @@ -180,7 +180,7 @@ def testOnMemoryCheckpoint(self): self.assertEqual(checkpoint_manager.best_checkpoints(), []) def testSameCheckpoint(self): - checkpoint_manager = CheckpointManager( + checkpoint_manager = _CheckpointManager( 1, "i", delete_fn=lambda c: os.remove(c.value) ) diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/tune/tests/test_integration_wandb.py index e2b4a5053995..da0cc98a313f 100644 --- a/python/ray/tune/tests/test_integration_wandb.py +++ b/python/ray/tune/tests/test_integration_wandb.py @@ -18,7 +18,7 @@ _QueueItem, ) from ray.tune.result import TRIAL_INFO -from ray.tune.trial import TrialInfo +from ray.tune.trial import _TrialInfo from ray.tune.utils.placement_groups import PlacementGroupFactory @@ -228,7 +228,7 @@ def testWandbMixinConfig(self): PlacementGroupFactory([{"CPU": 1}]), "/tmp", ) - trial_info = TrialInfo(trial) + trial_info = _TrialInfo(trial) config[TRIAL_INFO] = trial_info @@ -290,7 +290,7 @@ def testWandbDecoratorConfig(self): PlacementGroupFactory([{"CPU": 1}]), "/tmp", ) - trial_info = TrialInfo(trial) + trial_info = _TrialInfo(trial) @wandb_mixin def train_fn(config): diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index eff306555e2d..9ef8a177c593 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -68,7 +68,7 @@ END_TO_END_COMMAND = """ import ray from ray import tune -from ray.tune.trial import Location +from ray.tune.trial import _Location from ray.tune.progress_reporter import _get_trial_location from unittest.mock import patch @@ -76,7 +76,7 @@ def mock_get_trial_location(trial, result): location = _get_trial_location(trial, result) if location.pid: - return Location("123.123.123.123", "1") + return _Location("123.123.123.123", "1") return location @@ -262,7 +262,7 @@ def f(config): import random import numpy as np import time -from ray.tune.trial import Location +from ray.tune.trial import _Location from ray.tune.progress_reporter import _get_trial_location from unittest.mock import patch @@ -270,7 +270,7 @@ def f(config): def mock_get_trial_location(trial, result): location = _get_trial_location(trial, result) if location.pid: - return Location("123.123.123.123", "1") + return _Location("123.123.123.123", "1") return location diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 77334abb0551..9ff66a7fbded 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -11,8 +11,8 @@ from ray.tune import Trainable from ray.tune.callback import Callback from ray.tune.ray_trial_executor import ( - ExecutorEvent, - ExecutorEventType, + _ExecutorEvent, + _ExecutorEventType, RayTrialExecutor, ) from ray.tune.registry import _global_registry, TRAINABLE_CLASS, register_trainable @@ -21,7 +21,10 @@ from ray.tune.trial import Trial, _TuneCheckpoint from ray.tune.resources import Resources from ray.cluster_utils import Cluster -from ray.tune.utils.placement_groups import PlacementGroupFactory, PlacementGroupManager +from ray.tune.utils.placement_groups import ( + PlacementGroupFactory, + _PlacementGroupManager, +) from unittest.mock import patch @@ -99,7 +102,7 @@ def _simulate_starting_trial(self, trial): future_result = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=True ) - assert future_result.type == ExecutorEventType.PG_READY + assert future_result.type == _ExecutorEventType.PG_READY self.assertTrue(self.trial_executor.start_trial(trial)) self.assertEqual(Trial.RUNNING, trial.status) @@ -108,9 +111,9 @@ def _simulate_getting_result(self, trial): event = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=False ) - if event.type == ExecutorEventType.TRAINING_RESULT: + if event.type == _ExecutorEventType.TRAINING_RESULT: break - training_result = event.result[ExecutorEvent.KEY_FUTURE_RESULT] + training_result = event.result[_ExecutorEvent.KEY_FUTURE_RESULT] if isinstance(training_result, list): for r in training_result: trial.update_last_result(r) @@ -124,8 +127,8 @@ def _simulate_saving(self, trial): event = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=False ) - assert event.type == ExecutorEventType.SAVING_RESULT - self.process_trial_save(trial, event.result[ExecutorEvent.KEY_FUTURE_RESULT]) + assert event.type == _ExecutorEventType.SAVING_RESULT + self.process_trial_save(trial, event.result[_ExecutorEvent.KEY_FUTURE_RESULT]) self.assertEqual(checkpoint, trial.checkpoint) def testStartStop(self): @@ -489,7 +492,7 @@ def testPlacementGroupFactoryEquality(self): self.assertEqual(counter[pgf_3], 3) def testHasResourcesForTrialWithCaching(self): - pgm = PlacementGroupManager() + pgm = _PlacementGroupManager() pgf1 = PlacementGroupFactory([{"CPU": self.head_cpus}]) pgf2 = PlacementGroupFactory([{"CPU": self.head_cpus - 1}]) diff --git a/python/ray/tune/tests/test_resource_updater.py b/python/ray/tune/tests/test_resource_updater.py index 0f9c7f41d5a9..3396ddfa8775 100644 --- a/python/ray/tune/tests/test_resource_updater.py +++ b/python/ray/tune/tests/test_resource_updater.py @@ -1,12 +1,12 @@ import ray from ray.tests.conftest import * # noqa -from ray.tune.utils.resource_updater import ResourceUpdater +from ray.tune.utils.resource_updater import _ResourceUpdater def test_resource_updater(ray_start_cluster): cluster = ray_start_cluster - resource_updater = ResourceUpdater(refresh_period=100) + resource_updater = _ResourceUpdater(refresh_period=100) # Before intialization, all resources are 0. assert resource_updater.get_num_cpus() == 0 assert resource_updater.get_num_gpus() == 0 @@ -26,7 +26,7 @@ def test_resource_updater(ray_start_cluster): assert resource_updater.get_num_cpus() == 1 assert resource_updater.get_num_gpus() == 2 - resource_updater = ResourceUpdater(refresh_period=0) + resource_updater = _ResourceUpdater(refresh_period=0) assert resource_updater.get_num_cpus() == 2 assert resource_updater.get_num_gpus() == 3 diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index cae78aeb0476..6a9138306c28 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -13,8 +13,8 @@ from ray.tune.checkpoint_manager import _TuneCheckpoint from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback from ray.tune.ray_trial_executor import ( - ExecutorEvent, - ExecutorEventType, + _ExecutorEvent, + _ExecutorEventType, RayTrialExecutor, ) from ray.tune.result import TRAINING_ITERATION @@ -110,8 +110,8 @@ def testCallbackSteps(self): for t in trials: self.trial_runner.add_trial(t) - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.PG_READY + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.PG_READY ) self.trial_runner.step() @@ -134,8 +134,8 @@ def testCallbackSteps(self): ) ) - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.PG_READY + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.PG_READY ) self.trial_runner.step() @@ -156,10 +156,10 @@ def testCallbackSteps(self): trials[0].saving_to = cp # Let the first trial save a checkpoint - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.SAVING_RESULT, + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.SAVING_RESULT, trial=trials[0], - result={ExecutorEvent.KEY_FUTURE_RESULT: "__checkpoint"}, + result={_ExecutorEvent.KEY_FUTURE_RESULT: "__checkpoint"}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) @@ -167,8 +167,8 @@ def testCallbackSteps(self): # Let the second trial send a result result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.TRAINING_RESULT, + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={"future_result": result}, ) @@ -181,8 +181,8 @@ def testCallbackSteps(self): # Let the second trial restore from a checkpoint trials[1].restoring_from = cp - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.RESTORING_RESULT, trial=trials[1] + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.RESTORING_RESULT, trial=trials[1] ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) @@ -190,11 +190,11 @@ def testCallbackSteps(self): # Let the second trial finish trials[1].restoring_from = None - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.TRAINING_RESULT, + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={ - ExecutorEvent.KEY_FUTURE_RESULT: { + _ExecutorEvent.KEY_FUTURE_RESULT: { TRAINING_ITERATION: 2, "metric": 900, "done": True, @@ -206,10 +206,10 @@ def testCallbackSteps(self): self.assertEqual(self.callback.state["trial_complete"]["trial"].trial_id, "two") # Let the first trial error - self.executor.next_future_result = ExecutorEvent( - event_type=ExecutorEventType.ERROR, + self.executor.next_future_result = _ExecutorEvent( + event_type=_ExecutorEventType.ERROR, trial=trials[0], - result={ExecutorEvent.KEY_EXCEPTION: Exception()}, + result={_ExecutorEvent.KEY_EXCEPTION: Exception()}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index df3b497626d4..20cbae4b3b77 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -25,7 +25,7 @@ HyperBandForBOHB, ) -from ray.tune.schedulers.pbt import explore, PopulationBasedTrainingReplay +from ray.tune.schedulers.pbt import _explore, PopulationBasedTrainingReplay from ray.tune.suggest._mock import _MockSearcher from ray.tune.suggest.suggestion import ConcurrencyLimiter from ray.tune.trial import Trial, _TuneCheckpoint @@ -1164,38 +1164,38 @@ def assertProduces(fn, values): # Categorical case assertProduces( - lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 8} + lambda: _explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 8} ) assertProduces( - lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4} + lambda: _explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4} ) assertProduces( - lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {8, 10} + lambda: _explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {8, 10} ) assertProduces( - lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), + lambda: _explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4, 8, 10}, ) assertProduces( - lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x), + lambda: _explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x), {3, 4, 8, 10}, ) # Continuous case assertProduces( - lambda: explore( + lambda: _explore( {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0, lambda x: x ), {80, 120}, ) assertProduces( - lambda: explore( + lambda: _explore( {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0, lambda x: x ), {80.0, 120.0}, ) assertProduces( - lambda: explore( + lambda: _explore( {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0, lambda x: x ), {10.0, 100.0}, @@ -1224,7 +1224,7 @@ def assertNestedProduces(fn, values): # Nested mutation and spec assertNestedProduces( - lambda: explore( + lambda: _explore( { "a": {"b": 4}, "1": {"2": {"3": 100}}, @@ -1246,7 +1246,7 @@ def assertNestedProduces(fn, values): # Nested mutation and spec assertNestedProduces( - lambda: explore( + lambda: _explore( { "a": {"b": 4}, "1": {"2": {"3": 100}}, diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f793714262f2..98d11475b938 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -15,7 +15,7 @@ import ray.cloudpickle as cloudpickle from ray.exceptions import RayActorError, RayTaskError from ray.tune import TuneError -from ray.tune.checkpoint_manager import _TuneCheckpoint, CheckpointManager +from ray.tune.checkpoint_manager import _TuneCheckpoint, _CheckpointManager # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -class Location: +class _Location: """Describes the location at which Trial is placed to run.""" def __init__(self, hostname=None, pid=None): @@ -92,7 +92,7 @@ def validate(formats): raise TuneError("Unsupported import/export format: " + formats[i]) -class CheckpointDeleter: +class _CheckpointDeleter: """Checkpoint deleter callback for a runner.""" def __init__(self, trial_id, runner): @@ -129,7 +129,7 @@ def __call__(self, checkpoint: _TuneCheckpoint): logger.debug("Local checkpoint dir not found during deletion.") -class TrialInfo: +class _TrialInfo: """Serializable struct for holding information for a Trial. Attributes: @@ -286,7 +286,7 @@ def __init__( # Parameters that Tune varies across searches. self.evaluated_params = evaluated_params or {} self.experiment_tag = experiment_tag - self.location = Location() + self.location = _Location() trainable_cls = self.get_trainable_cls() if trainable_cls and _setup_default_resource: default_resources = trainable_cls.default_resource_request(self.config) @@ -369,10 +369,10 @@ def __init__( self.keep_checkpoints_num = keep_checkpoints_num self.checkpoint_score_attr = checkpoint_score_attr self.sync_on_checkpoint = sync_on_checkpoint - self.checkpoint_manager = CheckpointManager( + self.checkpoint_manager = _CheckpointManager( keep_checkpoints_num, checkpoint_score_attr, - CheckpointDeleter(self._trainable_name(), self.runner), + _CheckpointDeleter(self._trainable_name(), self.runner), ) # Restoration fields @@ -414,7 +414,7 @@ def _get_default_result_or_future(self) -> Optional[dict]: self._default_result_or_future = None if self._default_result_or_future and self.runner: self.set_location( - Location( + _Location( self._default_result_or_future.get(NODE_IP), self._default_result_or_future.get(PID), ) @@ -564,7 +564,7 @@ def set_runner(self, runner): self._default_result_or_future = runner.get_auto_filled_metrics.remote( debug_metrics_only=True ) - self.checkpoint_manager.delete = CheckpointDeleter( + self.checkpoint_manager.delete = _CheckpointDeleter( self._trainable_name(), runner ) # No need to invalidate state cache: runner is not stored in json @@ -678,7 +678,7 @@ def update_last_result(self, result): if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) - self.set_location(Location(result.get(NODE_IP), result.get(PID))) + self.set_location(_Location(result.get(NODE_IP), result.get(PID))) self.last_result = result self.last_update_time = time.time() @@ -800,7 +800,7 @@ def __getstate__(self): state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None - state["location"] = Location() + state["location"] = _Location() # Avoid waiting for events that will never occur on resume. state["restoring_from"] = None state["saving_to"] = None diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 1c199f2d02e7..f4e0ffd06cea 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -11,17 +11,17 @@ import ray from ray.exceptions import RayTaskError -from ray.tune.error import TuneStopTrialError +from ray.tune.error import _TuneStopTrialError from ray.tune.impl.out_of_band_serialize_dataset import out_of_band_serialize_dataset from ray.util import get_node_ip_address from ray.tune import TuneError from ray.tune.callback import CallbackList, Callback from ray.tune.experiment import Experiment -from ray.tune.insufficient_resources_manager import InsufficientResourcesManager +from ray.tune.insufficient_resources_manager import _InsufficientResourcesManager from ray.tune.ray_trial_executor import ( RayTrialExecutor, - ExecutorEventType, - ExecutorEvent, + _ExecutorEventType, + _ExecutorEvent, ) from ray.tune.result import ( DEBUG_METRICS, @@ -41,6 +41,7 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder from ray.tune.web_server import TuneServer +from ray.util.annotations import DeveloperAPI from ray.util.debug import log_once MAX_DEBUG_TRIALS = 20 @@ -200,6 +201,7 @@ def _serialize_and_write(): return self._checkpoint_dir +@DeveloperAPI class TrialRunner: """A TrialRunner implements the event loop for scheduling trials on Ray. @@ -280,7 +282,7 @@ def __init__( self._search_alg = search_alg or BasicVariantGenerator() self._scheduler_alg = scheduler or FIFOScheduler() self.trial_executor = trial_executor or RayTrialExecutor() - self._insufficient_resources_manager = InsufficientResourcesManager() + self._insufficient_resources_manager = _InsufficientResourcesManager() self._pending_trial_queue_times = {} # Set the number of maximum pending trials @@ -721,33 +723,33 @@ def _wait_and_handle_event(self, next_trial: Optional[Trial]): event = self.trial_executor.get_next_executor_event( self._live_trials, next_trial is not None ) - if event.type == ExecutorEventType.PG_READY: + if event.type == _ExecutorEventType.PG_READY: self._on_pg_ready(next_trial) - elif event.type == ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT: + elif event.type == _ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT: self._insufficient_resources_manager.on_no_available_trials( self.get_trials() ) - elif event.type == ExecutorEventType.YIELD: + elif event.type == _ExecutorEventType.YIELD: pass else: trial = event.trial result = event.result - if event.type == ExecutorEventType.ERROR: - self._on_executor_error(trial, result[ExecutorEvent.KEY_EXCEPTION]) - elif event.type == ExecutorEventType.RESTORING_RESULT: + if event.type == _ExecutorEventType.ERROR: + self._on_executor_error(trial, result[_ExecutorEvent.KEY_EXCEPTION]) + elif event.type == _ExecutorEventType.RESTORING_RESULT: self._on_restoring_result(trial) else: assert event.type in ( - ExecutorEventType.SAVING_RESULT, - ExecutorEventType.TRAINING_RESULT, + _ExecutorEventType.SAVING_RESULT, + _ExecutorEventType.TRAINING_RESULT, ), f"Unexpected future type - {event.type}" - if event.type == ExecutorEventType.TRAINING_RESULT: + if event.type == _ExecutorEventType.TRAINING_RESULT: self._on_training_result( - trial, result[ExecutorEvent.KEY_FUTURE_RESULT] + trial, result[_ExecutorEvent.KEY_FUTURE_RESULT] ) else: self._on_saving_result( - trial, result[ExecutorEvent.KEY_FUTURE_RESULT] + trial, result[_ExecutorEvent.KEY_FUTURE_RESULT] ) self._post_process_on_training_saving_result(trial) except Exception as e: @@ -1362,7 +1364,7 @@ def stop_trial(self, trial): self._process_trial_failure(trial, exc=e) else: self._process_trial_failure( - trial, TuneStopTrialError(traceback.format_exc()) + trial, _TuneStopTrialError(traceback.format_exc()) ) def cleanup_trials(self): @@ -1421,7 +1423,7 @@ def __setstate__(self, state): self._server = TuneServer(self, self._server_port) -class TrialExecutorWrapper(RayTrialExecutor): +class _TrialExecutorWrapper(RayTrialExecutor): """Wraps around TrialExecutor class, intercepts API calls and warns users of restricted API access. @@ -1450,6 +1452,7 @@ def __getattr__(self, attr): return getattr(self._trial_executor, attr) +@DeveloperAPI class TrialRunnerWrapper(TrialRunner): """Wraps around TrialRunner class, intercepts API calls and warns users of restricted API access. @@ -1467,7 +1470,7 @@ def __init__( executor_whitelist_attr: Optional[set] = None, ): self._trial_runner = trial_runner - self._trial_executor = TrialExecutorWrapper( + self._trial_executor = _TrialExecutorWrapper( trial_runner.trial_executor, executor_whitelist_attr ) self._runner_whitelist_attr = runner_whitelist_attr or set() diff --git a/python/ray/tune/utils/log.py b/python/ray/tune/utils/log.py index dd54adbb6851..ab11939d453e 100644 --- a/python/ray/tune/utils/log.py +++ b/python/ray/tune/utils/log.py @@ -1,7 +1,10 @@ from enum import Enum from typing import Union +from ray.util import PublicAPI + +@PublicAPI class Verbosity(Enum): V0_MINIMAL = 0 V1_EXPERIMENT = 1 diff --git a/python/ray/tune/utils/placement_groups.py b/python/ray/tune/utils/placement_groups.py index 4eb0e4bac869..aa3bb71ae882 100644 --- a/python/ray/tune/utils/placement_groups.py +++ b/python/ray/tune/utils/placement_groups.py @@ -263,7 +263,7 @@ def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]]): return PlacementGroupFactory([bundle]) -class PlacementGroupManager: +class _PlacementGroupManager: """PlacementGroupManager to stage and manage placement groups. .. versionadded:: 1.3.0 diff --git a/python/ray/tune/utils/resource_updater.py b/python/ray/tune/utils/resource_updater.py index 1645d6f7b432..86c3c61a926f 100644 --- a/python/ray/tune/utils/resource_updater.py +++ b/python/ray/tune/utils/resource_updater.py @@ -17,7 +17,7 @@ def _to_gb(n_bytes): return round(n_bytes / (1024 ** 3), 2) -class ResourceUpdater: +class _ResourceUpdater: """Periodic Resource updater for Tune. Initially, all resources are set to 0. The updater will try to update resources @@ -130,4 +130,4 @@ def __reduce__(self): # Do not need to serialize resources, because we can always # update it again. This also prevents keeping outdated resources # when deserialized. - return ResourceUpdater, (self._refresh_period,) + return _ResourceUpdater, (self._refresh_period,) diff --git a/python/ray/tune/utils/serialization.py b/python/ray/tune/utils/serialization.py index 17d2cfad3574..12ac7b1af060 100644 --- a/python/ray/tune/utils/serialization.py +++ b/python/ray/tune/utils/serialization.py @@ -4,11 +4,13 @@ from ray import cloudpickle as cloudpickle from ray._private.utils import binary_to_hex, hex_to_binary +from ray.util.annotations import DeveloperAPI from ray.util.debug import log_once logger = logging.getLogger(__name__) +@DeveloperAPI class TuneFunctionEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, types.FunctionType): @@ -27,6 +29,7 @@ def _to_cloudpickle(self, obj): } +@DeveloperAPI class TuneFunctionDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) diff --git a/python/ray/tune/utils/trainable.py b/python/ray/tune/utils/trainable.py index efd33f137f23..526d27dd9bc0 100644 --- a/python/ray/tune/utils/trainable.py +++ b/python/ray/tune/utils/trainable.py @@ -12,11 +12,13 @@ from ray.tune.registry import _ParameterRegistry from ray.tune.utils import detect_checkpoint_function from ray.util import placement_group +from ray.util.annotations import DeveloperAPI from six import string_types logger = logging.getLogger(__name__) +@DeveloperAPI class TrainableUtil: @staticmethod def process_checkpoint( @@ -239,6 +241,7 @@ def get_checkpoints_paths(logdir): return chkpt_df +@DeveloperAPI class PlacementGroupUtil: @staticmethod def get_remote_worker_options( diff --git a/python/ray/tune/web_server.py b/python/ray/tune/web_server.py index 2eb8ae856315..a88b43a9d4f5 100644 --- a/python/ray/tune/web_server.py +++ b/python/ray/tune/web_server.py @@ -10,6 +10,7 @@ from ray.tune import TuneError from ray.tune.suggest import BasicVariantGenerator from ray._private.utils import binary_to_hex, hex_to_binary +from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: from ray.tune.trial_runner import TrialRunner @@ -26,6 +27,7 @@ ) +@DeveloperAPI class TuneClient: """Client to interact with an ongoing Tune experiment. @@ -93,6 +95,7 @@ def _deserialize(self, response): return parsed +@DeveloperAPI def RunnerHandler(runner): class Handler(SimpleHTTPRequestHandler): """A Handler is a custom handler for TuneServer. @@ -224,6 +227,7 @@ def _add_trials(self, name, spec): return Handler +@DeveloperAPI class TuneServer(threading.Thread): """A TuneServer is a thread that initializes and runs a HTTPServer. diff --git a/python/ray/util/actor_pool.py b/python/ray/util/actor_pool.py index ce377857ff05..d7ff9f241c36 100644 --- a/python/ray/util/actor_pool.py +++ b/python/ray/util/actor_pool.py @@ -69,7 +69,7 @@ def map(self, fn, values): # by calling `has_next` and `gen_next` repeteadly. while self.has_next(): try: - self.get_next(timeout=0) + self.get_next(timeout=0, ignore_if_timedout=True) except TimeoutError: pass @@ -165,7 +165,7 @@ def has_next(self): """ return bool(self._future_to_actor) - def get_next(self, timeout=None): + def get_next(self, timeout=None, ignore_if_timedout=False): """Returns the next pending result in order. This returns the next result produced by submit(), blocking for up to @@ -191,10 +191,15 @@ def get_next(self, timeout=None): "It is not allowed to call get_next() after get_next_unordered()." ) future = self._index_to_future[self._next_return_index] + timeout_msg = "Timed out waiting for result" + raise_timeout_after_ignore = False if timeout is not None: res, _ = ray.wait([future], timeout=timeout) if not res: - raise TimeoutError("Timed out waiting for result") + if not ignore_if_timedout: + raise TimeoutError(timeout_msg) + else: + raise_timeout_after_ignore = True del self._index_to_future[self._next_return_index] self._next_return_index += 1 @@ -202,9 +207,13 @@ def get_next(self, timeout=None): i, a = self._future_to_actor.pop(future_key) self._return_actor(a) + if raise_timeout_after_ignore: + raise TimeoutError( + timeout_msg + ". The task {} has been ignored.".format(future) + ) return ray.get(future) - def get_next_unordered(self, timeout=None): + def get_next_unordered(self, timeout=None, ignore_if_timedout=False): """Returns any of the next pending results. This returns some result produced by submit(), blocking for up to @@ -232,14 +241,23 @@ def get_next_unordered(self, timeout=None): raise StopIteration("No more results to get") # TODO(ekl) bulk wait for performance res, _ = ray.wait(list(self._future_to_actor), num_returns=1, timeout=timeout) + timeout_msg = "Timed out waiting for result" + raise_timeout_after_ignore = False if res: [future] = res else: - raise TimeoutError("Timed out waiting for result") + if not ignore_if_timedout: + raise TimeoutError(timeout_msg) + else: + raise_timeout_after_ignore = True i, a = self._future_to_actor.pop(future) self._return_actor(a) del self._index_to_future[i] self._next_return_index = max(self._next_return_index, i + 1) + if raise_timeout_after_ignore: + raise TimeoutError( + timeout_msg + ". The task {} has been ignored.".format(future) + ) return ray.get(future) def _return_actor(self, actor): diff --git a/python/ray/util/data/__init__.py b/python/ray/util/data/__init__.py deleted file mode 100644 index 6440d32e1b4c..000000000000 --- a/python/ray/util/data/__init__.py +++ /dev/null @@ -1,98 +0,0 @@ -from collections import defaultdict -from typing import Iterable - -import pandas as pd - -from ray.util.data.dataset import MLDataset -from ray.util.data.parquet import read_parquet -from ray.util.iter import T, ParallelIterator - -try: - import dataclasses -except: # noqa: E722 - pass -else: - from dataclasses import is_dataclass - - -def to_pandas( - it: ParallelIterator[T], batch_size: int = 32 -) -> "ParallelIterator[pd.DataFrame]": - """Convert the a ParallelIterator to ParallelIterator of pd.DataFrame. - - The record type should be list like object or dataclass instance. If - the record is a iterable, we will convert to a list. If the record has - __getitem__ attr, we will use __getitem__ to get the given column - indexes data to create pandas DataFrame. If the record is dataclass - instance we will use __getattr__ to get the given column. - - Args: - it (ParallelIterator[T]): the ParallelIterator to converted - batch_size (int): batch the given size to create a pandas DataFrame - Returns: - A ParallelIterator of pd.DataFrame - """ - it = it.batch(batch_size) - - def convert_fn(input_it: Iterable[T]) -> Iterable[pd.DataFrame]: - names = [] - for batch in input_it: - assert isinstance(batch, list) - if hasattr(batch[0], "__getitem__"): - batch = pd.DataFrame(batch) - elif hasattr(batch[0], "__iter__"): - batch = [list(item) for item in batch] - batch = pd.DataFrame(batch) - elif is_dataclass(batch[0]): - if not names: - names = [f.name for f in dataclasses.fields(batch[0])] - values = defaultdict(lambda x: []) - for item in batch: - for col in names: - values[col].append(getattr(item, col)) - batch = pd.DataFrame(values, columns=names) - else: - raise ValueError( - "MLDataset only support list like item or dataclass instance" - ) - - yield batch - - it = it._with_transform( - lambda local_it: local_it.transform(convert_fn), ".to_pandas()" - ) - return it - - -def from_parallel_iter( - para_it: ParallelIterator[T], - need_convert: bool = True, - batch_size: int = 32, - repeated: bool = False, -) -> MLDataset: - """Create a MLDataset from an existing ParallelIterator. - - The object of the ParallelIterator should be list like object or dataclass - instance. - - Args: - para_it (ParallelIterator[T]): An existing parallel iterator, and each - should be a list like object or dataclass instance. - need_convert (bool): whether need to convert to pandas.DataFrame. This - should be False if the record type is pandas.DataFrame. - batch_size (int): if need_convert is True, we will batch the batch_size - records to a pandas.DataFrame - repeated (bool): whether the para_it is repeated. - Returns: - a MLDataset - """ - - if need_convert: - para_it = to_pandas(para_it, batch_size) - else: - batch_size = 0 - - return MLDataset.from_parallel_it(para_it, batch_size, repeated) - - -__all__ = ["from_parallel_iter", "read_parquet", "MLDataset"] diff --git a/python/ray/util/data/dataset.py b/python/ray/util/data/dataset.py deleted file mode 100644 index 2d9416ba4f6f..000000000000 --- a/python/ray/util/data/dataset.py +++ /dev/null @@ -1,446 +0,0 @@ -import random -from typing import Callable, List, Iterable, Iterator - -import pandas as pd - -from ray.util.annotations import Deprecated -from ray.util.iter import ( - _NextValueNotReady, - LocalIterator, - ParallelIterator, - T, - U, - _ActorSet, - from_items, -) - - -@Deprecated -class MLDataset(ParallelIterator[pd.DataFrame]): - """A distributed ML dataset implemented based on ParallelIterator - - All item should be a list like object or dataclass instance. - - Args: - batch_size (int): The batch size of the current dataset. It should be - larger than zero, and 0 means unknown. - """ - - def __init__( - self, - actor_sets: List[_ActorSet], - name: str, - parent_iterators: List[ParallelIterator[pd.DataFrame]], - batch_size: int, - repeated: bool, - ): - super(MLDataset, self).__init__(actor_sets, name, parent_iterators) - self._batch_size = batch_size - self._repeated = repeated - - @classmethod - def from_modin(cls, df, num_shards: int = 2): - """Create a MLDataset from a Modin Dataframe. - - Args: - df (modin.pandas.DataFrame): A Modin Dataframe. - num_shards (int): The number of worker actors to create. - """ - try: - import modin.pandas as pd - except ImportError: - raise ImportError( - "Cannot convert from Modin because Modin is not installed." - ) from None - if not isinstance(df, (pd.DataFrame, pd.Series)): - raise ValueError("Must provide a modin.pandas DataFrame or Series") - from modin.distributed.dataframe.pandas.partitions import unwrap_partitions - - parts = unwrap_partitions(df) - modin_iter = from_items(parts, num_shards=num_shards, repeat=False) - return cls.from_parallel_it(modin_iter, batch_size=0, repeated=False) - - @staticmethod - def from_parallel_it( - para_it: ParallelIterator[pd.DataFrame], batch_size: int, repeated: bool = False - ) -> "MLDataset": - """Create a MLDataset from an parallel iterator - - The record of ParallelIterator should be pandas.DataFrame. - - Args: - para_it (ParallelIterator[T]): An existing parallel iterator, - and each should be a list like object or dataclass instance - batch_size (int): The batch size of the current dataset. It - should be larger than zero, and 0 means unknown. - repeated (bool): whether the para_it is repeated. - Returns: - A MLDataset - """ - return MLDataset( - para_it.actor_sets, - para_it.name, - para_it.parent_iterators, - batch_size, - repeated, - ) - - def __iter__(self): - raise TypeError( - "You must use it.gather_sync() or it.gather_async() to " - "iterate over the results of a MLDataset." - ) - - def __str__(self): - return repr(self) - - def __repr__(self): - return f"MLDataset[{self.name}]" - - def _with_transform(self, local_it_fn, name) -> "MLDataset": - """Helper function to create new MLDataset""" - para_it = super()._with_transform(local_it_fn, name) - return MLDataset.from_parallel_it(para_it, self._batch_size, self._repeated) - - def transform( - self, fn: Callable[[Iterable[pd.DataFrame]], Iterable[pd.DataFrame]] - ) -> "MLDataset": - """Apply the fn function to the MLDataset - - Args: - fn (Callable[[Iterable[DataFrame]], Iterable[DataFrame]]): - The function to applied. The input is a iterator of - pandas.DataFrame, and the output should also be a iterator of - pandas.DataFrame. - Returns: - A new MLDataset - """ - return self._with_transform( - lambda local_it: local_it.transform(fn), ".transform()" - ) - - def batch(self, batch_size: int) -> "MLDataset": - """Rebatch the number of rows for each pandas.DataFrame record - - Unlike the ParallelIterator.batch. This method rebatch the underlying - the pandas DataFrame, and each pandas DataFrame will have batch_size - rows. - """ - if batch_size == self._batch_size: - return self - - def batch_fn(it: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: - it = iter(it) - return_df = None - while True: - try: - cur_df = next(it) - cur_index = 0 - cur_size = cur_df.shape[0] - while cur_df is not None or (cur_index + batch_size) < cur_size: - if cur_df is None or cur_index == cur_size: - cur_df = next(it) - cur_index = 0 - cur_size = cur_df.shape[0] - if return_df is not None: - ri = cur_index + batch_size - return_df.shape[0] - ri = min(ri, cur_size) - tmp = cur_df.iloc[cur_index:ri] - return_df = pd.concat([return_df, tmp]) - cur_index = ri - else: - ri = cur_index + batch_size - ri = min(ri, cur_size) - return_df = cur_df.iloc[cur_index:ri] - cur_index = ri - if return_df.shape[0] == batch_size: - return_df.index = range(return_df.shape[0]) - yield return_df - return_df = None - except StopIteration: - break - - if return_df is not None: - return_df.index = range(return_df.shape[0]) - yield return_df - - self._batch_size = batch_size - return self._with_transform( - lambda local_it: local_it.transform(batch_fn), f".batch({batch_size})" - ) - - def flatten(self) -> "MLDataset": - raise Exception("Unsupported operation") - - def combine(self, fn: Callable[[T], List[U]]) -> "MLDataset": - raise Exception("Unsupported operation") - - @property - def repeated(self) -> bool: - return self._repeated - - @property - def batch_size(self) -> int: - return self._batch_size - - def local_shuffle(self, shuffle_buffer_size: int, seed: int = None) -> "MLDataset": - """Applying local shuffle - - Unlike the ParallelIterator.local_shuffle. This shuffle will first - apply the local_shuffle for each shards and then shuffle the each - pandas DataFrame. - """ - ds = super().local_shuffle(shuffle_buffer_size, seed) - - def shuffle_fn(it: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: - for df in it: - df = df.sample(frac=1, random_state=seed) - yield df - - ds = ds._with_transform( - lambda local_it: local_it.transform(shuffle_fn), ".inner_pandas_shuffle()" - ) - - return ds - - def repartition(self, num_partitions: int, batch_ms: int = 0) -> "MLDataset": - """see ParallelIterator.repartition""" - if num_partitions == self.num_shards(): - return self - para_it = super().repartition(num_partitions, batch_ms) - return MLDataset.from_parallel_it(para_it, self._batch_size) - - def union(self, other: "MLDataset") -> "MLDataset": - """Return an iterator that is the union of this and the other.""" - if not isinstance(other, MLDataset): - raise TypeError(f"other must be of type MLDataset, got {type(other)}") - - if self._repeated != other.repeated: - raise TypeError( - f"want to union two MLDataset which have different repeated " - f"type, self repeated: {self._repeated}, other repeated: " - f"{other.repeated}" - ) - - batch_size = 0 - if self._batch_size == other._batch_size: - batch_size = self._batch_size - - actor_sets = [] - actor_sets.extend(self.actor_sets) - actor_sets.extend(other.actor_sets) - # if one of these iterators is a result of a repartition, we need to - # keep an explicit reference to its parent iterator - return MLDataset( - actor_sets, - f"ParallelUnion[{self}, {other}]", - parent_iterators=self.parent_iterators + other.parent_iterators, - batch_size=batch_size, - repeated=self._repeated, - ) - - def select_shards(self, shards_to_keep: List[int]) -> "MLDataset": - para_it = super().select_shards(shards_to_keep) - return MLDataset.from_parallel_it(para_it, self._batch_size, self._repeated) - - def get_repeatable_shard( - self, - index: int, - batch_ms: int = 0, - num_async: int = 1, - shuffle: bool = False, - shuffle_buffer_size: int = 1, - seed: int = None, - ) -> Iterator: - """Get the given shard of the current dataset. - - The return is a iterator. Each call iter on the returned iterator will - get the shard data from beginning. And it support shuffle the return - iterator when each call iter on the return. - Args: - index (int): the shard index id, -1 means collect all data. - batch_ms (int): Batches items for batch_ms milliseconds - before retrieving it. Increasing batch_ms increases latency - but improves throughput. If this value is 0, then items are - returned immediately. - num_async (int): The max number of requests in flight. Increasing - this improves the amount of pipeline parallelism in the - iterator. - shuffle (bool): whether shuffle the given shard data - shuffle_buffer_size (int): same as ParallelIterator.local_shuffle - seed (int): the random seed - Returns: - The given shard iterator. If the shuffle is True, each call iter - will return a different ordered iterator. - """ - return _RepeatableIterator( - self, index, batch_ms, num_async, shuffle, shuffle_buffer_size, seed - ) - - def to_torch( - self, - feature_columns=None, - feature_shapes=None, - feature_types=None, - label_column=None, - label_shape=None, - label_type=None, - ): - """Create a TorchMLDataset from the current MLDataset. - - Args: - feature_columns (List[Any]): the column indexes name. - feature_shapes (Optional[List[Any]]): the feature shapes should - match the feature columns if provided. - feature_types (Optional[List["torch.dtype"]]): the feature types - should match the feature columns if provided. All feature will - be cast into torch.float by default. Otherwise, cast into the - provided type. - label_column (Any): the label name. - label_shape (Optional[int]): the label shape. - label_type (Optional["torch.dtype"]): the label type, this will be - cast into torch.float by default - Returns: - A TorchMLDataset - """ - from ray.util.sgd.torch.torch_dataset import TorchMLDataset - - return TorchMLDataset( - self, - feature_columns, - feature_shapes, - feature_types, - label_column, - label_shape, - label_type, - ) - - def to_tf( - self, - feature_columns=None, - feature_shapes=None, - feature_types=None, - label_column=None, - label_shape=None, - label_type=None, - ): - """Create a TFMLDataset from the current MLDataset. - - Args: - feature_columns (List[Any]): the column names. - feature_shapes (Optional[List[tf.TensorShape]]): the feature shapes - should match the feature columns if provided. - feature_types (Optional[List["tf.DType"]]): the feature types - should match the feature columns if provided. All feature will - be cast into tf.float by default. Otherwise, cast into the - provided type. - label_column (Any): the label name. - label_shape (Optional[tf.TensorShape]): the label shape. - label_type (Optional["tf.DType"]): the label type, this will be - cast into tf.float by default - Returns: - A TFMLDataset - """ - from ray.util.sgd.tf.tf_dataset import TFMLDataset - - return TFMLDataset( - self, - feature_columns, - feature_shapes, - feature_types, - label_column, - label_shape, - label_type, - ) - - -class _RepeatableIterator(Iterator[T]): - """A repeatable iterator for the given shard index data. - - Each call iter(_RepeatableIterator instance) will fetch the data from - beginning and will return a different order or data if set shuffle - Args: - ds (MLDataset): a MLDataset - shard_index (int): the shard index id. -1 means collect all data. - batch_ms (int): Batches items for batch_ms milliseconds - before retrieving it. Increasing batch_ms increases latency - but improves throughput. If this value is 0, then items are - returned immediately. - num_async (int): The max number of requests in flight. Increasing this - improves the amount of pipeline parallelism in the iterator. - shuffle (bool): whether shuffle the given shard data - shuffle_buffer_size (int): same as ParallelIterator.local_shuffle - seed (int): the random seed - """ - - def __init__( - self, - ds: MLDataset, - shard_index: int, - batch_ms: int = 0, - num_async: int = 1, - shuffle: bool = False, - shuffle_buffer_size: int = 1, - seed: int = None, - ): - super(_RepeatableIterator, self).__init__() - self._ds = ds - self._shard_index = shard_index - self._batch_ms = batch_ms - self._num_async = num_async - self._shuffle = shuffle - self._shuffle_buffer_size = shuffle_buffer_size - self._seed = seed - self._local_it: LocalIterator[T] = None - - self._i = 0 - - def __next__(self) -> T: - assert self._local_it is not None - return next(self._local_it) - - def __iter__(self) -> Iterator[T]: - if self._shard_index >= 0: - it = self._ds.get_shard(self._shard_index, self._batch_ms, self._num_async) - else: - if self._num_async > 0: - it = self._ds.gather_async( - batch_ms=self._batch_ms, num_async=self._num_async - ) - else: - it = self._ds.gather_sync() - if self._shuffle: - it = self.shuffle(it) - - self._local_it = it - return self - - def shuffle(self, local_it: LocalIterator[T]) -> LocalIterator[pd.DataFrame]: - shuffle_random = random.Random(self._seed) - - def apply_shuffle(it): - buffer = [] - for item in it: - if isinstance(item, _NextValueNotReady): - yield item - else: - buffer.append(item) - if len(buffer) >= self._shuffle_buffer_size: - item = buffer.pop(shuffle_random.randint(0, len(buffer) - 1)) - item = item.sample(frac=1, random_state=self._seed) - yield item - while len(buffer) > 0: - item = buffer.pop(shuffle_random.randint(0, len(buffer) - 1)) - item = item.sample(frac=1, random_state=self._seed) - yield item - - return LocalIterator( - local_it.base_iterator, - local_it.shared_metrics, - local_it.local_transforms + [apply_shuffle], - name=local_it.name - + ".shuffle(shuffle_buffer_size={}, seed={})".format( - self._shuffle_buffer_size, - str(self._seed) if self._seed is not None else "None", - ), - ) diff --git a/python/ray/util/data/examples/__init__.py b/python/ray/util/data/examples/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/util/data/interface.py b/python/ray/util/data/interface.py deleted file mode 100644 index e8300544484b..000000000000 --- a/python/ray/util/data/interface.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Iterable - -import pandas as pd - - -class _SourceShard: - def prefix(self) -> str: - raise NotImplementedError - - @property - def shard_id(self) -> int: - raise NotImplementedError - - def __iter__(self) -> Iterable[pd.DataFrame]: - raise NotImplementedError - - def __str__(self): - return repr(self) - - def __repr__(self): - return f"{self.prefix()}SourceShard[{self.shard_id}]" diff --git a/python/ray/util/data/parquet.py b/python/ray/util/data/parquet.py deleted file mode 100644 index 916b91fa9bb2..000000000000 --- a/python/ray/util/data/parquet.py +++ /dev/null @@ -1,130 +0,0 @@ -import random -from typing import Iterable -from typing import List, Optional, Union - -import pyarrow.parquet as pq -from pandas import DataFrame - -from ray.util.annotations import Deprecated -import ray.util.iter as para_iter -from .dataset import MLDataset -from .interface import _SourceShard - - -class ParquetSourceShard(_SourceShard): - def __init__( - self, - data_pieces: List[pq.ParquetDatasetPiece], - columns: Optional[List[str]], - partitions: Optional[pq.ParquetPartitions], - shard_id: int, - ): - self._data_pieces = data_pieces - self._columns = columns - self._partitions = partitions - self._shard_id = shard_id - - def prefix(self) -> str: - return "Parquet" - - @property - def shard_id(self) -> int: - return self._shard_id - - def __iter__(self) -> Iterable[DataFrame]: - for piece in self._data_pieces: - yield piece.read( - columns=self._columns, use_threads=False, partitions=self._partitions - ).to_pandas() - - -@Deprecated -def read_parquet( - paths: Union[str, List[str]], - num_shards: int, - rowgroup_split: bool = True, - shuffle: bool = False, - shuffle_seed: int = None, - columns: Optional[List[str]] = None, - **kwargs, -) -> MLDataset: - """Read parquet format data from hdfs like filesystem into a MLDataset. - - .. code-block:: python - - # create dummy data - spark.range(...).write.parquet(...) - # create MLDataset - data = ray.util.data.read_parquet(...) - # convert to TorchMLDataset - ds = data.to_torch(feature_columns=..., label_column=...) - # get the given shard data - shard = ds.get_shard(0) - # create the DataLoader from the shard data and this can be used for - # the TorchTrainer data creator as well - data = DataLoader(shard, batch_size=32) - - Args: - paths (Union[str, List[str]): a single file path or a list of file path - num_shards (int): the number of shards - rowgroup_split (bool): whether split the files into shards based on - rowgroup. If set False, each shard will have a list of files. - shuffle (bool): whether shuffle the ParquetDatasetPiece order when - divide into shards - shuffle_seed (int): the shuffle seed - columns (Optional[List[str]]): a list of column names to read - kwargs: the other parquet read options - Returns: - A MLDataset - """ - pq_ds = pq.ParquetDataset(paths, **kwargs) - pieces = pq_ds.pieces - data_pieces = [] - if rowgroup_split: - # split base on rowgroup - for piece in pieces: - num_row_groups = piece.get_metadata().to_dict()["num_row_groups"] - for i in range(num_row_groups): - data_pieces.append( - pq.ParquetDatasetPiece( - piece.path, - piece.open_file_func, - piece.file_options, - i, - piece.partition_keys, - ) - ) - else: - # split base on file pieces - data_pieces = pieces.copy() - - if len(data_pieces) < num_shards: - raise ValueError( - f"number of data pieces: {len(data_pieces)} should " - f"larger than num_shards: {num_shards}" - ) - - if shuffle: - random_shuffle = random.Random(shuffle_seed) - random_shuffle.shuffle(data_pieces) - shards = [[] for _ in range(num_shards)] - for i, item in enumerate(data_pieces): - shard = shards[i % num_shards] - if item.row_group is None: - for number in item.get_metadata().to_dict()["num_row_groups"]: - shard.append( - pq.ParquetDatasetPiece( - item.path, - item.open_file_func, - item.file_options, - number, - item.partition_keys, - ) - ) - else: - shard.append(item) - - for i, shard in enumerate(shards): - shards[i] = ParquetSourceShard(shard, columns, pq_ds.partitions, i) - it = para_iter.from_iterators(shards, False, "parquet") - return MLDataset.from_parallel_it(it, batch_size=0, repeated=False) diff --git a/python/ray/util/ml_utils/BUILD b/python/ray/util/ml_utils/BUILD index 8f43babe6148..b5ec98ba3328 100644 --- a/python/ray/util/ml_utils/BUILD +++ b/python/ray/util/ml_utils/BUILD @@ -2,6 +2,14 @@ # Tests from the python/ray/util/ml_util/tests directory. # Please keep these sorted alphabetically. # -------------------------------------------------------------------- +py_test( + name = "test_checkpoint_manager", + size = "small", + srcs = ["tests/test_checkpoint_manager.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_util_lib"] +) + py_test( name = "test_mlflow", size = "medium", diff --git a/python/ray/util/ml_utils/checkpoint_manager.py b/python/ray/util/ml_utils/checkpoint_manager.py new file mode 100644 index 000000000000..f4e6653011a5 --- /dev/null +++ b/python/ray/util/ml_utils/checkpoint_manager.py @@ -0,0 +1,446 @@ +import copy +import enum +import gc +import heapq +import logging +import numbers +import os +import shutil + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Dict, Union, Callable, Tuple, List, Any + +import ray +from ray.ml import Checkpoint +from ray.tune.result import NODE_IP +from ray.util import PublicAPI +from ray.util.annotations import DeveloperAPI +from ray.util.ml_utils.util import is_nan + +MAX = "max" +MIN = "min" + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class CheckpointStorage(enum.Enum): + MEMORY = enum.auto() + PERSISTENT = enum.auto() + + +class _TrackedCheckpoint: + """Checkpoint tracked by a checkpoint manager. + + This class is used to track checkpoints generated by trainables and trainers in + order to add metadata (e.g. the result, or the node where it has been created) + and for bookkeeping purposes. + + The data can be an object, a checkpoint directory, or a future to either. Because + we can't know if it's data or a directory from a future, this class expects + a ``storage_mode`` that makes the data type explicit. + + The passed metrics can be used to compare performance of different checkpoints. + The ``checkpoint_id`` is passed as an alternative to be able to order + checkpoints in time. + + Args: + dir_or_data: Checkpoint directory, checkpoint data, or a future to either. + storage_mode: Either MEMORY or PERSISTENT. + checkpoint_id: Checkpoint number. Will be used to determine checkpoint order + if metrics are not available. Usually this should be monotonically + increasing for each tracked checkpoint. + metrics: Observed metrics for this checkpoint. This is used to determine + the value of the ``checkpoint_score_attr``. + node_ip: IP of the node where the checkpoint was generated. Defaults + to the current node. + """ + + def __init__( + self, + dir_or_data: Optional[Union[str, Path, Dict, ray.ObjectRef]], + storage_mode: CheckpointStorage, + checkpoint_id: Optional[int] = None, + metrics: Optional[Dict] = None, + node_ip: Optional[str] = None, + ): + self.dir_or_data = dir_or_data + self.id = checkpoint_id + self.storage_mode = storage_mode + + self.metrics = metrics or {} + self.node_ip = node_ip or self.metrics.get(NODE_IP, None) + + if storage_mode == CheckpointStorage.MEMORY and not isinstance( + dir_or_data, (dict, ray.ObjectRef) + ): + raise ValueError( + f"Memory checkpoints only support Ray object references and dicts " + f"as their data. Got: {dir_or_data}" + ) + + def commit(self, path: Optional[Path] = None) -> None: + """Commit checkpoint to disk, if needed. + + Args: + path: Path to commit checkpoint to. + """ + if self.storage_mode == CheckpointStorage.MEMORY: + # Do not persist memory checkpoints + return + + if not path: + # If no path is given, skip + return + + if not isinstance(self.dir_or_data, dict): + # Only persist dictionaries + return + + checkpoint = Checkpoint.from_dict(self.dir_or_data) + self.dir_or_data = checkpoint.to_directory(str(path)) + + def delete( + self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None + ) -> None: + """Delete checkpoint from disk, if needed. + + Args: + delete_fn: Function to be called with the tracked checkpoint as an + argument. Defaults to removing the local directory/file. + """ + delete_fn = delete_fn or _default_delete_fn + try: + delete_fn(self) + except Exception as e: + logger.warning(f"Checkpoint deletion failed: {e}") + + def __repr__(self): + if self.storage_mode == CheckpointStorage.MEMORY: + return f"" + + return ( + f"" + ) + + +def _default_delete_fn(checkpoint: _TrackedCheckpoint): + if checkpoint.storage_mode != CheckpointStorage.PERSISTENT: + return + + if isinstance(checkpoint.dir_or_data, (str, bytes, os.PathLike)): + if os.path.isfile(checkpoint.dir_or_data): + os.remove(checkpoint.dir_or_data) + return + elif os.path.isdir(checkpoint.dir_or_data): + shutil.rmtree(checkpoint.dir_or_data) + return + raise RuntimeError( + f"Could not delete checkpoint {checkpoint} from disk as it is " + f"neither file not directory. Path: {checkpoint.dir_or_data}." + ) + + +class _HeapCheckpointWrapper: + def __init__(self, priority: Any, tracked_checkpoint: _TrackedCheckpoint): + self.priority = priority + self.tracked_checkpoint = tracked_checkpoint + + def __lt__(self, other): + return self.priority < other.priority + + def __repr__(self): + return f"_HeapCheckpoint({repr(self.tracked_checkpoint)})" + + +@PublicAPI(stability="beta") +@dataclass +class CheckpointStrategy: + """Configurable parameters for defining the checkpointing strategy. + + Default behavior is to persist all checkpoints to disk. If + ``num_to_keep`` is set, the default retention policy is to keep the + checkpoints with maximum timestamp, i.e. the most recent checkpoints. + + Args: + num_to_keep (Optional[int]): The number of checkpoints to keep + on disk for this run. If a checkpoint is persisted to disk after + there are already this many checkpoints, then an existing + checkpoint will be deleted. If this is ``None`` then checkpoints + will not be deleted. If this is ``0`` then no checkpoints will be + persisted to disk. + checkpoint_score_attribute (str): The attribute that will be used to + score checkpoints to determine which checkpoints should be kept + on disk when there are greater than ``num_to_keep`` checkpoints. + This attribute must be a key from the checkpoint + dictionary which has a numerical value. Per default, the last + checkpoints will be kept. + checkpoint_score_order (str). Either "max" or "min". + If "max", then checkpoints with highest values of + ``checkpoint_score_attribute`` will be kept. + If "min", then checkpoints with lowest values of + ``checkpoint_score_attribute`` will be kept. + """ + + num_to_keep: Optional[int] = None + checkpoint_score_attribute: Optional[str] = None + checkpoint_score_order: str = MAX + + def __post_init__(self): + if self.num_to_keep is not None and self.num_to_keep < 0: + raise ValueError( + f"Received invalid num_to_keep: " + f"{self.num_to_keep}. " + f"Must be None or non-negative integer." + ) + if self.checkpoint_score_order not in (MAX, MIN): + raise ValueError( + f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".' + ) + + +class _CheckpointManager: + """Common checkpoint management and bookkeeping class for Ray Train and Tune. + + This class acts as the common core for checkpoint bookkeeping in Ray ML libraries. + On a high level, this manager keeps a reference to all stored checkpoints + (both in-memory and on-disk checkpoints). For on-disk checkpoints, it + keeps a configured number of checkpoints according to specified metrics. + + The manager supports lazy data writing by utilizing the + ``TrackedCheckpoint.commit()`` API, which is only invoked if the checkpoint + should be persisted to disk. + + Args: + checkpoint_strategy: Checkpoint strategy defining how many and which + checkpoints to keep. + latest_checkpoint_id: First checkpoint ID to use (e.g. in case we + continue training an existing experiment). + delete_fn: Function that takes a TrackedCheckpoint and deletes it from disk + or memory upon request. + + """ + + # If memory checkpoints should be persisted + _persist_memory_checkpoints: bool = False + + def __init__( + self, + checkpoint_strategy: CheckpointStrategy, + latest_checkpoint_id: int = 0, + delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, + ): + self._checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() + + # Incremental unique checkpoint ID of this run. + self._latest_checkpoint_id = latest_checkpoint_id + + # Used for keeping top K checkpoints. + self._top_persisted_checkpoints: List[_HeapCheckpointWrapper] = [] + + # Best checkpoint altogether. + # Used for exposing best_checkpoint_path. + self._best_persisted_checkpoint: Optional[_TrackedCheckpoint] = None + self._latest_persisted_checkpoint: Optional[_TrackedCheckpoint] = None + self._latest_memory_checkpoint: Optional[_TrackedCheckpoint] = None + + # Deletion of some checkpoints should be deferred. Specifically, if the + # latest persisted checkpoint should be deleted, we will only delete it + # once a new checkpoint came in (so that `_latest_persisted_checkpoint` is + # always available). + self._checkpoints_to_clean_up = set() + + self._delete_fn = delete_fn + + def set_delete_fn( + self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] + ): + """Update the function called to delete persisted checkpoints. + + Args: + delete_fn: Function that takes a tracked checkpoint as an argument and + deletes it from disk. + """ + self._delete_fn = delete_fn + + def register_checkpoint(self, checkpoint: _TrackedCheckpoint): + """Register new checkpoint and add to bookkeeping. + + This method will register a new checkpoint and add it to the internal + bookkeeping logic. This means the checkpoint manager will decide if + this checkpoint should be kept, and if older or worse performing + checkpoints should be deleted. + + Args: + checkpoint: Tracked checkpoint object to add to bookkeeping. + """ + checkpoint.id = checkpoint.id or self._latest_checkpoint_id + + if checkpoint.storage_mode == CheckpointStorage.MEMORY: + self._replace_latest_memory_checkpoint(checkpoint) + + if self._persist_memory_checkpoints: + persisted_checkpoint = copy.copy(checkpoint) + persisted_checkpoint.storage_mode = CheckpointStorage.PERSISTENT + else: + persisted_checkpoint = None + else: + persisted_checkpoint = checkpoint + + if persisted_checkpoint and self._checkpoint_strategy.num_to_keep != 0: + self._process_persistent_checkpoint(persisted_checkpoint) + + self._latest_checkpoint_id += 1 + + def _replace_latest_memory_checkpoint(self, memory_checkpoint: _TrackedCheckpoint): + assert memory_checkpoint.storage_mode == CheckpointStorage.MEMORY + self._latest_memory_checkpoint = memory_checkpoint + # Avoid memory leaks on k8s pods + gc.collect() + + def _replace_latest_persisted_checkpoint( + self, persisted_checkpoint: _TrackedCheckpoint + ): + second_to_latest_persisted_checkpoint = self._latest_persisted_checkpoint + self._latest_persisted_checkpoint = persisted_checkpoint + + if self._checkpoint_strategy.num_to_keep == 0: + self._maybe_delete_persisted_checkpoint( + second_to_latest_persisted_checkpoint + ) + + def _maybe_replace_best_persisted_checkpoint( + self, persisted_checkpoint: _TrackedCheckpoint + ): + if self._best_persisted_checkpoint is None: + self._best_persisted_checkpoint = persisted_checkpoint + else: + old_score = self._get_checkpoint_score(self._best_persisted_checkpoint) + candidate_score = self._get_checkpoint_score(persisted_checkpoint) + if candidate_score >= old_score: + self._best_persisted_checkpoint = persisted_checkpoint + + def _get_checkpoint_score( + self, checkpoint: _TrackedCheckpoint + ) -> Tuple[bool, numbers.Number, int]: + checkpoint_score_attribute = ( + self._checkpoint_strategy.checkpoint_score_attribute + ) + if checkpoint_score_attribute not in checkpoint.metrics: + logger.error( + f"Result dict has no key: {checkpoint_score_attribute}. " + f"checkpoint_score_attr must be set to a key in the " + f"result dict. Valid keys are: {list(checkpoint.metrics.keys())}" + ) + checkpoint_result = float("-inf") + else: + checkpoint_result = checkpoint.metrics[checkpoint_score_attribute] + + checkpoint_score_order = self._checkpoint_strategy.checkpoint_score_order + if checkpoint_score_order == MAX: + order_factor = 1.0 + else: + order_factor = -1.0 + + checkpoint_score = order_factor * checkpoint_result + + if not isinstance(checkpoint_score, numbers.Number): + raise ValueError( + f"Unable to persist checkpoint for " + f"checkpoint_score_attribute: " + f"{checkpoint_score_attribute} with value " + f"{checkpoint_score}. " + f"This attribute must be numerical." + ) + + return ( + not is_nan(checkpoint_score), + checkpoint_score if not is_nan(checkpoint_score) else 0, + checkpoint.id, + ) + + def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): + assert checkpoint.storage_mode == CheckpointStorage.PERSISTENT + + checkpoint_score = self._get_checkpoint_score(checkpoint) + wrapped_checkpoint = _HeapCheckpointWrapper( + priority=checkpoint_score, tracked_checkpoint=checkpoint + ) + + if self._checkpoint_strategy.num_to_keep is None: + # Keep all checkpoints + checkpoint.commit(path=self._get_next_checkpoint_path()) + self._replace_latest_persisted_checkpoint(checkpoint) + self._top_persisted_checkpoints.append(wrapped_checkpoint) + elif ( + len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep + ): + # Heap is not full yet, so keep this checkpoint + checkpoint.commit(path=self._get_next_checkpoint_path()) + heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint) + self._replace_latest_persisted_checkpoint(checkpoint) + elif wrapped_checkpoint.priority >= self._top_persisted_checkpoints[0].priority: + # Priority is higher than current worst checkpoint, so replace worst + checkpoint.commit(path=self._get_next_checkpoint_path()) + worst_checkpoint = heapq.heappushpop( + self._top_persisted_checkpoints, wrapped_checkpoint + ).tracked_checkpoint + + # Only remove if checkpoint data is different + if worst_checkpoint.dir_or_data != checkpoint.dir_or_data: + self._maybe_delete_persisted_checkpoint(worst_checkpoint) + logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.") + + self._replace_latest_persisted_checkpoint(checkpoint) + else: + # If the latest checkpoint has the same or lower priority, skip it. + self._skip_persisted_checkpoint(checkpoint) + + self._maybe_replace_best_persisted_checkpoint(persisted_checkpoint=checkpoint) + self._cleanup_checkpoints() + + def _maybe_delete_persisted_checkpoint( + self, persisted_checkpoint: _TrackedCheckpoint + ): + if persisted_checkpoint == self._latest_persisted_checkpoint: + self._checkpoints_to_clean_up.add(persisted_checkpoint) + else: + self._delete_persisted_checkpoint(persisted_checkpoint=persisted_checkpoint) + + def _delete_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + persisted_checkpoint.delete(delete_fn=self._delete_fn) + self._checkpoints_to_clean_up.discard(persisted_checkpoint) + + def _cleanup_checkpoints(self): + for checkpoint in list(self._checkpoints_to_clean_up): + self._maybe_delete_persisted_checkpoint(persisted_checkpoint=checkpoint) + + def _skip_persisted_checkpoint(self, persisted_checkpoint: _TrackedCheckpoint): + logger.debug(f"Skipping checkpoint due to low score: {persisted_checkpoint}.") + self._checkpoints_to_clean_up.add(persisted_checkpoint) + + def _get_next_checkpoint_path(self) -> Optional[Path]: + return None + + def __del__(self): + self._cleanup_checkpoints() + + def __getstate__(self): + state = self.__dict__.copy() + + # Do not serialize the delete fn + state.pop("_delete_fn", None) + + # Avoid serializing the memory checkpoint. + state["_newest_memory_checkpoint"] = _TrackedCheckpoint( + dir_or_data=None, + checkpoint_id=0, + storage_mode=CheckpointStorage.MEMORY, + ) + return state + + def __setstate__(self, state): + state["_delete_fn"] = None + self.__dict__.update(state) diff --git a/python/ray/util/ml_utils/tests/test_checkpoint_manager.py b/python/ray/util/ml_utils/tests/test_checkpoint_manager.py new file mode 100644 index 000000000000..16fd83a8ecb8 --- /dev/null +++ b/python/ray/util/ml_utils/tests/test_checkpoint_manager.py @@ -0,0 +1,95 @@ +import pytest +from ray.util.ml_utils.checkpoint_manager import ( + _CheckpointManager, + CheckpointStorage, + CheckpointStrategy, + _TrackedCheckpoint, +) + + +def test_unlimited_persistent_checkpoints(): + cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=None)) + + for i in range(10): + cpm.register_checkpoint( + _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT) + ) + + assert len(cpm._top_persisted_checkpoints) == 10 + + +def test_limited_persistent_checkpoints(): + cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=2)) + + for i in range(10): + cpm.register_checkpoint( + _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT) + ) + + assert len(cpm._top_persisted_checkpoints) == 2 + + +def test_no_persistent_checkpoints(): + cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=0)) + + for i in range(10): + cpm.register_checkpoint( + _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT) + ) + + assert len(cpm._top_persisted_checkpoints) == 0 + + +def test_dont_persist_memory_checkpoints(): + cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=None)) + cpm._persist_memory_checkpoints = False + + for i in range(10): + cpm.register_checkpoint( + _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY) + ) + + assert len(cpm._top_persisted_checkpoints) == 0 + + +def test_persist_memory_checkpoints(): + cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy(num_to_keep=None)) + cpm._persist_memory_checkpoints = True + + for i in range(10): + cpm.register_checkpoint( + _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY) + ) + + assert len(cpm._top_persisted_checkpoints) == 10 + + +def test_keep_best_checkpoints(): + cpm = _CheckpointManager( + checkpoint_strategy=CheckpointStrategy( + num_to_keep=2, + checkpoint_score_attribute="metric", + checkpoint_score_order="min", + ) + ) + cpm._persist_memory_checkpoints = True + + for i in range(10): + cpm.register_checkpoint( + _TrackedCheckpoint( + {"data": i}, + storage_mode=CheckpointStorage.MEMORY, + metrics={"metric": i}, + ) + ) + + # Sorted from worst (max) to best (min) + assert [ + cp.tracked_checkpoint.metrics["metric"] for cp in cpm._top_persisted_checkpoints + ] == [1, 0] + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/api.py b/python/ray/workflow/api.py index cb2b8ec693fa..87962e941832 100644 --- a/python/ray/workflow/api.py +++ b/python/ray/workflow/api.py @@ -286,7 +286,8 @@ def list_all( f" {status_filter}" ) elif status_filter is None: - status_filter = set(WorkflowStatus.__members__.keys()) + status_filter = set(WorkflowStatus) + status_filter.discard(WorkflowStatus.NONE) else: raise TypeError( "status_filter must be WorkflowStatus or a set of WorkflowStatus." diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 82407ccecbbe..8345ba728fee 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -140,6 +140,8 @@ def __reduce__(self): @PublicAPI(stability="beta") @unique class WorkflowStatus(str, Enum): + # No status is set for this workflow. + NONE = "NONE" # There is at least a remote task running in ray cluster RUNNING = "RUNNING" # It got canceled and can't be resumed later. diff --git a/python/ray/workflow/execution.py b/python/ray/workflow/execution.py index 45887b854934..7e8b00d0846e 100644 --- a/python/ray/workflow/execution.py +++ b/python/ray/workflow/execution.py @@ -10,7 +10,6 @@ from ray.workflow.common import ( Workflow, WorkflowStatus, - WorkflowMetaData, StepType, WorkflowNotFoundError, validate_user_metadata, @@ -128,10 +127,15 @@ def get_output(workflow_id: str, name: Optional[str]) -> ray.ObjectRef: def cancel(workflow_id: str) -> None: try: workflow_manager = get_management_actor() - ray.get(workflow_manager.cancel_workflow.remote(workflow_id)) except ValueError: wf_store = workflow_storage.get_workflow_storage(workflow_id) - wf_store.save_workflow_meta(WorkflowMetaData(WorkflowStatus.CANCELED)) + # TODO(suquark): Here we update workflow status "offline", so it is likely + # thread-safe because there is no workflow management actor updating the + # workflow concurrently. But we should be careful if we are going to + # update more workflow status offline in the future. + wf_store.update_workflow_status(WorkflowStatus.CANCELED) + return + ray.get(workflow_manager.cancel_workflow.remote(workflow_id)) def get_status(workflow_id: str) -> Optional[WorkflowStatus]: @@ -143,12 +147,12 @@ def get_status(workflow_id: str) -> Optional[WorkflowStatus]: if running: return WorkflowStatus.RUNNING store = workflow_storage.get_workflow_storage(workflow_id) - meta = store.load_workflow_meta() - if meta is None: + status = store.load_workflow_status() + if status == WorkflowStatus.NONE: raise WorkflowNotFoundError(workflow_id) - if meta.status == WorkflowStatus.RUNNING: + if status == WorkflowStatus.RUNNING: return WorkflowStatus.RESUMABLE - return meta.status + return status def get_metadata(workflow_id: str, name: Optional[str]) -> Dict[str, Any]: @@ -178,10 +182,24 @@ def list_all(status_filter: Set[WorkflowStatus]) -> List[Tuple[str, WorkflowStat runnings = set(runnings) # Here we don't have workflow id, so use empty one instead store = workflow_storage.get_workflow_storage("") + + exclude_running = False + if ( + WorkflowStatus.RESUMABLE in status_filter + and WorkflowStatus.RUNNING not in status_filter + ): + # Here we have to add "RUNNING" to the status filter, because some "RESUMABLE" + # workflows are converted from "RUNNING" workflows below. + exclude_running = True + status_filter.add(WorkflowStatus.RUNNING) + status_from_storage = store.list_workflow(status_filter) ret = [] - for (k, s) in store.list_workflow(): - if s == WorkflowStatus.RUNNING and k not in runnings: - s = WorkflowStatus.RESUMABLE + for (k, s) in status_from_storage: + if s == WorkflowStatus.RUNNING: + if k not in runnings: + s = WorkflowStatus.RESUMABLE + elif exclude_running: + continue if s in status_filter: ret.append((k, s)) return ret diff --git a/python/ray/workflow/tests/test_workflow_indexing.py b/python/ray/workflow/tests/test_workflow_indexing.py new file mode 100644 index 000000000000..d8a959caa393 --- /dev/null +++ b/python/ray/workflow/tests/test_workflow_indexing.py @@ -0,0 +1,83 @@ +import pytest + +from ray.workflow.common import WorkflowStatus +from ray.workflow.workflow_storage import WorkflowIndexingStorage + + +def test_workflow_status_update(workflow_start_regular): + # Test workflow status update is working. + store = WorkflowIndexingStorage() + assert not store.list_workflow() + for i in range(100): + assert store.load_workflow_status(workflow_id=str(i)) == WorkflowStatus.NONE + + for i in range(100): + store.update_workflow_status(str(i), WorkflowStatus.RUNNING) + + assert sorted(store.list_workflow()) == sorted( + [(str(i), WorkflowStatus.RUNNING) for i in range(100)] + ) + + assert sorted(store.list_workflow({WorkflowStatus.RUNNING})) == sorted( + [(str(i), WorkflowStatus.RUNNING) for i in range(100)] + ) + + assert sorted(store.list_workflow({WorkflowStatus.RESUMABLE})) == [] + + for i in range(100): + store.update_workflow_status(str(i), WorkflowStatus.RESUMABLE) + + assert sorted(store.list_workflow({WorkflowStatus.RESUMABLE})) == sorted( + [(str(i), WorkflowStatus.RESUMABLE) for i in range(100)] + ) + + assert sorted(store.list_workflow({WorkflowStatus.FAILED})) == [] + + for i in range(100): + store.update_workflow_status(str(i), WorkflowStatus.FAILED) + + assert sorted(store.list_workflow()) == sorted( + [(str(i), WorkflowStatus.FAILED) for i in range(100)] + ) + + assert sorted(store.list_workflow({WorkflowStatus.FAILED})) == sorted( + [(str(i), WorkflowStatus.FAILED) for i in range(100)] + ) + + assert sorted(store.list_workflow({WorkflowStatus.RUNNING})) == [] + + +def test_workflow_auto_fix_status(workflow_start_regular): + # Test workflow can recovery from corrupted status updating. + store = WorkflowIndexingStorage() + assert not store.list_workflow() + # this is a hack to crash status updating + _key_workflow_with_status = store._key_workflow_with_status + store._key_workflow_with_status = None + for i in range(100): + try: + store.update_workflow_status(str(i), WorkflowStatus.RUNNING) + except TypeError: + pass + + store._key_workflow_with_status = _key_workflow_with_status + + assert sorted(store.list_workflow()) == sorted( + [(str(i), WorkflowStatus.RUNNING) for i in range(100)] + ) + + for i in range(100): + try: + # when update workflow, we fix failed status + store.update_workflow_status(str(i), WorkflowStatus.RESUMABLE) + except TypeError: + pass + + for i in range(100): + assert store.load_workflow_status(str(i)) == WorkflowStatus.RESUMABLE + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/workflow_access.py b/python/ray/workflow/workflow_access.py index a075e79631cc..bb06b6b8756f 100644 --- a/python/ray/workflow/workflow_access.py +++ b/python/ray/workflow/workflow_access.py @@ -135,6 +135,7 @@ def __init__(self): self._step_output_cache: Dict[Tuple[str, str], LatestWorkflowOutput] = {} self._actor_initialized: Dict[str, ray.ObjectRef] = {} self._step_status: Dict[str, Dict[str, common.WorkflowStatus]] = {} + self._workflow_status: Dict[str, common.WorkflowStatus] = {} def get_cached_step_output( self, workflow_id: str, step_id: "StepID" @@ -194,9 +195,7 @@ def run_or_resume( ) self._step_output_cache[(workflow_id, step_id)] = latest_output - wf_store.save_workflow_meta( - common.WorkflowMetaData(common.WorkflowStatus.RUNNING) - ) + self._update_workflow_status(workflow_id, common.WorkflowStatus.RUNNING) if workflow_id not in self._step_status: self._step_status[workflow_id] = {} @@ -211,6 +210,11 @@ def gen_step_id(self, workflow_id: str, step_name: str) -> str: else: return f"{step_name}_{idx}" + def _update_workflow_status(self, workflow_id: str, status: common.WorkflowStatus): + wf_store = workflow_storage.WorkflowStorage(workflow_id) + wf_store.update_workflow_status(status) + self._workflow_status[workflow_id] = status + def update_step_status( self, workflow_id: str, @@ -233,30 +237,21 @@ def update_step_status( if status != common.WorkflowStatus.FAILED and remaining != 0: return - wf_store = workflow_storage.WorkflowStorage(workflow_id) - if status == common.WorkflowStatus.FAILED: if workflow_id in self._workflow_outputs: cancel_job(self._workflow_outputs.pop(workflow_id).output) - wf_store.save_workflow_meta( - common.WorkflowMetaData(common.WorkflowStatus.FAILED) - ) + self._update_workflow_status(workflow_id, common.WorkflowStatus.FAILED) self._step_status.pop(workflow_id) else: - wf_store.save_workflow_meta( - common.WorkflowMetaData(common.WorkflowStatus.SUCCESSFUL) - ) + self._update_workflow_status(workflow_id, common.WorkflowStatus.SUCCESSFUL) self._step_status.pop(workflow_id) - workflow_postrun_metadata = {"end_time": time.time()} - wf_store.save_workflow_postrun_metadata(workflow_postrun_metadata) + wf_store = workflow_storage.WorkflowStorage(workflow_id) + wf_store.save_workflow_postrun_metadata({"end_time": time.time()}) def cancel_workflow(self, workflow_id: str) -> None: self._step_status.pop(workflow_id) cancel_job(self._workflow_outputs.pop(workflow_id).output) - wf_store = workflow_storage.WorkflowStorage(workflow_id) - wf_store.save_workflow_meta( - common.WorkflowMetaData(common.WorkflowStatus.CANCELED) - ) + self._update_workflow_status(workflow_id, common.WorkflowStatus.CANCELED) def is_workflow_running(self, workflow_id: str) -> bool: return ( @@ -318,15 +313,15 @@ def get_output(self, workflow_id: str, name: Optional[str]) -> WorkflowStaticRef if workflow_id in self._workflow_outputs and name is None: return self._workflow_outputs[workflow_id].output wf_store = workflow_storage.WorkflowStorage(workflow_id) - meta = wf_store.load_workflow_meta() - if meta is None: + status = wf_store.load_workflow_status() + if status == common.WorkflowStatus.NONE: raise ValueError(f"No such workflow {workflow_id}") - if meta == common.WorkflowStatus.CANCELED: + if status == common.WorkflowStatus.CANCELED: raise ValueError(f"Workflow {workflow_id} is canceled") if name is None: # For resumable workflow, the workflow result is not ready. # It has to be resumed first. - if meta == common.WorkflowStatus.RESUMABLE: + if status == common.WorkflowStatus.RESUMABLE: raise ValueError( f"Workflow {workflow_id} is in resumable status, " "please resume it" diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index 2e7167c285ce..b057084e0619 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -5,7 +5,7 @@ import json import os -from typing import Dict, List, Optional, Any, Callable, Tuple, Union +from typing import Dict, List, Optional, Any, Callable, Tuple, Union, Set from dataclasses import dataclass import logging @@ -16,7 +16,6 @@ from ray.workflow.common import ( Workflow, StepID, - WorkflowMetaData, WorkflowStatus, WorkflowRef, WorkflowNotFoundError, @@ -51,6 +50,8 @@ WORKFLOW_PRERUN_METADATA = "pre_run_metadata.json" WORKFLOW_POSTRUN_METADATA = "post_run_metadata.json" WORKFLOW_PROGRESS = "progress.json" +WORKFLOW_STATUS_DIR = "__status__" +WORKFLOW_STATUS_DIRTY_DIR = "dirty" # Without this counter, we're going to scan all steps to get the number of # steps with a given name. This can be very expensive if there are too # many duplicates. @@ -86,6 +87,144 @@ def is_recoverable(self) -> bool: ) +class WorkflowIndexingStorage: + """Access and maintenance the indexing of workflow status. + + It runs a protocol that guarantees we can recover from any interrupted + status updating. This protocol is **not thread-safe** for updating the + status of the same workflow, currently it is executed by workflow management + actor with a single thread. + + Here is how the protocol works: + + Update the status of a workflow + 1. Load workflow status from workflow data. If it is the same as the new status, + return. + 2. Check if the workflow status updating is dirty. If it is, fix the + workflow status; otherwise, mark the workflow status updating dirty. + 3. Update status in the workflow metadata. + 4. Insert the workflow ID key in the status indexing directory of the new status. + 5. Delete the workflow ID key in the status indexing directory of + the previous status. + 6. Remove the workflow status updating dirty mark. + + Load a status of a workflow + 1. Read the status of the workflow from the workflow metadata. + 2. Return the status. + + List the status of all workflows + 1. Get status of all workflows by listing workflow ID keys in each workflow + status indexing directory. + 2. List all workflows with dirty updating status. Get their status from + workflow data. Override the status of the corresponding workflow. + 3. Return all the status. + """ + + def __init__(self): + self._storage = storage.get_client(WORKFLOW_ROOT) + + def update_workflow_status(self, workflow_id: str, status: WorkflowStatus): + """Update the status of the workflow. + Try fixing indexing if workflow status updating was marked dirty. + + This method is NOT thread-safe. It is handled by the workflow management actor. + """ + prev_status = self.load_workflow_status(workflow_id) + if prev_status != status: + # Try fixing indexing if workflow status updating was marked dirty. + if ( + self._storage.get_info(self._key_workflow_status_dirty(workflow_id)) + is not None + ): + # This means the previous status update failed. Fix it. + self._storage.put( + self._key_workflow_with_status(workflow_id, prev_status), b"" + ) + for s in WorkflowStatus: + if s != prev_status: + self._storage.delete( + self._key_workflow_with_status(workflow_id, s) + ) + else: + self._storage.put(self._key_workflow_status_dirty(workflow_id), b"") + # Transactional update of workflow status + self._storage.put( + self._key_workflow_metadata(workflow_id), + json.dumps({"status": status.value}).encode(), + ) + self._storage.put(self._key_workflow_with_status(workflow_id, status), b"") + if prev_status is not WorkflowStatus.NONE: + self._storage.delete( + self._key_workflow_with_status(workflow_id, prev_status) + ) + self._storage.delete(self._key_workflow_status_dirty(workflow_id)) + + def load_workflow_status(self, workflow_id: str): + """Load the committed workflow status.""" + raw_data = self._storage.get(self._key_workflow_metadata(workflow_id)) + if raw_data is not None: + metadata = json.loads(raw_data) + return WorkflowStatus(metadata["status"]) + return WorkflowStatus.NONE + + def list_workflow( + self, status_filter: Optional[Set[WorkflowStatus]] = None + ) -> List[Tuple[str, WorkflowStatus]]: + """List workflow status. Override status of the workflows whose status updating + were marked dirty with the workflow status from workflow metadata. + + Args: + status_filter: If given, only returns workflow with that status. This can + be a single status or set of statuses. + """ + if status_filter is None: + status_filter = set(WorkflowStatus) + status_filter.discard(WorkflowStatus.NONE) + elif not isinstance(status_filter, set): + raise TypeError("'status_filter' should either be 'None' or a set.") + elif WorkflowStatus.NONE in status_filter: + raise ValueError("'WorkflowStatus.NONE' is not a valid filter value.") + + results = {} + for status in status_filter: + try: + # empty string points the key to the dir + for p in self._storage.list(self._key_workflow_with_status("", status)): + workflow_id = p.base_name + results[workflow_id] = status + except FileNotFoundError: + pass + # Get "correct" status of workflows + try: + for p in self._storage.list(self._key_workflow_status_dirty("")): + workflow_id = p.base_name + # overwrite status + results.pop(workflow_id, None) + status = self.load_workflow_status(workflow_id) + if status in status_filter: + results[workflow_id] = status + except FileNotFoundError: + pass + return list(results.items()) + + def delete_workflow_status(self, workflow_id: str): + """Delete status indexing for the workflow.""" + for status in WorkflowStatus: + self._storage.delete(self._key_workflow_with_status(workflow_id, status)) + self._storage.delete(self._key_workflow_status_dirty(workflow_id)) + + def _key_workflow_with_status(self, workflow_id: str, status: WorkflowStatus): + """A key whose existence marks the status of the workflow.""" + return os.path.join(WORKFLOW_STATUS_DIR, status.value, workflow_id) + + def _key_workflow_status_dirty(self, workflow_id: str): + """A key marks the workflow status dirty, because it is under change.""" + return os.path.join(WORKFLOW_STATUS_DIR, WORKFLOW_STATUS_DIRTY_DIR, workflow_id) + + def _key_workflow_metadata(self, workflow_id: str): + return os.path.join(workflow_id, WORKFLOW_META) + + class WorkflowStorage: """Access workflow in storage. This is a higher-level abstraction, which does not care about the underlining storage implementation.""" @@ -96,6 +235,7 @@ def __init__(self, workflow_id: str): _ensure_workflow_initialized() self._storage = storage.get_client(os.path.join(WORKFLOW_ROOT, workflow_id)) + self._status_storage = WorkflowIndexingStorage() self._workflow_id = workflow_id def load_step_output(self, step_id: StepID) -> Any: @@ -545,45 +685,16 @@ def _load_workflow_metadata(): return _load_workflow_metadata() - def save_workflow_meta(self, metadata: WorkflowMetaData) -> None: - """Save the metadata of the current workflow. + def list_workflow( + self, status_filter: Optional[Set[WorkflowStatus]] = None + ) -> List[Tuple[str, WorkflowStatus]]: + """List all workflows matching a given status filter. Args: - metadata: WorkflowMetaData of the current workflow. - - Raises: - DataSaveError: if we fail to save the class body. - """ - - metadata = {"status": metadata.status.value} - self._put(self._key_workflow_metadata(), metadata, True) - - def load_workflow_meta(self) -> Optional[WorkflowMetaData]: - """Load the metadata of the current workflow. - - Returns: - The metadata of the current workflow. If it doesn't exist, - return None. + status_filter: If given, only returns workflow with that status. This can + be a single status or set of statuses. """ - - try: - metadata = self._get(self._key_workflow_metadata(), True) - return WorkflowMetaData(status=WorkflowStatus(metadata["status"])) - except KeyNotFoundError: - return None - - def _list_workflow(self) -> List[Tuple[str, WorkflowStatus]]: - results = [] - for workflow_id in self._scan("", ignore_errors=True): - try: - metadata = self._get(os.path.join(workflow_id, WORKFLOW_META), True) - results.append((workflow_id, WorkflowStatus(metadata["status"]))) - except KeyNotFoundError: - pass - return results - - def list_workflow(self) -> List[Tuple[str, WorkflowStatus]]: - return self._list_workflow() + return self._status_storage.list_workflow(status_filter) def advance_progress(self, finished_step_id: "StepID") -> None: """Save the latest progress of a workflow. This is used by a @@ -613,6 +724,7 @@ def get_latest_progress(self) -> "StepID": def delete_workflow(self) -> None: # TODO (Alex): There's a race condition here if someone tries to # start the workflow between these ops. + self._status_storage.delete_workflow_status(self._workflow_id) found = self._storage.delete_dir("") # TODO (Alex): Different file systems seem to have different # behavior when deleting a prefix that doesn't exist, so we may @@ -621,6 +733,17 @@ def delete_workflow(self) -> None: if not found: raise WorkflowNotFoundError(self._workflow_id) + def update_workflow_status(self, status: WorkflowStatus): + """Update the status of the workflow. + This method is NOT thread-safe. It is handled by the workflow management actor. + """ + self._status_storage.update_workflow_status(self._workflow_id, status) + + def load_workflow_status(self): + """Load workflow status. If we find the previous status updating failed, + fix it with redo-log transaction recovery.""" + return self._status_storage.load_workflow_status(self._workflow_id) + def _put(self, key: str, data: Any, is_json: bool = False) -> str: """Serialize and put an object in the object store. diff --git a/python/requirements.txt b/python/requirements.txt index 87dbd53f7dec..c0c482eff7e6 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -35,7 +35,7 @@ gym==0.19.0; python_version < '3.7' lz4 scikit-image pandas>=1.0.5; python_version < '3.7' -pandas>=1.2.0; python_version >= '3.7' +pandas>=1.3.0; python_version >= '3.7' scipy==1.4.1 tabulate tensorboardX >= 1.9 diff --git a/python/requirements/ml/requirements_upstream.txt b/python/requirements/ml/requirements_upstream.txt index bd437a4c2735..686335e6c1a6 100644 --- a/python/requirements/ml/requirements_upstream.txt +++ b/python/requirements/ml/requirements_upstream.txt @@ -4,6 +4,6 @@ ray_lightning==0.2.0 tune-sklearn==0.4.1 -xgboost_ray==0.1.8 -lightgbm_ray==0.1.3 +xgboost_ray==0.1.9 +lightgbm_ray==0.1.4 modin>=0.11.0; python_version >= '3.7' diff --git a/release/ray_release/cluster_manager/minimal.py b/release/ray_release/cluster_manager/minimal.py index 37feee9d5c21..c527e5e4e698 100644 --- a/release/ray_release/cluster_manager/minimal.py +++ b/release/ray_release/cluster_manager/minimal.py @@ -89,26 +89,22 @@ def build_cluster_env(self, timeout: float = 600.0): error_message = None config_json = None result = self.sdk.list_cluster_environment_builds(self.cluster_env_id) - for build in sorted(result.results, key=lambda b: b.created_at): - build_id = build.id - last_status = build.status - error_message = build.error_message - config_json = build.config_json + if not result or not result.results: + raise ClusterEnvBuildError(f"No build found for cluster env: {result}") - if build.status == "failed": - continue + build = sorted(result.results, key=lambda b: b.created_at)[-1] + build_id = build.id + last_status = build.status + error_message = build.error_message + config_json = build.config_json - elif build.status == "succeeded": - logger.info( - f"Link to cluster env build: " - f"{format_link(anyscale_cluster_env_build_url(build_id))}" - ) - self.cluster_env_build_id = build_id - return - else: - # If the build is neither failed nor succeeded, it is still - # going on - break + if last_status == "succeeded": + logger.info( + f"Link to succeeded cluster env build: " + f"{format_link(anyscale_cluster_env_build_url(build_id))}" + ) + self.cluster_env_build_id = build_id + return if last_status == "failed": logger.info(f"Previous cluster env build failed: {error_message}") @@ -123,13 +119,10 @@ def build_cluster_env(self, timeout: float = 600.0): build_id = result.result.id logger.info( - f"Link to cluster env build: " + f"Link to created cluster env build: " f"{format_link(anyscale_cluster_env_build_url(build_id))}" ) - if not build_id: - raise ClusterEnvBuildError("No build found for cluster env.") - # Build found but not failed/finished yet completed = False start_wait = time.time() diff --git a/release/ray_release/tests/test_cluster_manager.py b/release/ray_release/tests/test_cluster_manager.py index 29f90b51057f..5d1e989481f8 100644 --- a/release/ray_release/tests/test_cluster_manager.py +++ b/release/ray_release/tests/test_cluster_manager.py @@ -410,6 +410,37 @@ def testBuildClusterEnvPreBuildSucceeded(self): self.assertEqual(self.sdk.call_counter["list_cluster_environment_builds"], 1) self.assertEqual(len(self.sdk.call_counter), 1) + @patch("time.sleep", lambda *a, **kw: None) + def testBuildClusterEnvSelectLastBuild(self): + self.cluster_manager.set_cluster_env(self.cluster_env) + self.cluster_manager.cluster_env_id = "correct" + # (Second) build succeeded + self.cluster_manager.cluster_env_build_id = None + self.sdk.reset() + self.sdk.returns["list_cluster_environment_builds"] = APIDict( + results=[ + APIDict( + id="build_succeeded", + status="succeeded", + created_at=0, + error_message=None, + config_json={}, + ), + APIDict( + id="build_succeeded_2", + status="succeeded", + created_at=1, + error_message=None, + config_json={}, + ), + ] + ) + self.cluster_manager.build_cluster_env(timeout=600) + self.assertTrue(self.cluster_manager.cluster_env_build_id) + self.assertEqual(self.cluster_manager.cluster_env_build_id, "build_succeeded_2") + self.assertEqual(self.sdk.call_counter["list_cluster_environment_builds"], 1) + self.assertEqual(len(self.sdk.call_counter), 1) + @patch("time.sleep", lambda *a, **kw: None) def testBuildClusterBuildFails(self): self.cluster_manager.set_cluster_env(self.cluster_env) diff --git a/release/release_tests.yaml b/release/release_tests.yaml index cfde7d27a4f5..bdc52ba458a1 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -3425,13 +3425,11 @@ test_name: dataset_shuffle_random_shuffle_1tb test_suite: dataset_test - stable: true - frequency: nightly team: core cluster: cluster_env: shuffle/shuffle_app_config.yaml - cluster_compute: shuffle/shuffle_compute_large_scale.yaml + cluster_compute: shuffle/datasets_large_scale_compute_small_instances.yaml run: timeout: 7200 @@ -3448,50 +3446,6 @@ test_name: dataset_shuffle_sort_1tb test_suite: dataset_test - stable: true - - frequency: nightly - team: core - cluster: - cluster_env: shuffle/shuffle_app_config.yaml - cluster_compute: shuffle/shuffle_compute_large_scale.yaml - - run: - timeout: 7200 - script: python dataset/sort.py --num-partitions=1000 --partition-size=1e9 - wait_for_nodes: - num_nodes: 20 - type: sdk_command - file_manager: sdk - -- name: dataset_shuffle_random_shuffle_1tb_small_instances - group: core-dataset-tests - working_dir: nightly_tests - legacy: - test_name: dataset_shuffle_random_shuffle_1tb_small_instances - test_suite: dataset_test - - frequency: nightly - team: core - cluster: - cluster_env: shuffle/shuffle_app_config.yaml - cluster_compute: shuffle/datasets_large_scale_compute_small_instances.yaml - - run: - timeout: 7200 - script: python dataset/sort.py --num-partitions=1000 --partition-size=1e9 --shuffle - wait_for_nodes: - num_nodes: 20 - type: sdk_command - file_manager: sdk - -- name: dataset_shuffle_sort_1tb_small_instances - group: core-dataset-tests - working_dir: nightly_tests - legacy: - test_name: dataset_shuffle_sort_1tb_small_instances - test_suite: dataset_test - frequency: nightly team: core cluster: @@ -3536,8 +3490,6 @@ test_name: dataset_shuffle_push_based_sort_1tb test_suite: dataset_test - stable: false - frequency: nightly team: core cluster: @@ -3673,3 +3625,51 @@ type: sdk_command file_manager: sdk + +- name: chaos_dataset_shuffle_push_based_sort_1tb + group: core-dataset-tests + working_dir: nightly_tests + legacy: + test_name: chaos_dataset_shuffle_push_based_sort_1tb + test_suite: chaos_test + + stable: false + + frequency: nightly + team: core + cluster: + cluster_env: shuffle/shuffle_app_config.yaml + cluster_compute: shuffle/datasets_large_scale_compute_small_instances.yaml + + run: + timeout: 7200 + prepare: ' python setup_chaos.py --node-kill-interval 1200 --max-nodes-to-kill 3' + script: RAY_DATASET_PUSH_BASED_SHUFFLE=1 python dataset/sort.py --num-partitions=1000 --partition-size=1e9 + wait_for_nodes: + num_nodes: 20 + type: sdk_command + file_manager: sdk + +- name: chaos_dataset_shuffle_sort_1tb + group: core-dataset-tests + working_dir: nightly_tests + legacy: + test_name: chaos_dataset_shuffle_sort_1tb + test_suite: chaos_test + + stable: false + + frequency: nightly + team: core + cluster: + cluster_env: shuffle/shuffle_app_config.yaml + cluster_compute: shuffle/datasets_large_scale_compute_small_instances.yaml + + run: + timeout: 7200 + prepare: ' python setup_chaos.py --node-kill-interval 900 --max-nodes-to-kill 3' + script: python dataset/sort.py --num-partitions=1000 --partition-size=1e9 + wait_for_nodes: + num_nodes: 20 + type: sdk_command + file_manager: sdk diff --git a/rllib/BUILD b/rllib/BUILD index 2ebf81e9008f..afcf26ced1ad 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -379,6 +379,17 @@ py_test( args = ["--yaml-dir=tuned_examples/impala"] ) +# MADDPG +py_test( + name = "learning_tests_two_step_game_maddpg", + main = "tests/run_regression_tests.py", + tags = ["team:ml", "tf_only", "no_tf_eager_tracing", "learning_tests", "learning_tests_discrete"], + size = "large", + srcs = ["tests/run_regression_tests.py"], + data = ["tuned_examples/maddpg/two-step-game-maddpg.yaml"], + args = ["--yaml-dir=tuned_examples/maddpg", "--framework=tf"] +) + # Working, but takes a long time to learn (>15min). # Removed due to Higher API conflicts with Pytorch-Import tests ## MB-MPO @@ -729,7 +740,7 @@ py_test( py_test( name = "test_dreamer", tags = ["team:ml", "trainers_dir"], - size = "small", + size = "medium", srcs = ["algorithms/dreamer/tests/test_dreamer.py"] ) @@ -775,6 +786,14 @@ py_test( srcs = ["algorithms/marwil/tests/test_bc.py"] ) +# MADDPGTrainer +py_test( + name = "test_maddpg", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["algorithms/maddpg/tests/test_maddpg.py"] +) + # MAMLTrainer py_test( name = "test_maml", @@ -1388,6 +1407,19 @@ py_test( size = "small", srcs = ["evaluation/tests/test_episode.py"] ) +# -------------------------------------------------------------------- +# Execution Utils +# rllib/execution/ +# +# Tag: execution +# -------------------------------------------------------------------- + +py_test( + name = "test_async_requests_manager", + tags = ["team:ml", "execution"], + size = "small", + srcs = ["execution/tests/test_async_requests_manager.py"] +) # -------------------------------------------------------------------- # Models and Distributions @@ -1488,6 +1520,14 @@ py_test( srcs = ["policy/tests/test_sample_batch.py"] ) +py_test( + name = "policy/tests/test_view_requirement", + tags = ["team:ml", "policy"], + size = "small", + srcs = ["policy/tests/test_view_requirement.py"] +) + + # -------------------------------------------------------------------- # Utils: # rllib/utils/ @@ -1495,6 +1535,13 @@ py_test( # Tag: utils # -------------------------------------------------------------------- +py_test( + name = "test_serialization", + tags = ["team:ml", "utils"], + size = "large", + srcs = ["utils/tests/test_serialization.py"] +) + py_test( name = "test_curiosity", tags = ["team:ml", "utils"], @@ -2951,15 +2998,6 @@ py_test( args = ["--as-test", "--mixed-torch-tf", "--stop-reward=450.0"] ) -py_test( - name = "examples/two_step_game_maddpg", - main = "examples/two_step_game.py", - tags = ["team:ml", "examples", "examples_T"], - size = "medium", - srcs = ["examples/two_step_game.py"], - args = ["--as-test", "--stop-reward=7.1", "--run=MADDPG"] -) - py_test( name = "examples/two_step_game_pg_tf", main = "examples/two_step_game.py", diff --git a/rllib/README.rst b/rllib/README.rst index 439c4cb376f6..32acd186764a 100644 --- a/rllib/README.rst +++ b/rllib/README.rst @@ -105,7 +105,7 @@ Multi-agent: - `Single-Player Alpha Zero (contrib/AlphaZero) `__ - `Parameter Sharing `__ - `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN)) `__ -- `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) `__ +- `Multi-Agent Deep Deterministic Policy Gradient (MADDPG) `__ - `Shared Critic Methods `__ Others: diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index 43c8da4b63ba..703b37b83c66 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -6,7 +6,9 @@ from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.evaluation.rollout_worker import RolloutWorker -from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests +from ray.rllib.execution.parallel_requests import ( + AsyncRequestsManager, +) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated @@ -20,7 +22,11 @@ SYNCH_WORKER_WEIGHTS_TIMER, ) from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder -from ray.rllib.utils.typing import ResultDict, TrainerConfigDict +from ray.rllib.utils.typing import ( + ResultDict, + TrainerConfigDict, + PartialTrainerConfigDict, +) logger = logging.getLogger(__name__) @@ -153,6 +159,13 @@ class A3CTrainer(Trainer): def get_default_config(cls) -> TrainerConfigDict: return A3CConfig().to_dict() + @override(Trainer) + def setup(self, config: PartialTrainerConfigDict): + super().setup(config) + self._worker_manager = AsyncRequestsManager( + self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=1 + ) + @override(Trainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. @@ -194,13 +207,8 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: with self._timers[GRAD_WAIT_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. - async_results: Dict[ActorHandle, Dict] = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_requests_in_flight, - actors=self.workers.remote_workers(), - ray_wait_timeout_s=0.0, - max_remote_requests_in_flight_per_actor=1, - remote_fn=sample_and_compute_grads, - ) + self._worker_manager.call_on_all_available(sample_and_compute_grads) + async_results = self._worker_manager.get_ready() # Loop through all fetched worker-computed gradients (if any) # and apply them - one by one - to the local worker's model. @@ -243,6 +251,19 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: return learner_info_builder.finalize() + @override(Trainer) + def on_worker_failures( + self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle] + ): + """Handle failures on remote A3C workers. + + Args: + removed_workers: removed worker ids. + new_workers: ids of newly created workers. + """ + self._worker_manager.remove_workers(removed_workers) + self._worker_manager.add_workers(new_workers) + # Deprecated: Use ray.rllib.agents.a3c.A3CConfig instead! class _deprecated_default_config(dict): diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 07b57db6b5b6..3a80e09af0c4 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -12,12 +12,12 @@ from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.exploration.random_encoder import ( - MovingMeanStd, + _MovingMeanStd, compute_states_entropy, update_beta, ) from ray.rllib.utils.typing import AgentID, EnvType, PolicyID -from ray.tune.callback import CallbackMeta +from ray.tune.callback import _CallbackMeta # Import psutil after ray so the packaged version is used. import psutil @@ -28,7 +28,7 @@ @PublicAPI -class DefaultCallbacks(metaclass=CallbackMeta): +class DefaultCallbacks(metaclass=_CallbackMeta): """Abstract base class for RLlib callbacks (similar to Keras callbacks). These callbacks can be used for custom metrics and custom postprocessing. @@ -529,7 +529,7 @@ def __init__( self.beta = beta self.rho = rho self.beta_schedule = beta_schedule - self._rms = MovingMeanStd() + self._rms = _MovingMeanStd() super().__init__(*args, **kwargs) def on_learn_on_batch( diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 10e414a967a1..03418b21b1ce 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -16,10 +16,11 @@ import copy import platform import random -from typing import Dict, List, DefaultDict, Set +from typing import Dict, List, Type, Optional, Callable import ray from ray.actor import ActorHandle +from ray.rllib import Policy from ray.rllib.agents import Trainer from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer from ray.rllib.algorithms.dqn.learner_thread import LearnerThread @@ -29,9 +30,9 @@ STEPS_TRAINED_THIS_ITER_COUNTER, ) from ray.rllib.execution.parallel_requests import ( - asynchronous_parallel_requests, - wait_asynchronous_requests, + AsyncRequestsManager, ) +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE @@ -50,7 +51,6 @@ TrainerConfigDict, ResultDict, PartialTrainerConfigDict, - T, ) from ray.tune.trainable import Trainable from ray.tune.utils.placement_groups import PlacementGroupFactory @@ -139,6 +139,22 @@ def __init__(self, trainer_class=None): self.train_batch_size = 512 self.target_network_update_freq = 500000 self.training_intensity = 1 + + # max number of inflight requests to each sampling worker + # see the AsyncRequestsManager class for more details + # Tuning these values is important when running experimens with large sample + # batches. If the sample batches are large in size, then there is the risk that + # the object store may fill up, causing the store to spill objects to disk. + # This can cause any asynchronous requests to become very slow, making your + # experiment run slowly. You can inspect the object store during your + # experiment via a call to ray memory on your headnode, and by using the ray + # dashboard. If you're seeing that the object store is filling up, turn down + # the number of remote requests in flight, or enable compression in your + # experiment of timesteps. + self.max_requests_in_flight_per_sampler_worker = 2 + self.max_requests_in_flight_per_replay_worker = float("inf") + self.timeout_s_sampler_manager = 0.0 + self.timeout_s_replay_manager = 0.0 # APEX-DQN is using a distributed (non local) replay buffer. self.replay_buffer_config = { "no_local_replay_buffer": True, @@ -146,7 +162,6 @@ def __init__(self, trainer_class=None): # prioritization "type": "MultiAgentPrioritizedReplayBuffer", "capacity": 2000000, - "replay_batch_size": 32, # Alpha parameter for prioritized replay buffer. "prioritized_replay_alpha": 0.6, # Beta parameter for sampling from prioritized replay buffer. @@ -185,6 +200,137 @@ def __init__(self, trainer_class=None): # fmt: on # __sphinx_doc_end__ + def training( + self, + *, + num_atoms: Optional[int] = None, + v_min: Optional[float] = None, + v_max: Optional[float] = None, + noisy: Optional[bool] = None, + sigma0: Optional[float] = None, + dueling: Optional[bool] = None, + hiddens: Optional[int] = None, + double_q: Optional[bool] = None, + n_step: Optional[int] = None, + before_learn_on_batch: Callable[ + [Type[MultiAgentBatch], List[Type[Policy]], Type[int]], + Type[MultiAgentBatch], + ] = None, + training_intensity: Optional[float] = None, + replay_buffer_config: Optional[dict] = None, + max_requests_in_flight_per_sampler_worker: Optional[int] = None, + max_requests_in_flight_per_replay_worker: Optional[int] = None, + timeout_s_sampler_manager: Optional[float] = None, + timeout_s_replay_manager: Optional[float] = None, + **kwargs, + ) -> "ApexConfig": + """Sets the training related configuration. + + Args: + num_atoms: Number of atoms for representing the distribution of return. + When this is greater than 1, distributional Q-learning is used. + v_min: Minimum value estimation + v_max: Maximum value estimation + noisy: Whether to use noisy network to aid exploration. This adds + parametric noise to the model weights. + sigma0: Control the initial parameter noise for noisy nets. + dueling: Whether to use dueling DQN policy. + hiddens: Dense-layer setup for each the advantage branch and the value + branch + double_q: Whether to use double DQN for the policy. + n_step: N-step for Q-learning. + before_learn_on_batch: Callback to run before learning on a multi-agent + batch of experiences. + training_intensity: The ratio of timesteps to train on for every + timestep that is sampled. This must be greater than 0. + replay_buffer_config: Replay buffer config. + Examples: + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentReplayBuffer", + "learning_starts": 1000, + "capacity": 50000, + "replay_batch_size": 32, + "replay_sequence_length": 1, + } + - OR - + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": 50000, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + "replay_sequence_length": 1, + } + - Where - + prioritized_replay_alpha: Alpha parameter controls the degree of + prioritization in the buffer. In other words, when a buffer sample has + a higher temporal-difference error, with how much more probability + should it drawn to use to update the parametrized Q-network. 0.0 + corresponds to uniform probability. Setting much above 1.0 may quickly + result as the sampling distribution could become heavily “pointy” with + low entropy. + prioritized_replay_beta: Beta parameter controls the degree of + importance sampling which suppresses the influence of gradient updates + from samples that have higher probability of being sampled via alpha + parameter and the temporal-difference error. + prioritized_replay_eps: Epsilon parameter sets the baseline probability + for sampling so that when the temporal-difference error of a sample is + zero, there is still a chance of drawing the sample. + max_requests_in_flight_per_sampler_worker: Level of queuing for sampling + operations. + max_requests_in_flight_per_replay_worker: Level of queuing for replay + aggregator operations (if using aggregator workers). + timeout_s_sampler_manager: The timeout for waiting for sampling results + for workers -- typically if this is too low, the manager won't be able + to retrieve ready sampling results. + timeout_s_replay_manager: The timeout for waiting for replay worker + results -- typically if this is too low, the manager won't be able to + retrieve ready replay requests. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if num_atoms is not None: + self.num_atoms = num_atoms + if v_min is not None: + self.v_min = v_min + if v_max is not None: + self.v_max = v_max + if noisy is not None: + self.noisy = noisy + if sigma0 is not None: + self.sigma0 = sigma0 + if dueling is not None: + self.dueling = dueling + if hiddens is not None: + self.hiddens = hiddens + if double_q is not None: + self.double_q = double_q + if n_step is not None: + self.n_step = n_step + if before_learn_on_batch is not None: + self.before_learn_on_batch = before_learn_on_batch + if training_intensity is not None: + self.training_intensity = training_intensity + if replay_buffer_config is not None: + self.replay_buffer_config = replay_buffer_config + if max_requests_in_flight_per_sampler_worker is not None: + self.max_requests_in_flight_per_sampler_worker = ( + max_requests_in_flight_per_sampler_worker + ) + if max_requests_in_flight_per_replay_worker is not None: + self.max_requests_in_flight_per_replay_worker = ( + max_requests_in_flight_per_replay_worker + ) + if timeout_s_sampler_manager is not None: + self.timeout_s_sampler_manager = timeout_s_sampler_manager + if timeout_s_replay_manager is not None: + self.timeout_s_replay_manager = timeout_s_replay_manager + + return self + class ApexTrainer(DQNTrainer): @override(Trainable) @@ -216,7 +362,7 @@ def setup(self, config: PartialTrainerConfigDict): # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if replay_actor_config["replay_buffer_shards_colocated_with_driver"]: - self.replay_actors = create_colocated_actors( + self._replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) ( ReplayActor, @@ -231,21 +377,29 @@ def setup(self, config: PartialTrainerConfigDict): ] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: - self.replay_actors = [ + self._replay_actors = [ ReplayActor.remote(*replay_actor_config) for _ in range(num_replay_buffer_shards) ] + self._replay_actor_manager = AsyncRequestsManager( + self._replay_actors, + max_remote_requests_in_flight_per_worker=self.config[ + "max_requests_in_flight_per_replay_worker" + ], + ray_wait_timeout_s=self.config["timeout_s_replay_manager"], + ) + self._sampling_actor_manager = AsyncRequestsManager( + self.workers.remote_workers(), + max_remote_requests_in_flight_per_worker=self.config[ + "max_requests_in_flight_per_sampler_worker" + ], + ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], + ) self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) - self.remote_sampling_requests_in_flight: DefaultDict[ - ActorHandle, Set[ray.ObjectRef] - ] = defaultdict(set) - self.remote_replay_requests_in_flight: DefaultDict[ - ActorHandle, Set[ray.ObjectRef] - ] = defaultdict(set) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0 @@ -261,9 +415,6 @@ def validate_config(self, config): raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!") # Call DQN's validation method. super().validate_config(config) - # if config["_disable_execution_plan_api"]: - # if not config.get("training_intensity", 1.0) > 0: - # raise ValueError("training_intensity must be > 0") @override(Trainable) def training_iteration(self) -> ResultDict: @@ -295,7 +446,7 @@ def get_samples_and_store_to_replay_buffers(self): with self._timers[SAMPLE_TIMER]: local_sampling_worker = self.workers.local_worker() batch = local_sampling_worker.sample() - actor = random.choice(self.replay_actors) + actor = random.choice(self._replay_actors) ray.get(actor.add.remote(batch)) batch_statistics = { local_sampling_worker: [ @@ -327,19 +478,11 @@ def remote_worker_sample_and_store( # Sample and Store in the Replay Actors on the sampling workers. with self._timers[SAMPLE_TIMER]: - # Results are a mapping from ActorHandle (RolloutWorker) to their - # returned gradient calculation results. - num_samples_ready_dict: Dict[ - ActorHandle, T - ] = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_sampling_requests_in_flight, - actors=self.workers.remote_workers(), - ray_wait_timeout_s=0.1, - max_remote_requests_in_flight_per_actor=4, - remote_fn=remote_worker_sample_and_store, - remote_kwargs=[{"replay_actors": self.replay_actors}] - * len(self.workers.remote_workers()), + self._sampling_actor_manager.call_on_all_available( + remote_worker_sample_and_store, + fn_kwargs={"replay_actors": self._replay_actors}, ) + num_samples_ready_dict = self._sampling_actor_manager.get_ready() return num_samples_ready_dict def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: @@ -389,23 +532,20 @@ def sample_from_replay_buffer_place_on_learner_queue_non_blocking( """ - def wait_on_replay_actors(timeout: float) -> None: + def wait_on_replay_actors() -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ - replay_samples_ready: Dict[ActorHandle, T] = wait_asynchronous_requests( - remote_requests_in_flight=self.remote_replay_requests_in_flight, - ray_wait_timeout_s=timeout, - ) + _replay_samples_ready = self._replay_actor_manager.get_ready() - for replay_actor, sample_batches in replay_samples_ready.items(): - for sample_batch in sample_batches: - self.replay_sample_batches.append((replay_actor, sample_batch)) + for _replay_actor, _sample_batches in _replay_samples_ready.items(): + for _sample_batch in _sample_batches: + self.replay_sample_batches.append((_replay_actor, _sample_batch)) num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected + wait_on_replay_actors() if self.curr_num_samples_collected >= self.config["train_batch_size"]: - wait_on_replay_actors(None) training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( self.curr_num_samples_collected / self.config["train_batch_size"] @@ -413,22 +553,11 @@ def wait_on_replay_actors(timeout: float) -> None: num_requests_to_launch = max(1, round(num_requests_to_launch)) self.curr_num_samples_collected = 0 for _ in range(num_requests_to_launch): - rand_actor = random.choice(self.replay_actors) - replay_samples_ready: Dict[ - ActorHandle, T - ] = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_replay_requests_in_flight, - actors=[rand_actor], - ray_wait_timeout_s=0.1, - max_remote_requests_in_flight_per_actor=num_requests_to_launch, - remote_args=[[self.config["train_batch_size"]]], - remote_fn=lambda actor, num_items: actor.sample(num_items), + self._replay_actor_manager.call( + lambda actor, num_items: actor.sample(num_items), + fn_args=[self.config["train_batch_size"]], ) - for replay_actor, sample_batches in replay_samples_ready.items(): - for sample_batch in sample_batches: - self.replay_sample_batches.append((replay_actor, sample_batch)) - - wait_on_replay_actors(0.1) + wait_on_replay_actors() # add the sample batches to the learner queue while self.replay_sample_batches: @@ -495,13 +624,26 @@ def update_target_networks(self, num_new_trained_samples) -> None: STEPS_TRAINED_COUNTER ] + @override(Trainer) + def on_worker_failures( + self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle] + ): + """Handle the failures of remote sampling workers + + Args: + removed_workers: removed worker ids. + new_workers: ids of newly created workers. + """ + self._sampling_actor_manager.remove_workers(removed_workers) + self._sampling_actor_manager.add_workers(new_workers) + @override(Trainer) def _compile_step_results(self, *, step_ctx, step_attempt_results=None): result = super()._compile_step_results( step_ctx=step_ctx, step_attempt_results=step_attempt_results ) replay_stats = ray.get( - self.replay_actors[0].stats.remote(self.config["optimizer"].get("debug")) + self._replay_actors[0].stats.remote(self.config["optimizer"].get("debug")) ) exploration_infos_list = self.workers.foreach_policy_to_train( lambda p, pid: {pid: p.get_exploration_state()} diff --git a/rllib/agents/dqn/r2d2.py b/rllib/agents/dqn/r2d2.py index 9f7d99734ede..f02e88a8de57 100644 --- a/rllib/agents/dqn/r2d2.py +++ b/rllib/agents/dqn/r2d2.py @@ -90,8 +90,9 @@ def __init__(self, trainer_class=None): self.adam_epsilon = 1e-3 self.lr = 1e-4 self.gamma = 0.997 - self.train_batch_size = 64 + self.train_batch_size = 1000 self.target_network_update_freq = 2500 + self.training_intensity = 1000 # R2D2 is using a buffer that stores sequences. self.replay_buffer_config = { "type": "MultiAgentReplayBuffer", @@ -100,6 +101,10 @@ def __init__(self, trainer_class=None): "prioritized_replay": DEPRECATED_VALUE, # Size of the replay buffer (in sequences, not timesteps). "capacity": 100000, + # This algorithm learns on sequences. We therefore require the replay buffer + # to slice sampled batches into sequences before replay. How sequences + # are sliced depends on the parameters `replay_sequence_length`, + # `replay_burn_in`, and `replay_zero_init_states`. "storage_unit": "sequences", # Set automatically: The number # of contiguous environment steps to diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index 7bcac140dfa7..8af348b5deeb 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -11,7 +11,6 @@ postprocess_nstep_and_prio, ) from ray.rllib.algorithms.dqn.dqn_tf_policy import build_q_model -from ray.rllib.algorithms.dqn.simple_q_tf_policy import TargetNetworkMixin from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import Categorical @@ -19,7 +18,10 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import LearningRateSchedule +from ray.rllib.policy.tf_mixins import ( + LearningRateSchedule, + TargetNetworkMixin, +) from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import huber_loss from ray.rllib.utils.typing import ModelInputDict, TensorType, TrainerConfigDict diff --git a/rllib/agents/dqn/r2d2_torch_policy.py b/rllib/agents/dqn/r2d2_torch_policy.py index 02a04b8ab0ac..5da35180b1e5 100644 --- a/rllib/agents/dqn/r2d2_torch_policy.py +++ b/rllib/agents/dqn/r2d2_torch_policy.py @@ -14,13 +14,15 @@ compute_q_values, ) from ray.rllib.agents.dqn.r2d2_tf_policy import get_distribution_inputs_and_class -from ray.rllib.algorithms.dqn.simple_q_torch_policy import TargetNetworkMixin from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_mixins import LearningRateSchedule +from ray.rllib.policy.torch_mixins import ( + LearningRateSchedule, + TargetNetworkMixin, +) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import ( apply_grad_clipping, diff --git a/rllib/agents/dqn/tests/test_apex_dqn.py b/rllib/agents/dqn/tests/test_apex_dqn.py index 5a428fde224f..79b7f79d70c6 100644 --- a/rllib/agents/dqn/tests/test_apex_dqn.py +++ b/rllib/agents/dqn/tests/test_apex_dqn.py @@ -112,7 +112,6 @@ def test_apex_lr_schedule(self): "type": "MultiAgentPrioritizedReplayBuffer", "learning_starts": 10, "capacity": 100, - "replay_batch_size": 10, "prioritized_replay_alpha": 0.6, # Beta parameter for sampling from prioritized replay buffer. "prioritized_replay_beta": 0.4, @@ -146,7 +145,7 @@ def _step_n_times(trainer, n: int): for _ in framework_iterator(config): trainer = config.build(env="CartPole-v0") - lr = _step_n_times(trainer, 5) # 50 timesteps + lr = _step_n_times(trainer, 3) # 50 timesteps # Close to 0.2 self.assertGreaterEqual(lr, 0.1) diff --git a/rllib/agents/dqn/tests/test_r2d2.py b/rllib/agents/dqn/tests/test_r2d2.py index 5d7834474711..82145963ccc4 100644 --- a/rllib/agents/dqn/tests/test_r2d2.py +++ b/rllib/agents/dqn/tests/test_r2d2.py @@ -26,13 +26,21 @@ def check_batch_sizes(train_results): configured_b = train_results["config"]["train_batch_size"] actual_b = policy_stats["td_error"].shape[0] if (configured_b - actual_b) / actual_b > 0.1: - assert ( - configured_b - / ( - train_results["config"]["model"]["max_seq_len"] - + train_results["config"]["replay_buffer_config"]["replay_burn_in"] + # Since R2D2 learns on sequences of a fixed length but with variable + # amount of timesteps that are padded, the batch size is almost never the + # `train_batch_size`, which is specified in timesteps, but close to it. + assert 0.8 < ( + abs( + configured_b + / ( + train_results["config"]["model"]["max_seq_len"] + + train_results["config"]["replay_buffer_config"][ + "replay_burn_in" + ] + ) + / actual_b ) - == actual_b + < 1.2 ) diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index 3ba2db7b1d21..34d4172b3780 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -1,23 +1,19 @@ import copy import logging import platform -import random -from collections import defaultdict import queue -from typing import Optional, Type, List, Dict, Union, DefaultDict, Set, Callable, Any +from typing import Optional, Type, List, Dict, Union, Callable, Any import ray from ray.actor import ActorHandle from ray.rllib import SampleBatch -from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, TrainerConfig from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread from ray.rllib.execution.parallel_requests import ( - asynchronous_parallel_requests, - wait_asynchronous_requests, + AsyncRequestsManager, ) from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation from ray.rllib.execution.common import ( @@ -112,9 +108,10 @@ def __init__(self, trainer_class=None): self.replay_buffer_num_slots = 0 self.learner_queue_size = 16 self.learner_queue_timeout = 300 - self.max_sample_requests_in_flight_per_worker = 2 - self.aggregator_wait_timeout = 0.03 - self.sample_wait_timeout = 0.03 + self.max_requests_in_flight_per_sampler_worker = 2 + self.max_requests_in_flight_per_aggregator_worker = 2 + self.timeout_s_sampler_manager = 0.0 + self.timeout_s_aggregator_manager = 0.0 self.broadcast_interval = 1 self.num_aggregation_workers = 0 self.grad_clip = 40.0 @@ -158,9 +155,10 @@ def training( replay_buffer_num_slots: Optional[int] = None, learner_queue_size: Optional[int] = None, learner_queue_timeout: Optional[float] = None, - max_sample_requests_in_flight_per_worker: Optional[int] = None, - aggregator_wait_timeout: Optional[float] = None, - sample_wait_timeout: Optional[float] = None, + max_requests_in_flight_per_sampler_worker: Optional[int] = None, + max_requests_in_flight_per_aggregator_worker: Optional[int] = None, + timeout_s_sampler_manager: Optional[float] = None, + timeout_s_aggregator_manager: Optional[float] = None, broadcast_interval: Optional[int] = None, num_aggregation_workers: Optional[int] = None, grad_clip: Optional[float] = None, @@ -216,12 +214,16 @@ def training( learner_queue_timeout: Wait for train batches to be available in minibatch buffer queue this many seconds. This may need to be increased e.g. when training with a slow environment. - max_sample_requests_in_flight_per_worker: Level of queuing for sampling - and replay aggregator operations (if using aggregator workers). - aggregator_wait_timeout: Amount of time to block and wait on pending calls - to replay aggregator workers. - sample_wait_timeout: Amount of time to block and wait on pending calls to - sampling workers. + max_requests_in_flight_per_sampler_worker: Level of queuing for sampling + operations. + max_requests_in_flight_per_aggregator_worker: Level of queuing for replay + aggregator operations (if using aggregator workers). + timeout_s_sampler_manager: The timeout for waiting for sampling results + for workers -- typically if this is too low, the manager won't be able + to retrieve ready sampling results. + timeout_s_aggregator_manager: The timeout for waiting for replay worker + results -- typically if this is too low, the manager won't be able to + retrieve ready replay requests. broadcast_interval: Max number of workers to broadcast one set of weights to. num_aggregation_workers: Use n (`num_aggregation_workers`) extra Actors for @@ -251,6 +253,18 @@ def training( after_train_step: Callback for APPO to use to update KL, target network periodically. The input to the callback is the learner fetches dict. + Note: + Tuning max_requests_in_flight_per_sampler_worker and + max_requests_in_flight_per_aggregator_worker is important when running + experiments with large sample batches. If the sample batches are large in + size, then there is the risk that the object store may fill up, causing + the store to spill sample batches to disk. This can cause any asynchronous + requests to become very slow, making your experiment run slowly. You can + inspect the object store during your experiment via a call to ray memory + on your headnode, and by using the ray dashboard. If you're seeing that + the object store is filling up, turn down the number of remote requests + in flight, or enable compression in your experiment of timesteps. + Returns: This updated TrainerConfig object. """ @@ -279,18 +293,22 @@ def training( self.learner_queue_size = learner_queue_size if learner_queue_timeout is not None: self.learner_queue_timeout = learner_queue_timeout - if max_sample_requests_in_flight_per_worker is not None: - self.max_sample_requests_in_flight_per_worker = ( - max_sample_requests_in_flight_per_worker - ) - if aggregator_wait_timeout is not None: - self.aggregator_wait_timeout = aggregator_wait_timeout - if sample_wait_timeout is not None: - self.sample_wait_timeout = sample_wait_timeout if broadcast_interval is not None: self.broadcast_interval = broadcast_interval if num_aggregation_workers is not None: self.num_aggregation_workers = num_aggregation_workers + if max_requests_in_flight_per_sampler_worker is not None: + self.max_requests_in_flight_per_sampler_worker = ( + max_requests_in_flight_per_sampler_worker + ) + if max_requests_in_flight_per_aggregator_worker is not None: + self.max_requests_in_flight_per_aggregator_worker = ( + max_requests_in_flight_per_aggregator_worker + ) + if timeout_s_sampler_manager is not None: + self.timeout_s_sampler_manager = timeout_s_sampler_manager + if timeout_s_aggregator_manager is not None: + self.timeout_s_aggregator_manager = timeout_s_aggregator_manager if grad_clip is not None: self.grad_clip = grad_clip if opt_type is not None: @@ -364,7 +382,7 @@ def gather_experiences_directly(workers, config): rollouts = ParallelRollouts( workers, mode="async", - num_async=config["max_sample_requests_in_flight_per_worker"], + num_async=config["max_requests_in_flight_per_sampler_worker"], ) # Augment with replay and concat to desired train batch size. @@ -449,9 +467,22 @@ def get_default_policy_class( from ray.rllib.agents.a3c.a3c_torch_policy import A3CTorchPolicy return A3CTorchPolicy + elif config["framework"] == "tf": + if config["vtrace"]: + from ray.rllib.agents.impala.vtrace_tf_policy import ( + VTraceStaticGraphTFPolicy, + ) + + return VTraceStaticGraphTFPolicy + else: + from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy + + return A3CTFPolicy else: if config["vtrace"]: - return VTraceTFPolicy + from ray.rllib.agents.impala.vtrace_tf_policy import VTraceEagerTFPolicy + + return VTraceEagerTFPolicy else: from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy @@ -508,9 +539,6 @@ def validate_config(self, config): @override(Trainer) def setup(self, config: PartialTrainerConfigDict): super().setup(config) - self.remote_sampling_requests_in_flight: DefaultDict[ - ActorHandle, Set[ray.ObjectRef] - ] = defaultdict(set) if self.config["_disable_execution_plan_api"]: # Create extra aggregation workers and assign each rollout worker to @@ -541,12 +569,16 @@ def setup(self, config: PartialTrainerConfigDict): ], node=localhost, ) - self.aggregator_workers = [ + self._aggregator_workers = [ actor for actor_groups in all_co_located for actor in actor_groups ] - self.remote_aggregator_requests_in_flight: DefaultDict[ - ActorHandle, Set[ray.ObjectRef] - ] = defaultdict(set) + self._aggregator_actor_manager = AsyncRequestsManager( + self._aggregator_workers, + max_remote_requests_in_flight_per_worker=self.config[ + "max_requests_in_flight_per_aggregator_worker" + ], + ray_wait_timeout_s=self.config["timeout_s_aggregator_manager"], + ) else: # Create our local mixin buffer if the num of aggregation workers is 0. @@ -559,6 +591,15 @@ def setup(self, config: PartialTrainerConfigDict): replay_ratio=self.config["replay_ratio"], ) + self._sampling_actor_manager = AsyncRequestsManager( + self.workers.remote_workers(), + max_remote_requests_in_flight_per_worker=self.config[ + "max_requests_in_flight_per_sampler_worker" + ], + return_object_refs=True, + ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], + ) + # Create and start the learner thread. self._learner_thread = make_learner_thread( self.workers.local_worker(), self.config @@ -721,17 +762,12 @@ def aggregate_into_larger_batch(): def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]: # Perform asynchronous sampling on all (remote) rollout workers. if self.workers.remote_workers(): + self._sampling_actor_manager.call_on_all_available( + lambda worker: worker.sample() + ) sample_batches: Dict[ ActorHandle, List[ObjectRef] - ] = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_requests_in_flight, - actors=self.workers.remote_workers(), - ray_wait_timeout_s=self.config["sample_wait_timeout"], - max_remote_requests_in_flight_per_actor=self.config[ - "max_sample_requests_in_flight_per_worker" - ], - return_result_obj_ref_ids=True, - ) + ] = self._sampling_actor_manager.get_ready() else: # only sampling on the local worker sample_batches = { @@ -812,26 +848,13 @@ def process_experiences_tree_aggregation( ] ready_processed_batches = [] for batch in batches: - aggregator = random.choice(self.aggregator_workers) - processed_sample_batches: Dict[ - ActorHandle, List[ObjectRef] - ] = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_aggregator_requests_in_flight, - actors=[aggregator], - remote_fn=lambda actor, b: actor.process_episodes(b), - remote_kwargs=[{"b": batch}], - ray_wait_timeout_s=self.config["aggregator_wait_timeout"], - max_remote_requests_in_flight_per_actor=float("inf"), + self._aggregator_actor_manager.call( + lambda actor, b: actor.process_episodes(b), fn_kwargs={"b": batch} ) - for ready_sub_batches in processed_sample_batches.values(): - ready_processed_batches.extend(ready_sub_batches) waiting_processed_sample_batches: Dict[ ActorHandle, List[ObjectRef] - ] = wait_asynchronous_requests( - remote_requests_in_flight=self.remote_aggregator_requests_in_flight, - ray_wait_timeout_s=self.config["aggregator_wait_timeout"], - ) + ] = self._aggregator_actor_manager.get_ready() for ready_sub_batches in waiting_processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) @@ -859,6 +882,19 @@ def update_workers_if_necessary(self) -> None: # Update global vars of the local worker. self.workers.local_worker().set_global_vars(global_vars) + @override(Trainer) + def on_worker_failures( + self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle] + ): + """Handle the failures of remote sampling workers + + Args: + removed_workers: removed worker ids. + new_workers: ids of newly created workers. + """ + self._sampling_actor_manager.remove_workers(removed_workers) + self._sampling_actor_manager.add_workers(new_workers) + @override(Trainer) def _compile_step_results(self, *, step_ctx, step_attempt_results=None): result = super()._compile_step_results( diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 9cccc2c44b2f..4a8d1281c80f 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -5,16 +5,25 @@ import numpy as np import logging import gym +from typing import Dict, List, Type, Union import ray from ray.rllib.agents.impala import vtrace_tf as vtrace -from ray.rllib.models.tf.tf_action_dist import Categorical +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_mixins import LearningRateSchedule, EntropyCoeffSchedule from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, +) tf1, tf, tfv = try_import_tf() @@ -158,197 +167,275 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False): return res -def build_vtrace_loss(policy, model, dist_class, train_batch): - model_out, _ = model(train_batch) - action_dist = dist_class(model_out, model) +class VTraceClipGradients: + """VTrace version of gradient computation logic.""" + + def __init__(self): + """No special initialization required.""" + pass + + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + # Supporting more than one loss/optimizer. + if self.config["_tf_policy_handles_more_than_one_loss"]: + optimizers = force_list(optimizer) + losses = force_list(loss) + assert len(optimizers) == len(losses) + clipped_grads_and_vars = [] + for optim, loss_ in zip(optimizers, losses): + grads_and_vars = optim.compute_gradients( + loss_, self.model.trainable_variables() + ) + clipped_g_and_v = [] + for g, v in grads_and_vars: + if g is not None: + clipped_g, _ = tf.clip_by_global_norm( + [g], self.config["grad_clip"] + ) + clipped_g_and_v.append((clipped_g[0], v)) + clipped_grads_and_vars.append(clipped_g_and_v) + + self.grads = [g for g_and_v in clipped_grads_and_vars for (g, v) in g_and_v] + # Only one optimizer and and loss term. + else: + grads_and_vars = optimizer.compute_gradients( + loss, self.model.trainable_variables() + ) + grads = [g for (g, v) in grads_and_vars] + self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) + clipped_grads_and_vars = list( + zip(self.grads, self.model.trainable_variables()) + ) - if isinstance(policy.action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [policy.action_space.n] - elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = policy.action_space.nvec.astype(np.int32) - else: - is_multidiscrete = False - output_hidden_shape = 1 + return clipped_grads_and_vars - def make_time_major(*args, **kw): - return _make_time_major( - policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw - ) - actions = train_batch[SampleBatch.ACTIONS] - dones = train_batch[SampleBatch.DONES] - rewards = train_batch[SampleBatch.REWARDS] - behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] - behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] - unpacked_behaviour_logits = tf.split(behaviour_logits, output_hidden_shape, axis=1) - unpacked_outputs = tf.split(model_out, output_hidden_shape, axis=1) - values = model.value_function() +class VTraceOptimizer: + """Optimizer function for VTrace policies.""" - if policy.is_recurrent(): - max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) - mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(rewards) - - # Prepare actions for loss - loss_actions = actions if is_multidiscrete else tf.expand_dims(actions, axis=1) - - # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. - drop_last = policy.config["vtrace_drop_last_ts"] - policy.loss = VTraceLoss( - actions=make_time_major(loss_actions, drop_last=drop_last), - actions_logp=make_time_major(action_dist.logp(actions), drop_last=drop_last), - actions_entropy=make_time_major( - action_dist.multi_entropy(), drop_last=drop_last - ), - dones=make_time_major(dones, drop_last=drop_last), - behaviour_action_logp=make_time_major( - behaviour_action_logp, drop_last=drop_last - ), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), - target_logits=make_time_major(unpacked_outputs, drop_last=drop_last), - discount=policy.config["gamma"], - rewards=make_time_major(rewards, drop_last=drop_last), - values=make_time_major(values, drop_last=drop_last), - bootstrap_value=make_time_major(values)[-1], - dist_class=Categorical if is_multidiscrete else dist_class, - model=model, - valid_mask=make_time_major(mask, drop_last=drop_last), - config=policy.config, - vf_loss_coeff=policy.config["vf_loss_coeff"], - entropy_coeff=policy.entropy_coeff, - clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], - clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"], - ) - - if policy.config.get("_separate_vf_optimizer"): - return policy.loss.loss_wo_vf, policy.loss.vf_loss - else: - return policy.loss.total_loss - - -def stats(policy, train_batch): - drop_last = policy.config["vtrace"] and policy.config["vtrace_drop_last_ts"] - values_batched = _make_time_major( - policy, - train_batch.get(SampleBatch.SEQ_LENS), - policy.model.value_function(), - drop_last=drop_last, - ) - - return { - "cur_lr": tf.cast(policy.cur_lr, tf.float64), - "policy_loss": policy.loss.mean_pi_loss, - "entropy": policy.loss.mean_entropy, - "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), - "var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()), - "vf_loss": policy.loss.mean_vf_loss, - "vf_explained_var": explained_variance( - tf.reshape(policy.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1]), - ), - } - - -def grad_stats(policy, train_batch, grads): - # We have support for more than one loss (list of lists of grads). - if policy.config.get("_tf_policy_handles_more_than_one_loss"): - grad_gnorm = [tf.linalg.global_norm(g) for g in grads] - # Old case: We have a single list of grads (only one loss term and - # optimizer). - else: - grad_gnorm = tf.linalg.global_norm(grads) + def __init__(self): + pass - return { - "grad_gnorm": grad_gnorm, - } + # TODO: maybe standardize this function, so the choice of optimizers are more + # predictable for common agents. + def optimizer( + self, + ) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]: + config = self.config + if config["opt_type"] == "adam": + if config["framework"] in ["tf2", "tfe"]: + optim = tf.keras.optimizers.Adam(self.cur_lr) + if config["_separate_vf_optimizer"]: + return optim, tf.keras.optimizers.Adam(config["_lr_vf"]) + else: + optim = tf1.train.AdamOptimizer(self.cur_lr) + if config["_separate_vf_optimizer"]: + return optim, tf1.train.AdamOptimizer(config["_lr_vf"]) + else: + if config["_separate_vf_optimizer"]: + raise ValueError( + "RMSProp optimizer not supported for separate" + "vf- and policy losses yet! Set `opt_type=adam`" + ) + if tfv == 2: + optim = tf.keras.optimizers.RMSprop( + self.cur_lr, config["decay"], config["momentum"], config["epsilon"] + ) + else: + optim = tf1.train.RMSPropOptimizer( + self.cur_lr, config["decay"], config["momentum"], config["epsilon"] + ) -def choose_optimizer(policy, config): - if policy.config["opt_type"] == "adam": - if policy.config["framework"] in ["tf2", "tfe"]: - optim = tf.keras.optimizers.Adam(policy.cur_lr) - if policy.config["_separate_vf_optimizer"]: - return optim, tf.keras.optimizers.Adam(policy.config["_lr_vf"]) - else: - optim = tf1.train.AdamOptimizer(policy.cur_lr) - if policy.config["_separate_vf_optimizer"]: - return optim, tf1.train.AdamOptimizer(policy.config["_lr_vf"]) - else: - if policy.config["_separate_vf_optimizer"]: - raise ValueError( - "RMSProp optimizer not supported for separate" - "vf- and policy losses yet! Set `opt_type=adam`" - ) + return optim - if tfv == 2: - optim = tf.keras.optimizers.RMSprop( - policy.cur_lr, config["decay"], config["momentum"], config["epsilon"] - ) - else: - optim = tf1.train.RMSPropOptimizer( - policy.cur_lr, config["decay"], config["momentum"], config["epsilon"] - ) - return optim +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_vtrace_tf_policy(base: type) -> type: + """Construct an VTraceTFPolicy inheriting either dynamic or eager base policies. + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. -def clip_gradients(policy, optimizer, loss): - # Supporting more than one loss/optimizer. - if policy.config["_tf_policy_handles_more_than_one_loss"]: - optimizers = force_list(optimizer) - losses = force_list(loss) - assert len(optimizers) == len(losses) - clipped_grads_and_vars = [] - for optim, loss_ in zip(optimizers, losses): - grads_and_vars = optim.compute_gradients( - loss_, policy.model.trainable_variables() + Returns: + A TF Policy to be used with ImpalaTrainer. + """ + # VTrace mixins are placed in front of more general mixins to make sure + # their functions like optimizer() overrides all the other implementations + # (e.g., LearningRateSchedule.optimizer()) + class VTraceTFPolicy( + VTraceClipGradients, + VTraceOptimizer, + LearningRateSchedule, + EntropyCoeffSchedule, + base, + ): + def __init__( + self, + obs_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) + + # Initialize base class. + base.__init__( + self, + obs_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, ) - clipped_g_and_v = [] - for g, v in grads_and_vars: - if g is not None: - clipped_g, _ = tf.clip_by_global_norm( - [g], policy.config["grad_clip"] - ) - clipped_g_and_v.append((clipped_g[0], v)) - clipped_grads_and_vars.append(clipped_g_and_v) - - policy.grads = [g for g_and_v in clipped_grads_and_vars for (g, v) in g_and_v] - # Only one optimizer and and loss term. - else: - grads_and_vars = optimizer.compute_gradients( - loss, policy.model.trainable_variables() - ) - grads = [g for (g, v) in grads_and_vars] - policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) - clipped_grads_and_vars = list( - zip(policy.grads, policy.model.trainable_variables()) - ) - return clipped_grads_and_vars + VTraceClipGradients.__init__(self) + VTraceOptimizer.__init__(self) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + + if isinstance(self.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [self.action_space.n] + elif isinstance(self.action_space, gym.spaces.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = self.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + def make_time_major(*args, **kw): + return _make_time_major( + self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw + ) + + actions = train_batch[SampleBatch.ACTIONS] + dones = train_batch[SampleBatch.DONES] + rewards = train_batch[SampleBatch.REWARDS] + behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] + behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] + unpacked_behaviour_logits = tf.split( + behaviour_logits, output_hidden_shape, axis=1 + ) + unpacked_outputs = tf.split(model_out, output_hidden_shape, axis=1) + values = model.value_function() + + if self.is_recurrent(): + max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(rewards) + + # Prepare actions for loss + loss_actions = ( + actions if is_multidiscrete else tf.expand_dims(actions, axis=1) + ) -def setup_mixins(policy, obs_space, action_space, config): - LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) + # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. + drop_last = self.config["vtrace_drop_last_ts"] + self.vtrace_loss = VTraceLoss( + actions=make_time_major(loss_actions, drop_last=drop_last), + actions_logp=make_time_major( + action_dist.logp(actions), drop_last=drop_last + ), + actions_entropy=make_time_major( + action_dist.multi_entropy(), drop_last=drop_last + ), + dones=make_time_major(dones, drop_last=drop_last), + behaviour_action_logp=make_time_major( + behaviour_action_logp, drop_last=drop_last + ), + behaviour_logits=make_time_major( + unpacked_behaviour_logits, drop_last=drop_last + ), + target_logits=make_time_major(unpacked_outputs, drop_last=drop_last), + discount=self.config["gamma"], + rewards=make_time_major(rewards, drop_last=drop_last), + values=make_time_major(values, drop_last=drop_last), + bootstrap_value=make_time_major(values)[-1], + dist_class=Categorical if is_multidiscrete else dist_class, + model=model, + valid_mask=make_time_major(mask, drop_last=drop_last), + config=self.config, + vf_loss_coeff=self.config["vf_loss_coeff"], + entropy_coeff=self.entropy_coeff, + clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], + ) + if self.config.get("_separate_vf_optimizer"): + return self.vtrace_loss.loss_wo_vf, self.vtrace_loss.vf_loss + else: + return self.vtrace_loss.total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] + values_batched = _make_time_major( + self, + train_batch.get(SampleBatch.SEQ_LENS), + self.model.value_function(), + drop_last=drop_last, + ) -VTraceTFPolicy = build_tf_policy( - name="VTraceTFPolicy", - get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, - loss_fn=build_vtrace_loss, - stats_fn=stats, - grad_stats_fn=grad_stats, - optimizer_fn=choose_optimizer, - compute_gradients_fn=clip_gradients, - before_loss_init=setup_mixins, - mixins=[LearningRateSchedule, EntropyCoeffSchedule], - get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"], -) + return { + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "policy_loss": self.vtrace_loss.mean_pi_loss, + "entropy": self.vtrace_loss.mean_entropy, + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), + "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()), + "vf_loss": self.vtrace_loss.mean_vf_loss, + "vf_explained_var": explained_variance( + tf.reshape(self.vtrace_loss.value_targets, [-1]), + tf.reshape(values_batched, [-1]), + ), + } + + @override(base) + def grad_stats_fn( + self, train_batch: SampleBatch, grads: ModelGradients + ) -> Dict[str, TensorType]: + # We have support for more than one loss (list of lists of grads). + if self.config.get("_tf_policy_handles_more_than_one_loss"): + grad_gnorm = [tf.linalg.global_norm(g) for g in grads] + # Old case: We have a single list of grads (only one loss term and + # optimizer). + else: + grad_gnorm = tf.linalg.global_norm(grads) + + return { + "grad_gnorm": grad_gnorm, + } + + @override(base) + def get_batch_divisibility_req(self) -> int: + return self.config["rollout_fragment_length"] + + return VTraceTFPolicy + + +VTraceStaticGraphTFPolicy = get_vtrace_tf_policy(DynamicTFPolicyV2) +VTraceEagerTFPolicy = get_vtrace_tf_policy(EagerTFPolicyV2) diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index cce831416a76..cdb7ad7492d4 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -1,22 +1,29 @@ import gym import logging import numpy as np -from typing import Any, Dict +from typing import Dict, List, Type, Union import ray import ray.rllib.agents.impala.vtrace_torch as vtrace +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.torch.torch_action_dist import TorchCategorical -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_mixins import LearningRateSchedule, EntropyCoeffSchedule +from ray.rllib.policy.torch_mixins import ( + EntropyCoeffSchedule, + LearningRateSchedule, +) +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import ( apply_grad_clipping, explained_variance, global_norm, sequence_mask, ) +from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() @@ -119,101 +126,6 @@ def __init__( ) -def build_vtrace_loss(policy, model, dist_class, train_batch): - model_out, _ = model(train_batch) - action_dist = dist_class(model_out, model) - - if isinstance(policy.action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [policy.action_space.n] - elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = policy.action_space.nvec.astype(np.int32) - else: - is_multidiscrete = False - output_hidden_shape = 1 - - def _make_time_major(*args, **kw): - return make_time_major( - policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw - ) - - actions = train_batch[SampleBatch.ACTIONS] - dones = train_batch[SampleBatch.DONES] - rewards = train_batch[SampleBatch.REWARDS] - behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] - behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] - if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): - unpacked_behaviour_logits = torch.split( - behaviour_logits, list(output_hidden_shape), dim=1 - ) - unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) - else: - unpacked_behaviour_logits = torch.chunk( - behaviour_logits, output_hidden_shape, dim=1 - ) - unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) - values = model.value_function() - - if policy.is_recurrent(): - max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) - mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) - mask = torch.reshape(mask_orig, [-1]) - else: - mask = torch.ones_like(rewards) - - # Prepare actions for loss. - loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) - - # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. - drop_last = policy.config["vtrace_drop_last_ts"] - loss = VTraceLoss( - actions=_make_time_major(loss_actions, drop_last=drop_last), - actions_logp=_make_time_major(action_dist.logp(actions), drop_last=drop_last), - actions_entropy=_make_time_major(action_dist.entropy(), drop_last=drop_last), - dones=_make_time_major(dones, drop_last=drop_last), - behaviour_action_logp=_make_time_major( - behaviour_action_logp, drop_last=drop_last - ), - behaviour_logits=_make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), - target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), - discount=policy.config["gamma"], - rewards=_make_time_major(rewards, drop_last=drop_last), - values=_make_time_major(values, drop_last=drop_last), - bootstrap_value=_make_time_major(values)[-1], - dist_class=TorchCategorical if is_multidiscrete else dist_class, - model=model, - valid_mask=_make_time_major(mask, drop_last=drop_last), - config=policy.config, - vf_loss_coeff=policy.config["vf_loss_coeff"], - entropy_coeff=policy.entropy_coeff, - clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], - clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"], - ) - - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["pi_loss"] = loss.pi_loss - model.tower_stats["vf_loss"] = loss.vf_loss - model.tower_stats["entropy"] = loss.entropy - model.tower_stats["mean_entropy"] = loss.mean_entropy - model.tower_stats["total_loss"] = loss.total_loss - - values_batched = make_time_major( - policy, - train_batch.get(SampleBatch.SEQ_LENS), - values, - drop_last=policy.config["vtrace"] and drop_last, - ) - model.tower_stats["vf_explained_var"] = explained_variance( - torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1]) - ) - - return loss.total_loss - - def make_time_major(policy, seq_lens, tensor, drop_last=False): """Swaps batch and trajectory axis. @@ -251,51 +163,192 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): return res -def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, Any]: +class VTraceOptimizer: + """Optimizer function for VTrace torch policies.""" - return { - "cur_lr": policy.cur_lr, - "total_loss": torch.mean(torch.stack(policy.get_tower_stats("total_loss"))), - "policy_loss": torch.mean(torch.stack(policy.get_tower_stats("pi_loss"))), - "entropy": torch.mean(torch.stack(policy.get_tower_stats("mean_entropy"))), - "entropy_coeff": policy.entropy_coeff, - "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("vf_loss"))), - "vf_explained_var": torch.mean( - torch.stack(policy.get_tower_stats("vf_explained_var")) - ), - } + def __init__(self): + pass + @override(TorchPolicyV2) + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + if self.config["opt_type"] == "adam": + return torch.optim.Adam(params=self.model.parameters(), lr=self.cur_lr) + else: + return torch.optim.RMSprop( + params=self.model.parameters(), + lr=self.cur_lr, + weight_decay=self.config["decay"], + momentum=self.config["momentum"], + eps=self.config["epsilon"], + ) + + +# VTrace mixins are placed in front of more general mixins to make sure +# their functions like optimizer() overrides all the other implementations +# (e.g., LearningRateSchedule.optimizer()) +class VTraceTorchPolicy( + VTraceOptimizer, + LearningRateSchedule, + EntropyCoeffSchedule, + TorchPolicyV2, +): + """PyTorch policy class used with ImpalaTrainer.""" + + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) + + VTraceOptimizer.__init__(self) + # Need to initialize learning rate variable before calling + # TorchPolicyV2.__init__. + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) -def choose_optimizer(policy, config): - if policy.config["opt_type"] == "adam": - return torch.optim.Adam(params=policy.model.parameters(), lr=policy.cur_lr) - else: - return torch.optim.RMSprop( - params=policy.model.parameters(), - lr=policy.cur_lr, - weight_decay=config["decay"], - momentum=config["momentum"], - eps=config["epsilon"], + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], ) + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() -def setup_mixins(policy, obs_space, action_space, config): - EntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - - -VTraceTorchPolicy = build_policy_class( - name="VTraceTorchPolicy", - framework="torch", - loss_fn=build_vtrace_loss, - get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, - stats_fn=stats, - extra_grad_process_fn=apply_grad_clipping, - optimizer_fn=choose_optimizer, - before_init=setup_mixins, - mixins=[LearningRateSchedule, EntropyCoeffSchedule], - get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"], -) + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[ActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + + if isinstance(self.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [self.action_space.n] + elif isinstance(self.action_space, gym.spaces.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = self.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + def _make_time_major(*args, **kw): + return make_time_major( + self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw + ) + + actions = train_batch[SampleBatch.ACTIONS] + dones = train_batch[SampleBatch.DONES] + rewards = train_batch[SampleBatch.REWARDS] + behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] + behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] + if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): + unpacked_behaviour_logits = torch.split( + behaviour_logits, list(output_hidden_shape), dim=1 + ) + unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) + else: + unpacked_behaviour_logits = torch.chunk( + behaviour_logits, output_hidden_shape, dim=1 + ) + unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) + values = model.value_function() + + if self.is_recurrent(): + max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) + mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + mask = torch.reshape(mask_orig, [-1]) + else: + mask = torch.ones_like(rewards) + + # Prepare actions for loss. + loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) + + # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. + drop_last = self.config["vtrace_drop_last_ts"] + loss = VTraceLoss( + actions=_make_time_major(loss_actions, drop_last=drop_last), + actions_logp=_make_time_major( + action_dist.logp(actions), drop_last=drop_last + ), + actions_entropy=_make_time_major( + action_dist.entropy(), drop_last=drop_last + ), + dones=_make_time_major(dones, drop_last=drop_last), + behaviour_action_logp=_make_time_major( + behaviour_action_logp, drop_last=drop_last + ), + behaviour_logits=_make_time_major( + unpacked_behaviour_logits, drop_last=drop_last + ), + target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), + discount=self.config["gamma"], + rewards=_make_time_major(rewards, drop_last=drop_last), + values=_make_time_major(values, drop_last=drop_last), + bootstrap_value=_make_time_major(values)[-1], + dist_class=TorchCategorical if is_multidiscrete else dist_class, + model=model, + valid_mask=_make_time_major(mask, drop_last=drop_last), + config=self.config, + vf_loss_coeff=self.config["vf_loss_coeff"], + entropy_coeff=self.entropy_coeff, + clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], + ) + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["pi_loss"] = loss.pi_loss + model.tower_stats["vf_loss"] = loss.vf_loss + model.tower_stats["entropy"] = loss.entropy + model.tower_stats["mean_entropy"] = loss.mean_entropy + model.tower_stats["total_loss"] = loss.total_loss + + values_batched = make_time_major( + self, + train_batch.get(SampleBatch.SEQ_LENS), + values, + drop_last=self.config["vtrace"] and drop_last, + ) + model.tower_stats["vf_explained_var"] = explained_variance( + torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1]) + ) + + return loss.total_loss + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return convert_to_numpy( + { + "cur_lr": self.cur_lr, + "total_loss": torch.mean( + torch.stack(self.get_tower_stats("total_loss")) + ), + "policy_loss": torch.mean(torch.stack(self.get_tower_stats("pi_loss"))), + "entropy": torch.mean( + torch.stack(self.get_tower_stats("mean_entropy")) + ), + "entropy_coeff": self.entropy_coeff, + "var_gnorm": global_norm(self.model.trainable_variables()), + "vf_loss": torch.mean(torch.stack(self.get_tower_stats("vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(self.get_tower_stats("vf_explained_var")) + ), + } + ) + + @override(TorchPolicyV2) + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) + + @override(TorchPolicyV2) + def get_batch_divisibility_req(self) -> int: + return self.config["rollout_fragment_length"] diff --git a/rllib/agents/maddpg/__init__.py b/rllib/agents/maddpg/__init__.py index 792350b8c452..84d50c49ba0b 100644 --- a/rllib/agents/maddpg/__init__.py +++ b/rllib/agents/maddpg/__init__.py @@ -1,3 +1,19 @@ -from ray.rllib.agents.maddpg.maddpg import MADDPGTrainer, DEFAULT_CONFIG +from ray.rllib.algorithms.maddpg.maddpg import ( + MADDPGTrainer, + MADDPGTFPolicy, + DEFAULT_CONFIG, +) -__all__ = ["MADDPGTrainer", "DEFAULT_CONFIG"] +__all__ = [ + "MADDPGTrainer", + "MADDPGTFPolicy", + "DEFAULT_CONFIG", +] + +from ray.rllib.utils.deprecation import deprecation_warning + +deprecation_warning( + "ray.rllib.agents.maddpg", + "ray.rllib.algorithms.maddpg", + error=False, +) diff --git a/rllib/agents/marwil/__init__.py b/rllib/agents/marwil/__init__.py index af4c778db225..ff9db88d842b 100644 --- a/rllib/agents/marwil/__init__.py +++ b/rllib/agents/marwil/__init__.py @@ -5,7 +5,7 @@ MARWILTrainer, ) from ray.rllib.algorithms.marwil.marwil_tf_policy import ( - MARWILDynamicTFPolicy, + MARWILStaticGraphTFPolicy, MARWILEagerTFPolicy, ) from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy @@ -14,7 +14,7 @@ "BCConfig", "BCTrainer", "MARWILConfig", - "MARWILDynamicTFPolicy", + "MARWILStaticGraphTFPolicy", "MARWILEagerTFPolicy", "MARWILTorchPolicy", "MARWILTrainer", diff --git a/rllib/agents/ppo/__init__.py b/rllib/agents/ppo/__init__.py index 0b6736d14406..dd998c21d951 100644 --- a/rllib/agents/ppo/__init__.py +++ b/rllib/agents/ppo/__init__.py @@ -1,5 +1,5 @@ from ray.rllib.agents.ppo.ppo import PPOConfig, PPOTrainer, DEFAULT_CONFIG -from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy +from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy, PPOEagerTFPolicy from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.agents.ppo.appo import APPOConfig, APPOTrainer from ray.rllib.agents.ppo.ddppo import DDPPOConfig, DDPPOTrainer @@ -11,7 +11,8 @@ "DDPPOTrainer", "DEFAULT_CONFIG", "PPOConfig", - "PPOTFPolicy", + "PPOStaticGraphTFPolicy", + "PPOEagerTFPolicy", "PPOTorchPolicy", "PPOTrainer", ] diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 3200320fd491..f17be987546a 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -12,7 +12,6 @@ from typing import Optional, Type import logging -from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.agents.impala import ImpalaTrainer, ImpalaConfig from ray.rllib.policy.policy import Policy @@ -246,11 +245,17 @@ def get_default_policy_class( self, config: PartialTrainerConfigDict ) -> Optional[Type[Policy]]: if config["framework"] == "torch": - from ray.rllib.agents.ppo.appo_torch_policy import AsyncPPOTorchPolicy + from ray.rllib.agents.ppo.appo_torch_policy import APPOTorchPolicy - return AsyncPPOTorchPolicy + return APPOTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.agents.ppo.appo_tf_policy import APPOStaticGraphTFPolicy + + return APPOStaticGraphTFPolicy else: - return AsyncPPOTFPolicy + from ray.rllib.agents.ppo.appo_tf_policy import APPOEagerTFPolicy + + return APPOEagerTFPolicy # Deprecated: Use ray.rllib.agents.ppo.APPOConfig instead! diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 62ddab58c9cb..3d5e893ee06e 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -10,11 +10,12 @@ import gym from typing import Dict, List, Optional, Type, Union +import ray from ray.rllib.agents.impala import vtrace_tf as vtrace from ray.rllib.agents.impala.vtrace_tf_policy import ( _make_time_major, - clip_gradients, - choose_optimizer, + VTraceClipGradients, + VTraceOptimizer, ) from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( @@ -22,23 +23,25 @@ Postprocessing, ) from ray.rllib.models.tf.tf_action_dist import Categorical -from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.tf_mixins import ( EntropyCoeffSchedule, LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin, ) -from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import ( + DeveloperAPI, + override, +) from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable -from ray.rllib.utils.typing import AgentID, TensorType, TrainerConfigDict +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() @@ -48,47 +51,39 @@ logger = logging.getLogger(__name__) -def make_appo_model( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> ModelV2: +@DeveloperAPI +def make_appo_model(policy) -> ModelV2: """Builds model and target model for APPO. - Args: - policy (Policy): The Policy, which will use the model for optimization. - obs_space (gym.spaces.Space): The policy's observation space. - action_space (gym.spaces.Space): The policy's action space. - config (TrainerConfigDict): - Returns: ModelV2: The Model for the Policy to use. Note: The target model will not be returned, just assigned to `policy.target_model`. """ # Get the num_outputs for the following model construction calls. - _, logit_dim = ModelCatalog.get_action_dist(action_space, config["model"]) + _, logit_dim = ModelCatalog.get_action_dist( + policy.action_space, policy.config["model"] + ) # Construct the (main) model. policy.model = ModelCatalog.get_model_v2( - obs_space, - action_space, + policy.observation_space, + policy.action_space, logit_dim, - config["model"], + policy.config["model"], name=POLICY_SCOPE, - framework="torch" if config["framework"] == "torch" else "tf", + framework=policy.framework, ) policy.model_variables = policy.model.variables() # Construct the target model. policy.target_model = ModelCatalog.get_model_v2( - obs_space, - action_space, + policy.observation_space, + policy.action_space, logit_dim, - config["model"], + policy.config["model"], name=TARGET_POLICY_SCOPE, - framework="torch" if config["framework"] == "torch" else "tf", + framework=policy.framework, ) policy.target_model_variables = policy.target_model.variables() @@ -96,313 +91,6 @@ def make_appo_model( return policy.model -def appo_surrogate_loss( - policy: Policy, - model: ModelV2, - dist_class: Type[TFActionDistribution], - train_batch: SampleBatch, -) -> Union[TensorType, List[TensorType]]: - """Constructs the loss for APPO. - - With IS modifications and V-trace for Advantage Estimation. - - Args: - policy (Policy): The Policy to calculate the loss for. - model (ModelV2): The Model to calculate the loss for. - dist_class (Type[ActionDistribution]): The action distr. class. - train_batch (SampleBatch): The training data. - - Returns: - Union[TensorType, List[TensorType]]: A single loss tensor or a list - of loss tensors. - """ - model_out, _ = model(train_batch) - action_dist = dist_class(model_out, model) - - if isinstance(policy.action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [policy.action_space.n] - elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = policy.action_space.nvec.astype(np.int32) - else: - is_multidiscrete = False - output_hidden_shape = 1 - - # TODO: (sven) deprecate this when trajectory view API gets activated. - def make_time_major(*args, **kw): - return _make_time_major( - policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw - ) - - actions = train_batch[SampleBatch.ACTIONS] - dones = train_batch[SampleBatch.DONES] - rewards = train_batch[SampleBatch.REWARDS] - behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] - - target_model_out, _ = policy.target_model(train_batch) - prev_action_dist = dist_class(behaviour_logits, policy.model) - values = policy.model.value_function() - values_time_major = make_time_major(values) - - policy.model_vars = policy.model.variables() - policy.target_model_vars = policy.target_model.variables() - - if policy.is_recurrent(): - max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) - mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) - mask = tf.reshape(mask, [-1]) - mask = make_time_major(mask, drop_last=policy.config["vtrace"]) - - def reduce_mean_valid(t): - return tf.reduce_mean(tf.boolean_mask(t, mask)) - - else: - reduce_mean_valid = tf.reduce_mean - - if policy.config["vtrace"]: - drop_last = policy.config["vtrace_drop_last_ts"] - logger.debug( - "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" - ) - - # Prepare actions for loss. - loss_actions = actions if is_multidiscrete else tf.expand_dims(actions, axis=1) - - old_policy_behaviour_logits = tf.stop_gradient(target_model_out) - old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) - - # Prepare KL for Loss - mean_kl = make_time_major( - old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last - ) - - unpacked_behaviour_logits = tf.split( - behaviour_logits, output_hidden_shape, axis=1 - ) - unpacked_old_policy_behaviour_logits = tf.split( - old_policy_behaviour_logits, output_hidden_shape, axis=1 - ) - - # Compute vtrace on the CPU for better perf. - with tf.device("/cpu:0"): - vtrace_returns = vtrace.multi_from_logits( - behaviour_policy_logits=make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), - target_policy_logits=make_time_major( - unpacked_old_policy_behaviour_logits, drop_last=drop_last - ), - actions=tf.unstack( - make_time_major(loss_actions, drop_last=drop_last), axis=2 - ), - discounts=tf.cast( - ~make_time_major(tf.cast(dones, tf.bool), drop_last=drop_last), - tf.float32, - ) - * policy.config["gamma"], - rewards=make_time_major(rewards, drop_last=drop_last), - values=values_time_major[:-1] if drop_last else values_time_major, - bootstrap_value=values_time_major[-1], - dist_class=Categorical if is_multidiscrete else dist_class, - model=model, - clip_rho_threshold=tf.cast( - policy.config["vtrace_clip_rho_threshold"], tf.float32 - ), - clip_pg_rho_threshold=tf.cast( - policy.config["vtrace_clip_pg_rho_threshold"], tf.float32 - ), - ) - - actions_logp = make_time_major(action_dist.logp(actions), drop_last=drop_last) - prev_actions_logp = make_time_major( - prev_action_dist.logp(actions), drop_last=drop_last - ) - old_policy_actions_logp = make_time_major( - old_policy_action_dist.logp(actions), drop_last=drop_last - ) - - is_ratio = tf.clip_by_value( - tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 - ) - logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp) - policy._is_ratio = is_ratio - - advantages = vtrace_returns.pg_advantages - surrogate_loss = tf.minimum( - advantages * logp_ratio, - advantages - * tf.clip_by_value( - logp_ratio, - 1 - policy.config["clip_param"], - 1 + policy.config["clip_param"], - ), - ) - - action_kl = tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl - mean_kl_loss = reduce_mean_valid(action_kl) - mean_policy_loss = -reduce_mean_valid(surrogate_loss) - - # The value function loss. - if drop_last: - delta = values_time_major[:-1] - vtrace_returns.vs - else: - delta = values_time_major - vtrace_returns.vs - value_targets = vtrace_returns.vs - mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) - - # The entropy loss. - actions_entropy = make_time_major(action_dist.multi_entropy(), drop_last=True) - mean_entropy = reduce_mean_valid(actions_entropy) - - else: - logger.debug("Using PPO surrogate loss (vtrace=False)") - - # Prepare KL for Loss - mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist)) - - logp_ratio = tf.math.exp( - make_time_major(action_dist.logp(actions)) - - make_time_major(prev_action_dist.logp(actions)) - ) - - advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES]) - surrogate_loss = tf.minimum( - advantages * logp_ratio, - advantages - * tf.clip_by_value( - logp_ratio, - 1 - policy.config["clip_param"], - 1 + policy.config["clip_param"], - ), - ) - - action_kl = tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl - mean_kl_loss = reduce_mean_valid(action_kl) - mean_policy_loss = -reduce_mean_valid(surrogate_loss) - - # The value function loss. - value_targets = make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) - delta = values_time_major - value_targets - mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) - - # The entropy loss. - mean_entropy = reduce_mean_valid(make_time_major(action_dist.multi_entropy())) - - # The summed weighted loss. - total_loss = mean_policy_loss - mean_entropy * policy.entropy_coeff - # Optional KL loss. - if policy.config["use_kl_loss"]: - total_loss += policy.kl_coeff * mean_kl_loss - # Optional vf loss (or in a separate term due to separate - # optimizers/networks). - loss_wo_vf = total_loss - if not policy.config["_separate_vf_optimizer"]: - total_loss += mean_vf_loss * policy.config["vf_loss_coeff"] - - # Store stats in policy for stats_fn. - policy._total_loss = total_loss - policy._loss_wo_vf = loss_wo_vf - policy._mean_policy_loss = mean_policy_loss - # Backward compatibility: Deprecate policy._mean_kl. - policy._mean_kl_loss = policy._mean_kl = mean_kl_loss - policy._mean_vf_loss = mean_vf_loss - policy._mean_entropy = mean_entropy - policy._value_targets = value_targets - - # Return one total loss or two losses: vf vs rest (policy + kl). - if policy.config["_separate_vf_optimizer"]: - return loss_wo_vf, mean_vf_loss - else: - return total_loss - - -def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - """Stats function for APPO. Returns a dict with important loss stats. - - Args: - policy (Policy): The Policy to generate stats for. - train_batch (SampleBatch): The SampleBatch (already) used for training. - - Returns: - Dict[str, TensorType]: The stats dict. - """ - values_batched = _make_time_major( - policy, - train_batch.get(SampleBatch.SEQ_LENS), - policy.model.value_function(), - drop_last=policy.config["vtrace"] and policy.config["vtrace_drop_last_ts"], - ) - - stats_dict = { - "cur_lr": tf.cast(policy.cur_lr, tf.float64), - "total_loss": policy._total_loss, - "policy_loss": policy._mean_policy_loss, - "entropy": policy._mean_entropy, - "var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()), - "vf_loss": policy._mean_vf_loss, - "vf_explained_var": explained_variance( - tf.reshape(policy._value_targets, [-1]), tf.reshape(values_batched, [-1]) - ), - "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), - } - - if policy.config["vtrace"]: - is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1]) - stats_dict["mean_IS"] = is_stat_mean - stats_dict["var_IS"] = is_stat_var - - if policy.config["use_kl_loss"]: - stats_dict["kl"] = policy._mean_kl_loss - stats_dict["KL_Coeff"] = policy.kl_coeff - - return stats_dict - - -def postprocess_trajectory( - policy: Policy, - sample_batch: SampleBatch, - other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, - episode: Optional[Episode] = None, -) -> SampleBatch: - """Postprocesses a trajectory and returns the processed trajectory. - - The trajectory contains only data from one episode and from one agent. - - If `config.batch_mode=truncate_episodes` (default), sample_batch may - contain a truncated (at-the-end) episode, in case the - `config.rollout_fragment_length` was reached by the sampler. - - If `config.batch_mode=complete_episodes`, sample_batch will contain - exactly one episode (no matter how long). - New columns can be added to sample_batch and existing ones may be altered. - - Args: - policy (Policy): The Policy used to generate the trajectory - (`sample_batch`) - sample_batch (SampleBatch): The SampleBatch to postprocess. - other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional - dict of AgentIDs mapping to other agents' trajectory data (from the - same episode). NOTE: The other agents use the same policy. - episode (Optional[Episode]): Optional multi-agent episode - object in which the agents operated. - - Returns: - SampleBatch: The postprocessed, modified SampleBatch (or a new one). - """ - if not policy.config["vtrace"]: - sample_batch = compute_gae_for_sample_batch( - policy, sample_batch, other_agent_batches, episode - ) - - return sample_batch - - -def add_values(policy): - out = {} - if not policy.config["vtrace"]: - out[SampleBatch.VF_PREDS] = policy.model.value_function() - return out - - class TargetNetworkMixin: """Target NN is updated by master learner via the `update_target` method. @@ -422,69 +110,365 @@ def do_update(): self.update_target = do_update - @override(TFPolicy) def variables(self): return self.model_vars + self.target_model_vars -def setup_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: - """Call all mixin classes' constructors before APPOPolicy initialization. +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_appo_tf_policy(base: type) -> type: + """Construct an APPOTFPolicy inheriting either dynamic or eager base policies. Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. - config (TrainerConfigDict): The Policy's config. - """ - LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - KLCoeffMixin.__init__(policy, config) - ValueNetworkMixin.__init__(policy, config) - EntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. -def setup_late_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: - """Call all mixin classes' constructors after APPOPolicy initialization. - - Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. - config (TrainerConfigDict): The Policy's config. + Returns: + A TF Policy to be used with ImpalaTrainer. """ - TargetNetworkMixin.__init__(policy, obs_space, action_space, config) - - -# Build a child class of `DynamicTFPolicy`, given the custom functions defined -# above. -AsyncPPOTFPolicy = build_tf_policy( - name="AsyncPPOTFPolicy", - make_model=make_appo_model, - loss_fn=appo_surrogate_loss, - stats_fn=stats, - postprocess_fn=postprocess_trajectory, - optimizer_fn=choose_optimizer, - compute_gradients_fn=clip_gradients, - extra_action_out_fn=add_values, - before_loss_init=setup_mixins, - after_init=setup_late_mixins, - mixins=[ + + class APPOTFPolicy( + VTraceClipGradients, + VTraceOptimizer, LearningRateSchedule, KLCoeffMixin, - TargetNetworkMixin, - ValueNetworkMixin, EntropyCoeffSchedule, - ], - get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"], -) + ValueNetworkMixin, + TargetNetworkMixin, + base, + ): + def __init__( + self, + obs_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + config = dict(ray.rllib.agents.ppo.appo.DEFAULT_CONFIG, **config) + + # Although this is a no-op, we call __init__ here to make it clear + # that base.__init__ will use the make_model() call. + VTraceClipGradients.__init__(self) + VTraceOptimizer.__init__(self) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + + # Initialize base class. + base.__init__( + self, + obs_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + ValueNetworkMixin.__init__(self, config) + KLCoeffMixin.__init__(self, config) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + # Initiate TargetNetwork ops after loss initialization. + TargetNetworkMixin.__init__(self, obs_space, action_space, config) + + @override(base) + def make_model(self) -> ModelV2: + return make_appo_model(self) + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + + if isinstance(self.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [self.action_space.n] + elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = self.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + # TODO: (sven) deprecate this when trajectory view API gets activated. + def make_time_major(*args, **kw): + return _make_time_major( + self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw + ) + + actions = train_batch[SampleBatch.ACTIONS] + dones = train_batch[SampleBatch.DONES] + rewards = train_batch[SampleBatch.REWARDS] + behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] + + target_model_out, _ = self.target_model(train_batch) + prev_action_dist = dist_class(behaviour_logits, self.model) + values = self.model.value_function() + values_time_major = make_time_major(values) + + self.model_vars = self.model.variables() + self.target_model_vars = self.target_model.variables() + + if self.is_recurrent(): + max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + mask = tf.reshape(mask, [-1]) + mask = make_time_major(mask, drop_last=self.config["vtrace"]) + + def reduce_mean_valid(t): + return tf.reduce_mean(tf.boolean_mask(t, mask)) + + else: + reduce_mean_valid = tf.reduce_mean + + if self.config["vtrace"]: + drop_last = self.config["vtrace_drop_last_ts"] + logger.debug( + "Using V-Trace surrogate loss (vtrace=True; " + f"drop_last={drop_last})" + ) + + # Prepare actions for loss. + loss_actions = ( + actions if is_multidiscrete else tf.expand_dims(actions, axis=1) + ) + + old_policy_behaviour_logits = tf.stop_gradient(target_model_out) + old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) + + # Prepare KL for Loss + mean_kl = make_time_major( + old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last + ) + + unpacked_behaviour_logits = tf.split( + behaviour_logits, output_hidden_shape, axis=1 + ) + unpacked_old_policy_behaviour_logits = tf.split( + old_policy_behaviour_logits, output_hidden_shape, axis=1 + ) + + # Compute vtrace on the CPU for better perf. + with tf.device("/cpu:0"): + vtrace_returns = vtrace.multi_from_logits( + behaviour_policy_logits=make_time_major( + unpacked_behaviour_logits, drop_last=drop_last + ), + target_policy_logits=make_time_major( + unpacked_old_policy_behaviour_logits, drop_last=drop_last + ), + actions=tf.unstack( + make_time_major(loss_actions, drop_last=drop_last), axis=2 + ), + discounts=tf.cast( + ~make_time_major( + tf.cast(dones, tf.bool), drop_last=drop_last + ), + tf.float32, + ) + * self.config["gamma"], + rewards=make_time_major(rewards, drop_last=drop_last), + values=values_time_major[:-1] + if drop_last + else values_time_major, + bootstrap_value=values_time_major[-1], + dist_class=Categorical if is_multidiscrete else dist_class, + model=model, + clip_rho_threshold=tf.cast( + self.config["vtrace_clip_rho_threshold"], tf.float32 + ), + clip_pg_rho_threshold=tf.cast( + self.config["vtrace_clip_pg_rho_threshold"], tf.float32 + ), + ) + + actions_logp = make_time_major( + action_dist.logp(actions), drop_last=drop_last + ) + prev_actions_logp = make_time_major( + prev_action_dist.logp(actions), drop_last=drop_last + ) + old_policy_actions_logp = make_time_major( + old_policy_action_dist.logp(actions), drop_last=drop_last + ) + + is_ratio = tf.clip_by_value( + tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 + ) + logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp) + self._is_ratio = is_ratio + + advantages = vtrace_returns.pg_advantages + surrogate_loss = tf.minimum( + advantages * logp_ratio, + advantages + * tf.clip_by_value( + logp_ratio, + 1 - self.config["clip_param"], + 1 + self.config["clip_param"], + ), + ) + + action_kl = ( + tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl + ) + mean_kl_loss = reduce_mean_valid(action_kl) + mean_policy_loss = -reduce_mean_valid(surrogate_loss) + + # The value function loss. + if drop_last: + delta = values_time_major[:-1] - vtrace_returns.vs + else: + delta = values_time_major - vtrace_returns.vs + value_targets = vtrace_returns.vs + mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) + + # The entropy loss. + actions_entropy = make_time_major( + action_dist.multi_entropy(), drop_last=True + ) + mean_entropy = reduce_mean_valid(actions_entropy) + + else: + logger.debug("Using PPO surrogate loss (vtrace=False)") + + # Prepare KL for Loss + mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist)) + + logp_ratio = tf.math.exp( + make_time_major(action_dist.logp(actions)) + - make_time_major(prev_action_dist.logp(actions)) + ) + + advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES]) + surrogate_loss = tf.minimum( + advantages * logp_ratio, + advantages + * tf.clip_by_value( + logp_ratio, + 1 - self.config["clip_param"], + 1 + self.config["clip_param"], + ), + ) + + action_kl = ( + tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl + ) + mean_kl_loss = reduce_mean_valid(action_kl) + mean_policy_loss = -reduce_mean_valid(surrogate_loss) + + # The value function loss. + value_targets = make_time_major( + train_batch[Postprocessing.VALUE_TARGETS] + ) + delta = values_time_major - value_targets + mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) + + # The entropy loss. + mean_entropy = reduce_mean_valid( + make_time_major(action_dist.multi_entropy()) + ) + + # The summed weighted loss. + total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff + # Optional KL loss. + if self.config["use_kl_loss"]: + total_loss += self.kl_coeff * mean_kl_loss + # Optional vf loss (or in a separate term due to separate + # optimizers/networks). + loss_wo_vf = total_loss + if not self.config["_separate_vf_optimizer"]: + total_loss += mean_vf_loss * self.config["vf_loss_coeff"] + + # Store stats in policy for stats_fn. + self._total_loss = total_loss + self._loss_wo_vf = loss_wo_vf + self._mean_policy_loss = mean_policy_loss + # Backward compatibility: Deprecate policy._mean_kl. + self._mean_kl_loss = self._mean_kl = mean_kl_loss + self._mean_vf_loss = mean_vf_loss + self._mean_entropy = mean_entropy + self._value_targets = value_targets + + # Return one total loss or two losses: vf vs rest (policy + kl). + if self.config["_separate_vf_optimizer"]: + return loss_wo_vf, mean_vf_loss + else: + return total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + values_batched = _make_time_major( + self, + train_batch.get(SampleBatch.SEQ_LENS), + self.model.value_function(), + drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"], + ) + + stats_dict = { + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "total_loss": self._total_loss, + "policy_loss": self._mean_policy_loss, + "entropy": self._mean_entropy, + "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()), + "vf_loss": self._mean_vf_loss, + "vf_explained_var": explained_variance( + tf.reshape(self._value_targets, [-1]), + tf.reshape(values_batched, [-1]), + ), + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), + } + + if self.config["vtrace"]: + is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1]) + stats_dict["mean_IS"] = is_stat_mean + stats_dict["var_IS"] = is_stat_var + + if self.config["use_kl_loss"]: + stats_dict["kl"] = self._mean_kl_loss + stats_dict["KL_Coeff"] = self.kl_coeff + + return stats_dict + + @override(base) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[SampleBatch] = None, + episode: Optional["Episode"] = None, + ): + if not self.config["vtrace"]: + sample_batch = compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + return sample_batch + + @override(base) + def extra_action_out_fn(self) -> Dict[str, TensorType]: + extra_action_fetches = super().extra_action_out_fn() + if not self.config["vtrace"]: + extra_action_fetches[SampleBatch.VF_PREDS] = self.model.value_function() + return extra_action_fetches + + @override(base) + def get_batch_divisibility_req(self) -> int: + return self.config["rollout_fragment_length"] + + return APPOTFPolicy + + +APPOStaticGraphTFPolicy = get_appo_tf_policy(DynamicTFPolicyV2) +APPOEagerTFPolicy = get_appo_tf_policy(EagerTFPolicyV2) diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index 2b45a79f89bb..0ff3f20f5d18 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -8,391 +8,404 @@ import gym import numpy as np import logging -from typing import Type +from typing import Any, Dict, List, Optional, Type, Union -from ray.rllib.algorithms.dqn.simple_q_torch_policy import TargetNetworkMixin +import ray import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.agents.impala.vtrace_torch_policy import ( make_time_major, - choose_optimizer, + VTraceOptimizer, ) -from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model, postprocess_trajectory -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.postprocessing import ( + compute_gae_for_sample_batch, + Postprocessing, +) +from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( TorchDistributionWrapper, TorchCategorical, ) -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_mixins import ( EntropyCoeffSchedule, LearningRateSchedule, + KLCoeffMixin, ValueNetworkMixin, + TargetNetworkMixin, ) +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import ( apply_grad_clipping, explained_variance, global_norm, sequence_mask, ) -from ray.rllib.utils.typing import TensorType, TrainerConfigDict +from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() logger = logging.getLogger(__name__) -def appo_surrogate_loss( - policy: Policy, - model: ModelV2, - dist_class: Type[TorchDistributionWrapper], - train_batch: SampleBatch, -) -> TensorType: - """Constructs the loss for APPO. - - With IS modifications and V-trace for Advantage Estimation. - - Args: - policy (Policy): The Policy to calculate the loss for. - model (ModelV2): The Model to calculate the loss for. - dist_class (Type[ActionDistribution]): The action distr. class. - train_batch (SampleBatch): The training data. - - Returns: - Union[TensorType, List[TensorType]]: A single loss tensor or a list - of loss tensors. - """ - target_model = policy.target_models[model] - - model_out, _ = model(train_batch) - action_dist = dist_class(model_out, model) - - if isinstance(policy.action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [policy.action_space.n] - elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = policy.action_space.nvec.astype(np.int32) - else: - is_multidiscrete = False - output_hidden_shape = 1 - - def _make_time_major(*args, **kwargs): - return make_time_major( - policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +class APPOTorchPolicy( + VTraceOptimizer, + LearningRateSchedule, + EntropyCoeffSchedule, + KLCoeffMixin, + ValueNetworkMixin, + TargetNetworkMixin, + TorchPolicyV2, +): + """PyTorch policy class used with APPOTrainer.""" + + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.agents.ppo.appo.DEFAULT_CONFIG, **config) + + # Although this is a no-op, we call __init__ here to make it clear + # that base.__init__ will use the make_model() call. + VTraceOptimizer.__init__(self) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], ) - actions = train_batch[SampleBatch.ACTIONS] - dones = train_batch[SampleBatch.DONES] - rewards = train_batch[SampleBatch.REWARDS] - behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + ValueNetworkMixin.__init__(self, config) + KLCoeffMixin.__init__(self, config) + + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() + + # Initiate TargetNetwork ops after loss initialization. + TargetNetworkMixin.__init__(self) + + @override(TorchPolicyV2) + def init_view_requirements(self): + self.view_requirements = self._get_default_view_requirements() + + @override(TorchPolicyV2) + def make_model(self) -> ModelV2: + return make_appo_model(self) + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[ActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Constructs the loss for APPO. + + With IS modifications and V-trace for Advantage Estimation. + + Args: + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]): The action distr. class. + train_batch (SampleBatch): The training data. + + Returns: + Union[TensorType, List[TensorType]]: A single loss tensor or a list + of loss tensors. + """ + target_model = self.target_models[model] + + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + + if isinstance(self.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [self.action_space.n] + elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = self.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 - target_model_out, _ = target_model(train_batch) + def _make_time_major(*args, **kwargs): + return make_time_major( + self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs + ) - prev_action_dist = dist_class(behaviour_logits, model) - values = model.value_function() - values_time_major = _make_time_major(values) + actions = train_batch[SampleBatch.ACTIONS] + dones = train_batch[SampleBatch.DONES] + rewards = train_batch[SampleBatch.REWARDS] + behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] - drop_last = policy.config["vtrace"] and policy.config["vtrace_drop_last_ts"] + target_model_out, _ = target_model(train_batch) - if policy.is_recurrent(): - max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) - mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) - mask = torch.reshape(mask, [-1]) - mask = _make_time_major(mask, drop_last=drop_last) - num_valid = torch.sum(mask) + prev_action_dist = dist_class(behaviour_logits, model) + values = model.value_function() + values_time_major = _make_time_major(values) - def reduce_mean_valid(t): - return torch.sum(t[mask]) / num_valid + drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] - else: - reduce_mean_valid = torch.mean + if self.is_recurrent(): + max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) + mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + mask = torch.reshape(mask, [-1]) + mask = _make_time_major(mask, drop_last=drop_last) + num_valid = torch.sum(mask) - if policy.config["vtrace"]: - logger.debug( - "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" - ) + def reduce_mean_valid(t): + return torch.sum(t[mask]) / num_valid - old_policy_behaviour_logits = target_model_out.detach() - old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) + else: + reduce_mean_valid = torch.mean - if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): - unpacked_behaviour_logits = torch.split( - behaviour_logits, list(output_hidden_shape), dim=1 + if self.config["vtrace"]: + logger.debug( + "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" ) - unpacked_old_policy_behaviour_logits = torch.split( - old_policy_behaviour_logits, list(output_hidden_shape), dim=1 + + old_policy_behaviour_logits = target_model_out.detach() + old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) + + if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): + unpacked_behaviour_logits = torch.split( + behaviour_logits, list(output_hidden_shape), dim=1 + ) + unpacked_old_policy_behaviour_logits = torch.split( + old_policy_behaviour_logits, list(output_hidden_shape), dim=1 + ) + else: + unpacked_behaviour_logits = torch.chunk( + behaviour_logits, output_hidden_shape, dim=1 + ) + unpacked_old_policy_behaviour_logits = torch.chunk( + old_policy_behaviour_logits, output_hidden_shape, dim=1 + ) + + # Prepare actions for loss. + loss_actions = ( + actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) ) - else: - unpacked_behaviour_logits = torch.chunk( - behaviour_logits, output_hidden_shape, dim=1 + + # Prepare KL for loss. + action_kl = _make_time_major( + old_policy_action_dist.kl(action_dist), drop_last=drop_last ) - unpacked_old_policy_behaviour_logits = torch.chunk( - old_policy_behaviour_logits, output_hidden_shape, dim=1 + + # Compute vtrace on the CPU for better perf. + vtrace_returns = vtrace.multi_from_logits( + behaviour_policy_logits=_make_time_major( + unpacked_behaviour_logits, drop_last=drop_last + ), + target_policy_logits=_make_time_major( + unpacked_old_policy_behaviour_logits, drop_last=drop_last + ), + actions=torch.unbind( + _make_time_major(loss_actions, drop_last=drop_last), dim=2 + ), + discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) + * self.config["gamma"], + rewards=_make_time_major(rewards, drop_last=drop_last), + values=values_time_major[:-1] if drop_last else values_time_major, + bootstrap_value=values_time_major[-1], + dist_class=TorchCategorical if is_multidiscrete else dist_class, + model=model, + clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) - # Prepare actions for loss. - loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) + actions_logp = _make_time_major( + action_dist.logp(actions), drop_last=drop_last + ) + prev_actions_logp = _make_time_major( + prev_action_dist.logp(actions), drop_last=drop_last + ) + old_policy_actions_logp = _make_time_major( + old_policy_action_dist.logp(actions), drop_last=drop_last + ) + is_ratio = torch.clamp( + torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 + ) + logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) + self._is_ratio = is_ratio + + advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) + surrogate_loss = torch.min( + advantages * logp_ratio, + advantages + * torch.clamp( + logp_ratio, + 1 - self.config["clip_param"], + 1 + self.config["clip_param"], + ), + ) - # Prepare KL for loss. - action_kl = _make_time_major( - old_policy_action_dist.kl(action_dist), drop_last=drop_last - ) + mean_kl_loss = reduce_mean_valid(action_kl) + mean_policy_loss = -reduce_mean_valid(surrogate_loss) - # Compute vtrace on the CPU for better perf. - vtrace_returns = vtrace.multi_from_logits( - behaviour_policy_logits=_make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), - target_policy_logits=_make_time_major( - unpacked_old_policy_behaviour_logits, drop_last=drop_last - ), - actions=torch.unbind( - _make_time_major(loss_actions, drop_last=drop_last), dim=2 - ), - discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) - * policy.config["gamma"], - rewards=_make_time_major(rewards, drop_last=drop_last), - values=values_time_major[:-1] if drop_last else values_time_major, - bootstrap_value=values_time_major[-1], - dist_class=TorchCategorical if is_multidiscrete else dist_class, - model=model, - clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], - clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"], - ) + # The value function loss. + value_targets = vtrace_returns.vs.to(values_time_major.device) + if drop_last: + delta = values_time_major[:-1] - value_targets + else: + delta = values_time_major - value_targets + mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) - actions_logp = _make_time_major(action_dist.logp(actions), drop_last=drop_last) - prev_actions_logp = _make_time_major( - prev_action_dist.logp(actions), drop_last=drop_last - ) - old_policy_actions_logp = _make_time_major( - old_policy_action_dist.logp(actions), drop_last=drop_last - ) - is_ratio = torch.clamp( - torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 - ) - logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) - policy._is_ratio = is_ratio - - advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) - surrogate_loss = torch.min( - advantages * logp_ratio, - advantages - * torch.clamp( - logp_ratio, - 1 - policy.config["clip_param"], - 1 + policy.config["clip_param"], - ), - ) - - mean_kl_loss = reduce_mean_valid(action_kl) - mean_policy_loss = -reduce_mean_valid(surrogate_loss) + # The entropy loss. + mean_entropy = reduce_mean_valid( + _make_time_major(action_dist.entropy(), drop_last=drop_last) + ) - # The value function loss. - value_targets = vtrace_returns.vs.to(values_time_major.device) - if drop_last: - delta = values_time_major[:-1] - value_targets else: - delta = values_time_major - value_targets - mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) - - # The entropy loss. - mean_entropy = reduce_mean_valid( - _make_time_major(action_dist.entropy(), drop_last=drop_last) - ) + logger.debug("Using PPO surrogate loss (vtrace=False)") + + # Prepare KL for Loss + action_kl = _make_time_major(prev_action_dist.kl(action_dist)) + + actions_logp = _make_time_major(action_dist.logp(actions)) + prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) + logp_ratio = torch.exp(actions_logp - prev_actions_logp) + + advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) + surrogate_loss = torch.min( + advantages * logp_ratio, + advantages + * torch.clamp( + logp_ratio, + 1 - self.config["clip_param"], + 1 + self.config["clip_param"], + ), + ) - else: - logger.debug("Using PPO surrogate loss (vtrace=False)") + mean_kl_loss = reduce_mean_valid(action_kl) + mean_policy_loss = -reduce_mean_valid(surrogate_loss) - # Prepare KL for Loss - action_kl = _make_time_major(prev_action_dist.kl(action_dist)) + # The value function loss. + value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) + delta = values_time_major - value_targets + mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) - actions_logp = _make_time_major(action_dist.logp(actions)) - prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) - logp_ratio = torch.exp(actions_logp - prev_actions_logp) + # The entropy loss. + mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) - advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) - surrogate_loss = torch.min( - advantages * logp_ratio, - advantages - * torch.clamp( - logp_ratio, - 1 - policy.config["clip_param"], - 1 + policy.config["clip_param"], - ), + # The summed weighted loss + total_loss = ( + mean_policy_loss + + mean_vf_loss * self.config["vf_loss_coeff"] + - mean_entropy * self.entropy_coeff ) - mean_kl_loss = reduce_mean_valid(action_kl) - mean_policy_loss = -reduce_mean_valid(surrogate_loss) - - # The value function loss. - value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) - delta = values_time_major - value_targets - mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) - - # The entropy loss. - mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) - - # The summed weighted loss - total_loss = ( - mean_policy_loss - + mean_vf_loss * policy.config["vf_loss_coeff"] - - mean_entropy * policy.entropy_coeff - ) - - # Optional additional KL Loss - if policy.config["use_kl_loss"]: - total_loss += policy.kl_coeff * mean_kl_loss - - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - model.tower_stats["total_loss"] = total_loss - model.tower_stats["mean_policy_loss"] = mean_policy_loss - model.tower_stats["mean_kl_loss"] = mean_kl_loss - model.tower_stats["mean_vf_loss"] = mean_vf_loss - model.tower_stats["mean_entropy"] = mean_entropy - model.tower_stats["value_targets"] = value_targets - model.tower_stats["vf_explained_var"] = explained_variance( - torch.reshape(value_targets, [-1]), - torch.reshape(values_time_major[:-1] if drop_last else values_time_major, [-1]), - ) - - return total_loss - - -def stats(policy: Policy, train_batch: SampleBatch): - """Stats function for APPO. Returns a dict with important loss stats. - - Args: - policy (Policy): The Policy to generate stats for. - train_batch (SampleBatch): The SampleBatch (already) used for training. - - Returns: - Dict[str, TensorType]: The stats dict. - """ - stats_dict = { - "cur_lr": policy.cur_lr, - "total_loss": torch.mean(torch.stack(policy.get_tower_stats("total_loss"))), - "policy_loss": torch.mean( - torch.stack(policy.get_tower_stats("mean_policy_loss")) - ), - "entropy": torch.mean(torch.stack(policy.get_tower_stats("mean_entropy"))), - "entropy_coeff": policy.entropy_coeff, - "var_gnorm": global_norm(policy.model.trainable_variables()), - "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("mean_vf_loss"))), - "vf_explained_var": torch.mean( - torch.stack(policy.get_tower_stats("vf_explained_var")) - ), - } - - if policy.config["vtrace"]: - is_stat_mean = torch.mean(policy._is_ratio, [0, 1]) - is_stat_var = torch.var(policy._is_ratio, [0, 1]) - stats_dict["mean_IS"] = is_stat_mean - stats_dict["var_IS"] = is_stat_var - - if policy.config["use_kl_loss"]: - stats_dict["kl"] = torch.mean( - torch.stack(policy.get_tower_stats("mean_kl_loss")) + # Optional additional KL Loss + if self.config["use_kl_loss"]: + total_loss += self.kl_coeff * mean_kl_loss + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_policy_loss"] = mean_policy_loss + model.tower_stats["mean_kl_loss"] = mean_kl_loss + model.tower_stats["mean_vf_loss"] = mean_vf_loss + model.tower_stats["mean_entropy"] = mean_entropy + model.tower_stats["value_targets"] = value_targets + model.tower_stats["vf_explained_var"] = explained_variance( + torch.reshape(value_targets, [-1]), + torch.reshape( + values_time_major[:-1] if drop_last else values_time_major, [-1] + ), ) - stats_dict["KL_Coeff"] = policy.kl_coeff - - return stats_dict - -def add_values(policy, input_dict, state_batches, model, action_dist): - out = {} - if not policy.config["vtrace"]: - out[SampleBatch.VF_PREDS] = model.value_function() - return out + return total_loss + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Stats function for APPO. Returns a dict with important loss stats. -class KLCoeffMixin: - """Assigns the `update_kl()` method to the PPOPolicy. + Args: + policy (Policy): The Policy to generate stats for. + train_batch (SampleBatch): The SampleBatch (already) used for training. - This is used in PPO's execution plan (see ppo.py) for updating the KL - coefficient after each learning step based on `config.kl_target` and - the measured KL value (from the train_batch). - """ - - def __init__(self, config): - # The current KL value (as python float). - self.kl_coeff = config["kl_coeff"] - # Constant target value. - self.kl_target = config["kl_target"] - - def update_kl(self, sampled_kl): - # Update the current KL value based on the recently measured value. - if sampled_kl > 2.0 * self.kl_target: - self.kl_coeff *= 1.5 - elif sampled_kl < 0.5 * self.kl_target: - self.kl_coeff *= 0.5 - # Return the current KL value. - return self.kl_coeff + Returns: + Dict[str, TensorType]: The stats dict. + """ + stats_dict = { + "cur_lr": self.cur_lr, + "total_loss": torch.mean(torch.stack(self.get_tower_stats("total_loss"))), + "policy_loss": torch.mean( + torch.stack(self.get_tower_stats("mean_policy_loss")) + ), + "entropy": torch.mean(torch.stack(self.get_tower_stats("mean_entropy"))), + "entropy_coeff": self.entropy_coeff, + "var_gnorm": global_norm(self.model.trainable_variables()), + "vf_loss": torch.mean(torch.stack(self.get_tower_stats("mean_vf_loss"))), + "vf_explained_var": torch.mean( + torch.stack(self.get_tower_stats("vf_explained_var")) + ), + } + if self.config["vtrace"]: + is_stat_mean = torch.mean(self._is_ratio, [0, 1]) + is_stat_var = torch.var(self._is_ratio, [0, 1]) + stats_dict["mean_IS"] = is_stat_mean + stats_dict["var_IS"] = is_stat_var -def setup_early_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -): - """Call all mixin classes' constructors before APPOPolicy initialization. - - Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. - config (TrainerConfigDict): The Policy's config. - """ - LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - - -def setup_late_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -): - """Call all mixin classes' constructors after APPOPolicy initialization. - - Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. - config (TrainerConfigDict): The Policy's config. - """ - KLCoeffMixin.__init__(policy, config) - ValueNetworkMixin.__init__(policy, config) - TargetNetworkMixin.__init__(policy) - - -# Build a child class of `TorchPolicy`, given the custom functions defined -# above. -AsyncPPOTorchPolicy = build_policy_class( - name="AsyncPPOTorchPolicy", - framework="torch", - loss_fn=appo_surrogate_loss, - stats_fn=stats, - postprocess_fn=postprocess_trajectory, - extra_action_out_fn=add_values, - extra_grad_process_fn=apply_grad_clipping, - optimizer_fn=choose_optimizer, - before_init=setup_early_mixins, - before_loss_init=setup_late_mixins, - make_model=make_appo_model, - mixins=[ - LearningRateSchedule, - KLCoeffMixin, - TargetNetworkMixin, - ValueNetworkMixin, - EntropyCoeffSchedule, - ], - get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"], -) + if self.config["use_kl_loss"]: + stats_dict["kl"] = torch.mean( + torch.stack(self.get_tower_stats("mean_kl_loss")) + ) + stats_dict["KL_Coeff"] = self.kl_coeff + + return convert_to_numpy(stats_dict) + + @override(TorchPolicyV2) + def extra_action_out( + self, + input_dict: Dict[str, TensorType], + state_batches: List[TensorType], + model: TorchModelV2, + action_dist: TorchDistributionWrapper, + ) -> Dict[str, TensorType]: + out = {} + if not self.config["vtrace"]: + out[SampleBatch.VF_PREDS] = model.value_function() + return out + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, + episode: Optional["Episode"] = None, + ): + # Call super's postprocess_trajectory first. + sample_batch = super().postprocess_trajectory( + sample_batch, other_agent_batches, episode + ) + if not self.config["vtrace"]: + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + with torch.no_grad(): + sample_batch = compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + return sample_batch + + @override(TorchPolicyV2) + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) + + @override(TorchPolicyV2) + def get_batch_divisibility_req(self) -> int: + return self.config["rollout_fragment_length"] diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index 80cb5bbe4b6c..dcf0eca2ce85 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -27,7 +27,7 @@ from ray.rllib.execution.common import ( STEPS_TRAINED_THIS_ITER_COUNTER, ) -from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests +from ray.rllib.execution.parallel_requests import AsyncRequestsManager from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.metrics import ( @@ -275,20 +275,21 @@ def setup(self, config: PartialTrainerConfigDict): ] ) logger.info("Torch process group init completed") + self._ddppo_worker_manager = AsyncRequestsManager( + self.workers.remote_workers(), + max_remote_requests_in_flight_per_worker=1, + ray_wait_timeout_s=0.03, + ) @override(PPOTrainer) def training_iteration(self) -> ResultDict: # Shortcut. first_worker = self.workers.remote_workers()[0] - # Run sampling and update steps on each worker in asynchronous fashion. - sample_and_update_results = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_requests_in_flight, - actors=self.workers.remote_workers(), - ray_wait_timeout_s=0.0, - max_remote_requests_in_flight_per_actor=1, # 2 - remote_fn=self._sample_and_train_torch_distributed, + self._ddppo_worker_manager.call_on_all_available( + self._sample_and_train_torch_distributed ) + sample_and_update_results = self._ddppo_worker_manager.get_ready() # For all results collected: # - Update our counters and timers. diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index cd45ea2e17e4..c9bd73bf3bdb 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -13,7 +13,6 @@ from typing import List, Optional, Type, Union from ray.util.debug import log_once -from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.execution.rollout_ops import ( @@ -369,8 +368,14 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy return PPOTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy + + return PPOStaticGraphTFPolicy else: - return PPOTFPolicy + from ray.rllib.agents.ppo.ppo_tf_policy import PPOEagerTFPolicy + + return PPOEagerTFPolicy @ExperimentalAPI def training_iteration(self) -> ResultDict: diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 462fb579489c..dee7cd211863 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -2,36 +2,30 @@ TensorFlow policy class used for PPO. """ -import gym import logging -from typing import Dict, List, Optional, Type, Union +from typing import Dict, List, Type, Union import ray -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( - compute_gae_for_sample_batch, Postprocessing, + compute_gae_for_sample_batch, ) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_mixins import ( EntropyCoeffSchedule, - KLCoeffMixin, LearningRateSchedule, + KLCoeffMixin, ValueNetworkMixin, + compute_gradients, ) -from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.deprecation import ( - Deprecated, - DEPRECATED_VALUE, - deprecation_warning, -) +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import ( - AgentID, LocalOptimizer, ModelGradients, TensorType, @@ -43,230 +37,11 @@ logger = logging.getLogger(__name__) -def ppo_surrogate_loss( - policy: Policy, - model: Union[ModelV2, "tf.keras.Model"], - dist_class: Type[TFActionDistribution], - train_batch: SampleBatch, -) -> Union[TensorType, List[TensorType]]: - """Constructs the loss for Proximal Policy Objective. - - Args: - policy (Policy): The Policy to calculate the loss for. - model (Union[ModelV2, tf.keras.Model]): The Model to calculate - the loss for. - dist_class (Type[ActionDistribution]: The action distr. class. - train_batch (SampleBatch): The training data. - - Returns: - Union[TensorType, List[TensorType]]: A single loss tensor or a list - of loss tensors. - """ - if isinstance(model, tf.keras.Model): - logits, state, extra_outs = model(train_batch) - value_fn_out = extra_outs[SampleBatch.VF_PREDS] - else: - logits, state = model(train_batch) - value_fn_out = model.value_function() - - curr_action_dist = dist_class(logits, model) - - # RNN case: Mask away 0-padded chunks at end of time axis. - if state: - # Derive max_seq_len from the data itself, not from the seq_lens - # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still - # 0-padded up to T=5 (as it's the case for attention nets). - B = tf.shape(train_batch[SampleBatch.SEQ_LENS])[0] - max_seq_len = tf.shape(logits)[0] // B - - mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) - mask = tf.reshape(mask, [-1]) - - def reduce_mean_valid(t): - return tf.reduce_mean(tf.boolean_mask(t, mask)) - - # non-RNN case: No masking. - else: - mask = None - reduce_mean_valid = tf.reduce_mean - - prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) - - logp_ratio = tf.exp( - curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - - train_batch[SampleBatch.ACTION_LOGP] - ) - - # Only calculate kl loss if necessary (kl-coeff > 0.0). - if policy.config["kl_coeff"] > 0.0: - action_kl = prev_action_dist.kl(curr_action_dist) - mean_kl_loss = reduce_mean_valid(action_kl) - else: - mean_kl_loss = tf.constant(0.0) - - curr_entropy = curr_action_dist.entropy() - mean_entropy = reduce_mean_valid(curr_entropy) - - surrogate_loss = tf.minimum( - train_batch[Postprocessing.ADVANTAGES] * logp_ratio, - train_batch[Postprocessing.ADVANTAGES] - * tf.clip_by_value( - logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"] - ), - ) - mean_policy_loss = reduce_mean_valid(-surrogate_loss) - - # Compute a value function loss. - if policy.config["use_critic"]: - vf_loss = tf.math.square( - value_fn_out - train_batch[Postprocessing.VALUE_TARGETS] - ) - vf_loss_clipped = tf.clip_by_value( - vf_loss, - 0, - policy.config["vf_clip_param"], - ) - mean_vf_loss = reduce_mean_valid(vf_loss_clipped) - # Ignore the value function. - else: - vf_loss_clipped = mean_vf_loss = tf.constant(0.0) - - total_loss = reduce_mean_valid( - -surrogate_loss - + policy.config["vf_loss_coeff"] * vf_loss_clipped - - policy.entropy_coeff * curr_entropy - ) - # Add mean_kl_loss (already processed through `reduce_mean_valid`), - # if necessary. - if policy.config["kl_coeff"] > 0.0: - total_loss += policy.kl_coeff * mean_kl_loss - - # Store stats in policy for stats_fn. - policy._total_loss = total_loss - policy._mean_policy_loss = mean_policy_loss - policy._mean_vf_loss = mean_vf_loss - policy._mean_entropy = mean_entropy - # Backward compatibility: Deprecate policy._mean_kl. - policy._mean_kl_loss = policy._mean_kl = mean_kl_loss - policy._value_fn_out = value_fn_out - - return total_loss - - -def kl_and_loss_stats( - policy: Policy, train_batch: SampleBatch -) -> Dict[str, TensorType]: - """Stats function for PPO. Returns a dict with important KL and loss stats. - - Args: - policy (Policy): The Policy to generate stats for. - train_batch (SampleBatch): The SampleBatch (already) used for training. - - Returns: - Dict[str, TensorType]: The stats dict. - """ - return { - "cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64), - "cur_lr": tf.cast(policy.cur_lr, tf.float64), - "total_loss": policy._total_loss, - "policy_loss": policy._mean_policy_loss, - "vf_loss": policy._mean_vf_loss, - "vf_explained_var": explained_variance( - train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out - ), - "kl": policy._mean_kl_loss, - "entropy": policy._mean_entropy, - "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), - } - - -# TODO: (sven) Deprecate once we only allow native keras models. -def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: - """Defines extra fetches per action computation. - - Args: - policy (Policy): The Policy to perform the extra action fetch on. - - Returns: - Dict[str, TensorType]: Dict with extra tf fetches to perform per - action computation. - """ - # Keras models return values for each call in third return argument - # (dict). - if isinstance(policy.model, tf.keras.Model): - return {} - # Return value function outputs. VF estimates will hence be added to the - # SampleBatches produced by the sampler(s) to generate the train batches - # going into the loss function. - return { - SampleBatch.VF_PREDS: policy.model.value_function(), - } - - -def compute_and_clip_gradients( - policy: Policy, optimizer: LocalOptimizer, loss: TensorType -) -> ModelGradients: - """Gradients computing function (from loss tensor, using local optimizer). - - Args: - policy (Policy): The Policy object that generated the loss tensor and - that holds the given local optimizer. - optimizer (LocalOptimizer): The tf (local) optimizer object to - calculate the gradients with. - loss (TensorType): The loss tensor for which gradients should be - calculated. - - Returns: - ModelGradients: List of the possibly clipped gradients- and variable - tuples. - """ - # Compute the gradients. - variables = policy.model.trainable_variables - if isinstance(policy.model, ModelV2): - variables = variables() - grads_and_vars = optimizer.compute_gradients(loss, variables) - - # Clip by global norm, if necessary. - if policy.config["grad_clip"] is not None: - # Defuse inf gradients (due to super large losses). - grads = [g for (g, v) in grads_and_vars] - grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) - # If the global_norm is inf -> All grads will be NaN. Stabilize this - # here by setting them to 0.0. This will simply ignore destructive loss - # calculations. - policy.grads = [tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads] - clipped_grads_and_vars = list(zip(policy.grads, variables)) - return clipped_grads_and_vars - else: - return grads_and_vars - - -def validate_config( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: +def validate_config(config: TrainerConfigDict) -> None: """Executed before Policy is "initialized" (at beginning of constructor). - Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. config (TrainerConfigDict): The Policy's config. """ - # Setting `vf_share_layers` in the top-level config is deprecated. - # It's confusing as some users might (correctly!) set it in their - # model config and then won't notice that it's silently overwritten - # here. - if config.get("vf_share_layers", DEPRECATED_VALUE) != DEPRECATED_VALUE: - deprecation_warning( - old="config[vf_share_layers]", - new="config[model][vf_share_layers]", - error=True, - ) - config["model"]["vf_share_layers"] = config["vf_share_layers"] - # If vf_share_layers is True, inform about the need to tune vf_loss_coeff. if config.get("model", {}).get("vf_share_layers") is True: logger.info( @@ -275,61 +50,195 @@ def validate_config( ) -def setup_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: - """Call mixin classes' constructors before Policy's loss initialization. +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_ppo_tf_policy(base: type) -> type: + """Construct a PPOTFPolicy inheriting either dynamic or eager base policies. Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. - config (TrainerConfigDict): The Policy's config. + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. + + Returns: + A TF Policy to be used with PPOTrainer. """ - ValueNetworkMixin.__init__(policy, config) - KLCoeffMixin.__init__(policy, config) - EntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - - -@Deprecated( - old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae", - new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", - error=False, -) -def postprocess_ppo_gae( - policy: Policy, - sample_batch: SampleBatch, - other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, - episode: Optional[Episode] = None, -) -> SampleBatch: - - return compute_gae_for_sample_batch( - policy, sample_batch, other_agent_batches, episode - ) - - -# Build a child class of `DynamicTFPolicy`, given the custom functions defined -# above. -PPOTFPolicy = build_tf_policy( - name="PPOTFPolicy", - loss_fn=ppo_surrogate_loss, - get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, - postprocess_fn=compute_gae_for_sample_batch, - stats_fn=kl_and_loss_stats, - compute_gradients_fn=compute_and_clip_gradients, - extra_action_out_fn=vf_preds_fetches, - before_init=validate_config, - before_loss_init=setup_mixins, - mixins=[ - LearningRateSchedule, + + class PPOTFPolicy( EntropyCoeffSchedule, + LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin, - ], -) + base, + ): + def __init__( + self, + obs_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) + validate_config(config) + + # Initialize base class. + base.__init__( + self, + obs_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + # Initialize MixIns. + ValueNetworkMixin.__init__(self, config) + KLCoeffMixin.__init__(self, config) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + if isinstance(model, tf.keras.Model): + logits, state, extra_outs = model(train_batch) + value_fn_out = extra_outs[SampleBatch.VF_PREDS] + else: + logits, state = model(train_batch) + value_fn_out = model.value_function() + + curr_action_dist = dist_class(logits, model) + + # RNN case: Mask away 0-padded chunks at end of time axis. + if state: + # Derive max_seq_len from the data itself, not from the seq_lens + # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still + # 0-padded up to T=5 (as it's the case for attention nets). + B = tf.shape(train_batch[SampleBatch.SEQ_LENS])[0] + max_seq_len = tf.shape(logits)[0] // B + + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + mask = tf.reshape(mask, [-1]) + + def reduce_mean_valid(t): + return tf.reduce_mean(tf.boolean_mask(t, mask)) + + # non-RNN case: No masking. + else: + mask = None + reduce_mean_valid = tf.reduce_mean + + prev_action_dist = dist_class( + train_batch[SampleBatch.ACTION_DIST_INPUTS], model + ) + + logp_ratio = tf.exp( + curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) + - train_batch[SampleBatch.ACTION_LOGP] + ) + + # Only calculate kl loss if necessary (kl-coeff > 0.0). + if self.config["kl_coeff"] > 0.0: + action_kl = prev_action_dist.kl(curr_action_dist) + mean_kl_loss = reduce_mean_valid(action_kl) + else: + mean_kl_loss = tf.constant(0.0) + + curr_entropy = curr_action_dist.entropy() + mean_entropy = reduce_mean_valid(curr_entropy) + + surrogate_loss = tf.minimum( + train_batch[Postprocessing.ADVANTAGES] * logp_ratio, + train_batch[Postprocessing.ADVANTAGES] + * tf.clip_by_value( + logp_ratio, + 1 - self.config["clip_param"], + 1 + self.config["clip_param"], + ), + ) + mean_policy_loss = reduce_mean_valid(-surrogate_loss) + + # Compute a value function loss. + if self.config["use_critic"]: + vf_loss = tf.math.square( + value_fn_out - train_batch[Postprocessing.VALUE_TARGETS] + ) + vf_loss_clipped = tf.clip_by_value( + vf_loss, + 0, + self.config["vf_clip_param"], + ) + mean_vf_loss = reduce_mean_valid(vf_loss_clipped) + # Ignore the value function. + else: + vf_loss_clipped = mean_vf_loss = tf.constant(0.0) + + total_loss = reduce_mean_valid( + -surrogate_loss + + self.config["vf_loss_coeff"] * vf_loss_clipped + - self.entropy_coeff * curr_entropy + ) + # Add mean_kl_loss (already processed through `reduce_mean_valid`), + # if necessary. + if self.config["kl_coeff"] > 0.0: + total_loss += self.kl_coeff * mean_kl_loss + + # Store stats in policy for stats_fn. + self._total_loss = total_loss + self._mean_policy_loss = mean_policy_loss + self._mean_vf_loss = mean_vf_loss + self._mean_entropy = mean_entropy + # Backward compatibility: Deprecate self._mean_kl. + self._mean_kl_loss = self._mean_kl = mean_kl_loss + self._value_fn_out = value_fn_out + + return total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return { + "cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64), + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "total_loss": self._total_loss, + "policy_loss": self._mean_policy_loss, + "vf_loss": self._mean_vf_loss, + "vf_explained_var": explained_variance( + train_batch[Postprocessing.VALUE_TARGETS], self._value_fn_out + ), + "kl": self._mean_kl_loss, + "entropy": self._mean_entropy, + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), + } + + @override(base) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + sample_batch = super().postprocess_trajectory(sample_batch) + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + @override(base) + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + return compute_gradients(self, optimizer, loss) + + return PPOTFPolicy + + +PPOStaticGraphTFPolicy = get_ppo_tf_policy(DynamicTFPolicyV2) +PPOEagerTFPolicy = get_ppo_tf_policy(EagerTFPolicyV2) diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index 0e3a0d54cf1b..cd6d0667d702 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -4,8 +4,8 @@ import ray from ray.rllib.agents.ppo.ppo_tf_policy import validate_config from ray.rllib.evaluation.postprocessing import ( - compute_gae_for_sample_batch, Postprocessing, + compute_gae_for_sample_batch, ) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.action_dist import ActionDistribution @@ -16,7 +16,7 @@ LearningRateSchedule, ValueNetworkMixin, ) -from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy @@ -37,15 +37,15 @@ class PPOTorchPolicy( LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, - TorchPolicy, + TorchPolicyV2, ): """PyTorch policy class used with PPOTrainer.""" def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) - validate_config(self, observation_space, action_space, config) + validate_config(config) - TorchPolicy.__init__( + TorchPolicyV2.__init__( self, observation_space, action_space, @@ -54,42 +54,23 @@ def __init__(self, observation_space, action_space, config): ) ValueNetworkMixin.__init__(self, config) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) EntropyCoeffSchedule.__init__( self, config["entropy_coeff"], config["entropy_coeff_schedule"] ) - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - - # The current KL value (as python float). - self.kl_coeff = self.config["kl_coeff"] - # Constant target value. - self.kl_target = self.config["kl_target"] + KLCoeffMixin.__init__(self, config) # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch() - @override(TorchPolicy) - def postprocess_trajectory( - self, sample_batch, other_agent_batches=None, episode=None - ): - # Do all post-processing always with no_grad(). - # Not using this here will introduce a memory leak - # in torch (issue #6962). - # TODO: no_grad still necessary? - with torch.no_grad(): - return compute_gae_for_sample_batch( - self, sample_batch, other_agent_batches, episode - ) - - # TODO: Add method to Policy base class (as the new way of defining loss - # functions (instead of passing 'loss` to the super's constructor)). - @override(TorchPolicy) + @override(TorchPolicyV2) def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: - """Constructs the loss for Proximal Policy Objective. + """Compute loss for Proximal Policy Objective. Args: model: The Model to calculate the loss for. @@ -190,14 +171,12 @@ def reduce_mean_valid(t): # TODO: Make this an event-style subscription (e.g.: # "after_gradients_computed"). - @override(TorchPolicy) + @override(TorchPolicyV2) def extra_grad_process(self, local_optimizer, loss): return apply_grad_clipping(self, local_optimizer, loss) - # TODO: Make this an event-style subscription (e.g.: - # "after_losses_computed"). - @override(TorchPolicy) - def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: return convert_to_numpy( { "cur_kl_coeff": self.kl_coeff, @@ -221,3 +200,16 @@ def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]: "entropy_coeff": self.entropy_coeff, } ) + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + # TODO: no_grad still necessary? + with torch.no_grad(): + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) diff --git a/rllib/agents/ppo/tests/test_appo.py b/rllib/agents/ppo/tests/test_appo.py index 36cbec0bee44..2f55d8ac4691 100644 --- a/rllib/agents/ppo/tests/test_appo.py +++ b/rllib/agents/ppo/tests/test_appo.py @@ -89,6 +89,7 @@ def test_appo_two_tf_optimizers(self): trainer.stop() def test_appo_entropy_coeff_schedule(self): + # Initial lr, doesn't really matter because of the schedule below. config = ( ppo.appo.APPOConfig() .rollouts( @@ -99,12 +100,12 @@ def test_appo_entropy_coeff_schedule(self): .resources(num_gpus=0) .training( train_batch_size=20, - # Initial entropy_coeff, doesn't really matter because of the schedule - # below. - entropy_coeff=0.1, + entropy_coeff=0.01, entropy_coeff_schedule=[ [0, 0.1], - [200, 0.001], + [100, 0.01], + [300, 0.001], + [500, 0.0001], ], ) ) @@ -131,16 +132,15 @@ def _step_n_times(trainer, n: int): for _ in framework_iterator(config): trainer = config.build(env="CartPole-v0") - coeff = _step_n_times(trainer, 5) # 100 timesteps - # Should be somewhere between starting coeff 0.1 and end coeff 0.001. - self.assertLessEqual(coeff, 0.075) - self.assertGreaterEqual(coeff, 0.03) - - coeff = _step_n_times(trainer, 5) # 200 timesteps - # Should have annealed to the final coeff of 0.001. - self.assertLessEqual(coeff, 0.03) + coeff = _step_n_times(trainer, 10) # 200 timesteps + # Should be close to the starting coeff of 0.01. + self.assertLessEqual(coeff, 0.01) self.assertGreaterEqual(coeff, 0.001) + coeff = _step_n_times(trainer, 20) # 400 timesteps + # Should have annealed to the final coeff of 0.0001. + self.assertLessEqual(coeff, 0.001) + trainer.stop() diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index de7fdd67893e..ad070ca40ccf 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -4,9 +4,7 @@ import ray from ray.rllib.agents.callbacks import DefaultCallbacks import ray.rllib.agents.ppo as ppo -from ray.rllib.agents.ppo.ppo_tf_policy import ( - ppo_surrogate_loss as ppo_surrogate_loss_tf, -) +from ray.rllib.agents.ppo.ppo_tf_policy import PPOEagerTFPolicy from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.evaluation.postprocessing import ( compute_gae_for_sample_batch, @@ -315,7 +313,7 @@ def test_ppo_loss_function(self): # Calculate actual PPO loss. if fw in ["tf2", "tfe"]: - ppo_surrogate_loss_tf(policy, policy.model, Categorical, train_batch) + PPOEagerTFPolicy.loss(policy, policy.model, Categorical, train_batch) elif fw == "torch": PPOTorchPolicy.loss( policy, policy.model, policy.dist_class, train_batch diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index ea7e3a08a922..077562fc771c 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -32,7 +32,7 @@ from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.utils import gym_env_creator +from ray.rllib.env.utils import _gym_env_creator from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.metrics import ( collect_episodes, @@ -58,6 +58,8 @@ DeveloperAPI, ExperimentalAPI, override, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, PublicAPI, ) from ray.rllib.utils.debug import update_global_seed_if_necessary @@ -288,10 +290,12 @@ def default_logger_creator(config): config, logger_creator, remote_checkpoint_dir, sync_function_tpl ) + @OverrideToImplementCustomLogic @classmethod def get_default_config(cls) -> TrainerConfigDict: return TrainerConfig().to_dict() + @OverrideToImplementCustomLogic_CallToSuperRecommended @override(Trainable) def setup(self, config: PartialTrainerConfigDict): @@ -487,6 +491,7 @@ def setup(self, config: PartialTrainerConfigDict): def _init(self, config: TrainerConfigDict, env_creator: EnvCreator) -> None: raise NotImplementedError + @OverrideToImplementCustomLogic def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: """Returns a default Policy class to use, given a config. @@ -518,12 +523,27 @@ def step(self) -> ResultDict: and - if required - evaluation. """ step_attempt_results = None - + self._rollout_worker_metrics = [] + local_worker = ( + self.workers.local_worker() + if hasattr(self.workers, "local_worker") + else None + ) with self._step_context() as step_ctx: while not step_ctx.should_stop(step_attempt_results): # Try to train one step. try: step_attempt_results = self.step_attempt() + # Collect rollout worker metrics. + episodes, self._episodes_to_be_collected = collect_episodes( + local_worker, + self._remote_workers_for_metrics, + self._episodes_to_be_collected, + timeout_seconds=self.config[ + "metrics_episode_collection_timeout_s" + ], + ) + self._rollout_worker_metrics.extend(episodes) # @ray.remote RolloutWorker failure. except RayError as e: # Try to recover w/o the failed worker. @@ -845,6 +865,7 @@ def duration_fn(num_units_done): # Also return the results here for convenience. return self.evaluation_metrics + @OverrideToImplementCustomLogic @DeveloperAPI def training_iteration(self) -> ResultDict: """Default single iteration logic of an algorithm. @@ -1467,9 +1488,12 @@ def log_result(self, result: ResultDict) -> None: @override(Trainable) def cleanup(self) -> None: # Stop all workers. - if hasattr(self, "workers"): + if hasattr(self, "workers") and self.workers is not None: self.workers.stop() + if hasattr(self, "evaluation_workers") and self.evaluation_workers is not None: + self.evaluation_workers.stop() + @OverrideToImplementCustomLogic @classmethod @override(Trainable) def default_resource_request( @@ -1569,7 +1593,7 @@ def env_creator_from_classpath(env_context): # Try gym/PyBullet/Vizdoom. else: return env_specifier, functools.partial( - gym_env_creator, env_descriptor=env_specifier + _gym_env_creator, env_descriptor=env_specifier ) elif isinstance(env_specifier, type): @@ -1777,6 +1801,7 @@ def resolve_tf_settings(): check_if_correct_nn_framework_installed() resolve_tf_settings() + @OverrideToImplementCustomLogic_CallToSuperRecommended @DeveloperAPI def validate_config(self, config: TrainerConfigDict) -> None: """Validates a given config dict for this Trainer. @@ -2042,11 +2067,13 @@ def try_recover_from_step_attempt(self) -> None: if not isinstance(workers, WorkerSet): return + removed_workers, new_workers = [], [] # Search for failed workers and try to recover (restart) them. if self.config["recreate_failed_workers"] is True: - workers.recreate_failed_workers() + removed_workers, new_workers = workers.recreate_failed_workers() elif self.config["ignore_worker_failures"] is True: - workers.remove_failed_workers() + removed_workers = workers.remove_failed_workers() + self.on_worker_failures(removed_workers, new_workers) if not self.config.get("_disable_execution_plan_api") and callable( self.execution_plan @@ -2056,6 +2083,17 @@ def try_recover_from_step_attempt(self) -> None: workers, self.config, **self._kwargs_for_execution_plan() ) + def on_worker_failures( + self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle] + ): + """Called after a worker failure is detected. + + Args: + removed_workers: List of removed workers. + new_workers: List of new workers. + """ + pass + @override(Trainable) def _export_model( self, export_formats: List[str], export_dir: str @@ -2276,13 +2314,7 @@ def _compile_step_results(self, *, step_ctx, step_attempt_results=None): # Learner info. results["info"] = {LEARNER_INFO: step_attempt_results} - # Collect rollout worker metrics. - episodes, self._episodes_to_be_collected = collect_episodes( - self.workers.local_worker(), - self._remote_workers_for_metrics, - self._episodes_to_be_collected, - timeout_seconds=self.config["metrics_episode_collection_timeout_s"], - ) + episodes = self._rollout_worker_metrics orig_episodes = list(episodes) missing = self.config["metrics_num_episodes_for_smoothing"] - len(episodes) if missing > 0: diff --git a/rllib/algorithms/alpha_star/alpha_star.py b/rllib/algorithms/alpha_star/alpha_star.py index cf856f46e95b..639e1446e3d2 100644 --- a/rllib/algorithms/alpha_star/alpha_star.py +++ b/rllib/algorithms/alpha_star/alpha_star.py @@ -14,7 +14,9 @@ from ray.rllib.agents.trainer import Trainer import ray.rllib.agents.ppo.appo as appo from ray.rllib.evaluation.rollout_worker import RolloutWorker -from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests +from ray.rllib.execution.parallel_requests import ( + AsyncRequestsManager, +) from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import MultiAgentBatch @@ -86,8 +88,15 @@ def __init__(self, trainer_class=None): # AlphaStar specific settings: self.replay_buffer_capacity = 20 self.replay_buffer_replay_ratio = 0.5 - self.sample_wait_timeout = 0.01 - self.learn_wait_timeout = 0.1 + # Tuning max_requests_in_flight_per_sampler_worker and + # max_requests_in_flight_per_learner_worker is important so backpressure is + # created on the remote workers and the object store doesn't fill up + # unexpectedly. If the workers spend time idle, consider increasing these. + self.max_requests_in_flight_per_sampler_worker = 2 + self.max_requests_in_flight_per_learner_worker = 2 + + self.timeout_s_sampler_manager = 0.0 + self.timeout_s_learner_manager = 0.0 # League-building parameters. # The LeagueBuilder class to be used for league building logic. @@ -132,6 +141,7 @@ def __init__(self, trainer_class=None): # values. self.vtrace_drop_last_ts = False self.min_time_s_per_reporting = 2 + self._disable_execution_plan_api = True # __sphinx_doc_end__ # fmt: on @@ -141,8 +151,10 @@ def training( *, replay_buffer_capacity: Optional[int] = None, replay_buffer_replay_ratio: Optional[float] = None, - sample_wait_timeout: Optional[float] = None, - learn_wait_timeout: Optional[float] = None, + max_requests_in_flight_per_sampler_worker: Optional[int] = None, + max_requests_in_flight_per_learner_worker: Optional[int] = None, + timeout_s_sampler_manager: Optional[float] = None, + timeout_s_learner_manager: Optional[float] = None, league_builder_config: Optional[Dict[str, Any]] = None, max_num_policies_to_train: Optional[int] = None, **kwargs, @@ -154,15 +166,26 @@ def training( policy. replay_buffer_replay_ratio: For example, ratio=0.2 -> 20% of samples in each train batch are old (replayed) ones. - sample_wait_timeout: Timeout to use for `ray.wait()` when waiting for + timeout_s_sampler_manager: Timeout to use for `ray.wait()` when waiting for samplers to have placed new data into the buffers. If no samples are ready within the timeout, the buffers used for mixin-sampling will return only older samples. - learn_wait_timeout: Timeout to use for `ray.wait()` when waiting for the - policy learner actors to have performed an update and returned learning - stats. If no learner actors have produced any learning results in the - meantime, their learner-stats in the results will be empty for that - iteration. + timeout_s_learner_manager: Timeout to use for `ray.wait()` when waiting for + the policy learner actors to have performed an update and returned + learning stats. If no learner actors have produced any learning + results in the meantime, their learner-stats in the results will be + empty for that iteration. + max_requests_in_flight_per_sampler_worker: Maximum number of ray remote + calls that can be run in parallel for each sampler worker. This is + particularly important when dealing with many sampler workers or + sample batches that are large, and when could potentially fill up + the object store. + max_requests_in_flight_per_learner_worker: Maximum number of ray remote + calls that can be run in parallel for each learner worker. This is + important to tune when dealing with many learner workers so that the + object store doesn't fill up and so that learner actors don't become + backed up with too many requests that could become stale if not + attended to in a timely manner. league_builder_config: League-building config dict. The dict Must contain a `type` key indicating the LeagueBuilder class to be used for league building logic. All other keys (that are not @@ -189,14 +212,22 @@ def training( self.replay_buffer_capacity = replay_buffer_capacity if replay_buffer_replay_ratio is not None: self.replay_buffer_replay_ratio = replay_buffer_replay_ratio - if sample_wait_timeout is not None: - self.sample_wait_timeout = sample_wait_timeout - if learn_wait_timeout is not None: - self.learn_wait_timeout = learn_wait_timeout + if timeout_s_sampler_manager is not None: + self.timeout_s_sampler_manager = timeout_s_sampler_manager + if timeout_s_learner_manager is not None: + self.timeout_s_learner_manager = timeout_s_learner_manager if league_builder_config is not None: self.league_builder_config = league_builder_config if max_num_policies_to_train is not None: self.max_num_policies_to_train = max_num_policies_to_train + if max_requests_in_flight_per_sampler_worker is not None: + self.max_requests_in_flight_per_sampler_worker = ( + max_requests_in_flight_per_sampler_worker + ) + if max_requests_in_flight_per_learner_worker is not None: + self.max_requests_in_flight_per_learner_worker = ( + max_requests_in_flight_per_learner_worker + ) return self @@ -356,6 +387,21 @@ def _set_policy_learners(worker): ) self.distributed_learners = distributed_learners + self._sampling_actor_manager = AsyncRequestsManager( + self.workers.remote_workers(), + max_remote_requests_in_flight_per_worker=self.config[ + "max_requests_in_flight_per_sampler_worker" + ], + ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], + ) + policy_actors = [policy_actor for _, policy_actor, _ in distributed_learners] + self._learner_worker_manager = AsyncRequestsManager( + workers=policy_actors, + max_remote_requests_in_flight_per_worker=self.config[ + "max_requests_in_flight_per_learner_worker" + ], + ray_wait_timeout_s=self.config["timeout_s_learner_manager"], + ) @override(Trainer) def step(self) -> ResultDict: @@ -374,13 +420,16 @@ def training_iteration(self) -> ResultDict: # - Rollout results are sent directly to correct replay buffer # shards, instead of here (to the driver). with self._timers[SAMPLE_TIMER]: - sample_results = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_requests_in_flight, - actors=self.workers.remote_workers() or [self.workers.local_worker()], - ray_wait_timeout_s=self.config["sample_wait_timeout"], - max_remote_requests_in_flight_per_actor=2, - remote_fn=self._sample_and_send_to_buffer, - ) + # if there are no remote workers (e.g. num_workers=0) + if not self.workers.remote_workers(): + worker = self.workers.local_worker() + statistics = worker.apply(self._sample_and_send_to_buffer) + sample_results = {worker: [statistics]} + else: + self._sampling_actor_manager.call_on_all_available( + self._sample_and_send_to_buffer + ) + sample_results = self._sampling_actor_manager.get_ready() # Update sample counters. for sample_result in sample_results.values(): for (env_steps, agent_steps) in sample_result: @@ -390,19 +439,13 @@ def training_iteration(self) -> ResultDict: # Trigger asynchronous training update requests on all learning # policies. with self._timers[LEARN_ON_BATCH_TIMER]: - pol_actors = [] - args = [] for pid, pol_actor, repl_actor in self.distributed_learners: - pol_actors.append(pol_actor) - args.append([repl_actor, pid]) - train_results = asynchronous_parallel_requests( - remote_requests_in_flight=self.remote_requests_in_flight, - actors=pol_actors, - ray_wait_timeout_s=self.config["learn_wait_timeout"], - max_remote_requests_in_flight_per_actor=2, - remote_fn=self._update_policy, - remote_args=args, - ) + if pol_actor not in self._learner_worker_manager.workers: + self._learner_worker_manager.add_workers(pol_actor) + self._learner_worker_manager.call( + self._update_policy, actor=pol_actor, fn_args=[repl_actor, pid] + ) + train_results = self._learner_worker_manager.get_ready() # Update sample counters. for train_result in train_results.values(): diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 827854bece6c..7fd06c61a91c 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -12,6 +12,7 @@ multi_gpu_train_one_step, train_one_step, ) +from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer from ray.rllib.offline.shuffled_input import ShuffledInput from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch @@ -150,7 +151,7 @@ def __init__(self, *args, **kwargs): ) batch[SampleBatch.DONES][-1] = True self.local_replay_buffer.add_batch(batch) - print( + logger.info( f"Loaded {num_batches} batches ({total_timesteps} ts) into the" " replay buffer, which has capacity " f"{self.local_replay_buffer.capacity}." @@ -215,7 +216,11 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: def training_iteration(self) -> ResultDict: # Sample training batch from replay buffer. - train_batch = self.local_replay_buffer.sample(self.config["train_batch_size"]) + train_batch = sample_min_n_steps_from_buffer( + self.local_replay_buffer, + self.config["train_batch_size"], + count_by_agent_steps=self._by_agent_steps, + ) # Old-style replay buffers return None if learning has not started if not train_batch: diff --git a/rllib/algorithms/cql/tests/test_cql.py b/rllib/algorithms/cql/tests/test_cql.py index aed354d06ddf..b36ffb089f9c 100644 --- a/rllib/algorithms/cql/tests/test_cql.py +++ b/rllib/algorithms/cql/tests/test_cql.py @@ -95,10 +95,12 @@ def test_cql_compilation(self): # Example on how to do evaluation on the trained Trainer # using the data from CQL's global replay buffer. - # Get a sample (MultiAgentBatch -> SampleBatch). - batch = trainer.local_replay_buffer.replay().policy_batches[ - "default_policy" - ] + # Get a sample (MultiAgentBatch). + multi_agent_batch = trainer.local_replay_buffer.sample( + num_items=config.train_batch_size + ) + # All experiences have been buffered for `default_policy` + batch = multi_agent_batch.policy_batches["default_policy"] if fw == "torch": obs = torch.from_numpy(batch["obs"]) diff --git a/rllib/algorithms/ddpg/apex.py b/rllib/algorithms/ddpg/apex.py index 94f65e9ecdac..66bb5d62ba2c 100644 --- a/rllib/algorithms/ddpg/apex.py +++ b/rllib/algorithms/ddpg/apex.py @@ -1,3 +1,7 @@ +from typing import List + +from ray.actor import ActorHandle +from ray.rllib.agents import Trainer from ray.rllib.agents.dqn.apex import ApexTrainer from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer from ray.rllib.evaluation.worker_set import WorkerSet @@ -44,6 +48,22 @@ "target_network_update_freq": 500000, "min_sample_timesteps_per_reporting": 25000, "min_time_s_per_reporting": 30, + "training_intensity": 1, + # max number of inflight requests to each sampling worker + # see the AsyncRequestsManager class for more details + # Tuning these values is important when running experimens with large sample + # batches. If the sample batches are large in size, then there is the risk that + # the object store may fill up, causing the store to spill objects to disk. + # This can cause any asynchronous requests to become very slow, making your + # experiment run slowly. You can inspect the object store during your + # experiment via a call to ray memory on your headnode, and by using the ray + # dashboard. If you're seeing that the object store is filling up, turn down + # the number of remote requests in flight, or enable compression in your + # experiment of timesteps. + "max_requests_in_flight_per_sampler_worker": 2, + "max_requests_in_flight_per_replay_worker": float("inf"), + "timeout_s_sampler_manager": 0.0, + "timeout_s_replay_manager": 0.0, }, _allow_unknown_configs=True, ) @@ -64,6 +84,19 @@ def training_iteration(self) -> ResultDict: """Use APEX-DQN's training iteration function.""" return ApexTrainer.training_iteration(self) + @override(Trainer) + def on_worker_failures( + self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle] + ): + """Handle the failures of remote sampling workers + + Args: + removed_workers: removed worker ids. + new_workers: ids of newly created workers. + """ + self._sampling_actor_manager.remove_workers(removed_workers) + self._sampling_actor_manager.add_workers(new_workers) + @staticmethod @override(DDPGTrainer) def execution_plan( diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index d1bd28606e82..60730b88d641 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -11,6 +11,7 @@ import logging from typing import List, Optional, Type, Callable +import numpy as np from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy @@ -47,6 +48,7 @@ NUM_TARGET_UPDATES, ) from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer logger = logging.getLogger(__name__) @@ -199,7 +201,7 @@ def training( `train_batch_size` / (`rollout_fragment_length` x `num_workers` x `num_envs_per_worker`). If not None, will make sure that the ratio between timesteps inserted - into and sampled from th buffer matches the given values. + into and sampled from the buffer matches the given values. Example: training_intensity=1000.0 train_batch_size=250 @@ -217,7 +219,6 @@ def training( "type": "MultiAgentReplayBuffer", "learning_starts": 1000, "capacity": 50000, - "replay_batch_size": 32, "replay_sequence_length": 1, } - OR - @@ -293,13 +294,18 @@ def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: native_ratio = config["train_batch_size"] / ( config["rollout_fragment_length"] * config["num_envs_per_worker"] - * config["num_workers"] + # Add one to workers because the local + # worker usually collects experiences as well, and we avoid division by zero. + * max(config["num_workers"] + 1, 1) ) # Training intensity is specified in terms of # (steps_replayed / steps_sampled), so adjust for the native ratio. - weights = [1, config["training_intensity"] / native_ratio] - return weights + sample_and_train_weight = config["training_intensity"] / native_ratio + if sample_and_train_weight < 1: + return [int(np.round(1 / sample_and_train_weight)), 1] + else: + return [1, int(np.round(sample_and_train_weight))] class DQNTrainer(SimpleQTrainer): @@ -366,8 +372,10 @@ def training_iteration(self) -> ResultDict: for _ in range(sample_and_train_weight): # Sample training batch (MultiAgentBatch) from replay buffer. - train_batch = self.local_replay_buffer.sample( - self.config["train_batch_size"] + train_batch = sample_min_n_steps_from_buffer( + self.local_replay_buffer, + self.config["train_batch_size"], + count_by_agent_steps=self._by_agent_steps, ) # Old-style replay buffers return None if learning has not started @@ -379,6 +387,10 @@ def training_iteration(self) -> ResultDict: post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) + # for policy_id, sample_batch in train_batch.policy_batches.items(): + # print(len(sample_batch["obs"])) + # print(sample_batch.count) + # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) diff --git a/rllib/algorithms/dqn/dqn_tf_policy.py b/rllib/algorithms/dqn/dqn_tf_policy.py index b35e3dee0fed..441ee82cd3ab 100644 --- a/rllib/algorithms/dqn/dqn_tf_policy.py +++ b/rllib/algorithms/dqn/dqn_tf_policy.py @@ -6,14 +6,16 @@ import numpy as np import ray from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel -from ray.rllib.algorithms.dqn.simple_q_tf_policy import TargetNetworkMixin from ray.rllib.evaluation.postprocessing import adjust_nstep from ray.rllib.models import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import LearningRateSchedule +from ray.rllib.policy.tf_mixins import ( + LearningRateSchedule, + TargetNetworkMixin, +) from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration import ParameterNoise diff --git a/rllib/algorithms/dqn/dqn_torch_policy.py b/rllib/algorithms/dqn/dqn_torch_policy.py index c3bf1dcccf6e..e2d1b349a24e 100644 --- a/rllib/algorithms/dqn/dqn_torch_policy.py +++ b/rllib/algorithms/dqn/dqn_torch_policy.py @@ -11,7 +11,6 @@ postprocess_nstep_and_prio, ) from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel -from ray.rllib.algorithms.dqn.simple_q_torch_policy import TargetNetworkMixin from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( @@ -21,7 +20,10 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_mixins import LearningRateSchedule +from ray.rllib.policy.torch_mixins import ( + LearningRateSchedule, + TargetNetworkMixin, +) from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/algorithms/dqn/simple_q.py b/rllib/algorithms/dqn/simple_q.py index f50920b8d207..291db786f6b9 100644 --- a/rllib/algorithms/dqn/simple_q.py +++ b/rllib/algorithms/dqn/simple_q.py @@ -57,7 +57,6 @@ class SimpleQConfig(TrainerConfig): >>> replay_config = config.replay_buffer_config.update( >>> { >>> "capacity": 40000, - >>> "replay_batch_size": 64, >>> } >>> ) >>> config.training(replay_buffer_config=replay_config)\ @@ -112,11 +111,10 @@ def __init__(self, trainer_class=None): # __sphinx_doc_begin__ self.target_network_update_freq = 500 self.replay_buffer_config = { - "type": "MultiAgentReplayBuffer", - "capacity": 50000, # How many steps of the model to sample before learning starts. "learning_starts": 1000, - "replay_batch_size": 32, + "type": "MultiAgentReplayBuffer", + "capacity": 50000, # The number of contiguous environment steps to replay at once. This # may be set to greater than 1 to support recurrent models. "replay_sequence_length": 1, @@ -189,7 +187,6 @@ def training( "type": "MultiAgentReplayBuffer", "learning_starts": 1000, "capacity": 50000, - "replay_batch_size": 32, "replay_sequence_length": 1, } - OR - diff --git a/rllib/algorithms/dqn/simple_q_tf_policy.py b/rllib/algorithms/dqn/simple_q_tf_policy.py index f550cf60e47f..958849f40662 100644 --- a/rllib/algorithms/dqn/simple_q_tf_policy.py +++ b/rllib/algorithms/dqn/simple_q_tf_policy.py @@ -12,12 +12,11 @@ from ray.rllib.policy import Policy from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.tf_mixins import TargetNetworkMixin from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable +from ray.rllib.utils.tf_utils import huber_loss from ray.rllib.utils.typing import TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() @@ -27,52 +26,6 @@ Q_TARGET_SCOPE = "target_q_func" -class TargetNetworkMixin: - """Assign the `update_target` method to the SimpleQTFPolicy - - The function is called every `target_network_update_freq` steps by the - master learner. - """ - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, - ): - @make_tf_callable(self.get_session()) - def do_update(): - # update_target_fn will be called periodically to copy Q network to - # target Q network - update_target_expr = [] - assert len(self.q_func_vars) == len(self.target_q_func_vars), ( - self.q_func_vars, - self.target_q_func_vars, - ) - for var, var_target in zip(self.q_func_vars, self.target_q_func_vars): - update_target_expr.append(var_target.assign(var)) - logger.debug("Update target op {}".format(var_target)) - return tf.group(*update_target_expr) - - self.update_target = do_update - - @property - def q_func_vars(self): - if not hasattr(self, "_q_func_vars"): - self._q_func_vars = self.model.variables() - return self._q_func_vars - - @property - def target_q_func_vars(self): - if not hasattr(self, "_target_q_func_vars"): - self._target_q_func_vars = self.target_model.variables() - return self._target_q_func_vars - - @override(TFPolicy) - def variables(self): - return self.q_func_vars + self.target_q_func_vars - - def build_q_models( policy: Policy, obs_space: gym.spaces.Space, diff --git a/rllib/algorithms/dqn/simple_q_torch_policy.py b/rllib/algorithms/dqn/simple_q_torch_policy.py index 0c89d9b12880..bc3869fc3091 100644 --- a/rllib/algorithms/dqn/simple_q_torch_policy.py +++ b/rllib/algorithms/dqn/simple_q_torch_policy.py @@ -18,8 +18,7 @@ from ray.rllib.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.utils.annotations import override +from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import concat_multi_gpu_td_errors, huber_loss from ray.rllib.utils.typing import TensorType, TrainerConfigDict @@ -31,33 +30,6 @@ logger = logging.getLogger(__name__) -class TargetNetworkMixin: - """Assign the `update_target` method to the SimpleQTorchPolicy - - The function is called every `target_network_update_freq` steps by the - master learner. - """ - - def __init__(self): - # Hard initial update from Q-net(s) to target Q-net(s). - self.update_target() - - def update_target(self): - # Update_target_fn will be called periodically to copy Q network to - # target Q networks. - state_dict = self.model.state_dict() - for target in self.target_models.values(): - target.load_state_dict(state_dict) - - @override(TorchPolicy) - def set_weights(self, weights): - # Makes sure that whenever we restore weights for this policy's - # model, we sync the target network (from the main model) - # at the same time. - TorchPolicy.set_weights(self, weights) - self.update_target() - - def build_q_model_and_distribution( policy: Policy, obs_space: gym.spaces.Space, diff --git a/rllib/agents/maddpg/README.md b/rllib/algorithms/maddpg/README.md similarity index 100% rename from rllib/agents/maddpg/README.md rename to rllib/algorithms/maddpg/README.md diff --git a/rllib/algorithms/maddpg/__init__.py b/rllib/algorithms/maddpg/__init__.py new file mode 100644 index 000000000000..2ae788f1ebd6 --- /dev/null +++ b/rllib/algorithms/maddpg/__init__.py @@ -0,0 +1,3 @@ +from ray.rllib.algorithms.maddpg.maddpg import MADDPGTrainer, DEFAULT_CONFIG + +__all__ = ["MADDPGTrainer", "DEFAULT_CONFIG"] diff --git a/rllib/agents/maddpg/maddpg.py b/rllib/algorithms/maddpg/maddpg.py similarity index 95% rename from rllib/agents/maddpg/maddpg.py rename to rllib/algorithms/maddpg/maddpg.py index 85b186a8e82b..e63321586169 100644 --- a/rllib/agents/maddpg/maddpg.py +++ b/rllib/algorithms/maddpg/maddpg.py @@ -12,12 +12,11 @@ import logging from typing import Type -from ray.rllib.agents.maddpg.maddpg_tf_policy import MADDPGTFPolicy from ray.rllib.algorithms.dqn.dqn import DQNTrainer -from ray.rllib.agents.trainer import COMMON_CONFIG, with_common_config +from ray.rllib.algorithms.maddpg.maddpg_tf_policy import MADDPGTFPolicy +from ray.rllib.agents.trainer import with_common_config from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch -from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.deprecation import DEPRECATED_VALUE @@ -77,6 +76,8 @@ "capacity": int(1e6), # How many steps of the model to sample before learning starts. "learning_starts": 1024 * 25, + # Force lockstep replay mode for MADDPG. + "replay_mode": "lockstep", }, # Observation compression. Note that compression makes simulation slow in # MPE. @@ -86,10 +87,6 @@ # timesteps. Otherwise, the replay will proceed at the native ratio # determined by (train_batch_size / rollout_fragment_length). "training_intensity": None, - # Force lockstep replay mode for MADDPG. - "multiagent": merge_dicts(COMMON_CONFIG["multiagent"], { - "replay_mode": "lockstep", - }), # === Optimization === # Learning rate for the critic (Q-function) optimizer. diff --git a/rllib/agents/maddpg/maddpg_tf_policy.py b/rllib/algorithms/maddpg/maddpg_tf_policy.py similarity index 98% rename from rllib/agents/maddpg/maddpg_tf_policy.py rename to rllib/algorithms/maddpg/maddpg_tf_policy.py index 8bf5f93a89b8..6b02fc09d6ea 100644 --- a/rllib/agents/maddpg/maddpg_tf_policy.py +++ b/rllib/algorithms/maddpg/maddpg_tf_policy.py @@ -43,7 +43,7 @@ def postprocess_trajectory( class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): def __init__(self, obs_space, act_space, config): # _____ Initial Configuration - config = dict(ray.rllib.agents.maddpg.DEFAULT_CONFIG, **config) + config = dict(ray.rllib.algorithms.maddpg.maddpg.DEFAULT_CONFIG, **config) self.config = config self.global_step = tf1.train.get_or_create_global_step() @@ -69,11 +69,11 @@ def _make_continuous_space(space): ) obs_space_n = [ - _make_continuous_space(space) + _make_continuous_space(space or obs_space) for _, (_, space, _, _) in config["multiagent"]["policies"].items() ] act_space_n = [ - _make_continuous_space(space) + _make_continuous_space(space or act_space) for _, (_, _, space, _) in config["multiagent"]["policies"].items() ] diff --git a/rllib/algorithms/maddpg/tests/test_maddpg.py b/rllib/algorithms/maddpg/tests/test_maddpg.py new file mode 100644 index 000000000000..c6181f7822be --- /dev/null +++ b/rllib/algorithms/maddpg/tests/test_maddpg.py @@ -0,0 +1,57 @@ +import unittest + +import ray +import ray.rllib.algorithms.maddpg as maddpg +from ray.rllib.examples.env.two_step_game import TwoStepGame +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.test_utils import ( + check_train_results, + framework_iterator, +) + + +class TestMADDPG(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_maddpg_compilation(self): + """Test whether an MADDPGTrainer can be built with all frameworks.""" + config = maddpg.DEFAULT_CONFIG.copy() + config["env"] = TwoStepGame + config["env_config"] = { + "actions_are_logits": True, + } + config["multiagent"] = { + "policies": { + "pol1": PolicySpec( + config={"agent_id": 0}, + ), + "pol2": PolicySpec( + config={"agent_id": 1}, + ), + }, + "policy_mapping_fn": (lambda aid, **kwargs: "pol2" if aid else "pol1"), + } + + num_iterations = 1 + + # Only working for tf right now. + for _ in framework_iterator(config, frameworks="tf"): + trainer = maddpg.MADDPGTrainer(config) + for i in range(num_iterations): + results = trainer.train() + check_train_results(results) + print(results) + trainer.stop() + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/maml/maml.py b/rllib/algorithms/maml/maml.py index bf1369f69b24..ec096e9d194e 100644 --- a/rllib/algorithms/maml/maml.py +++ b/rllib/algorithms/maml/maml.py @@ -288,9 +288,9 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: return MAMLTorchPolicy elif config["framework"] == "tf": - from ray.rllib.algorithms.maml.maml_tf_policy import MAMLDynamicTFPolicy + from ray.rllib.algorithms.maml.maml_tf_policy import MAMLStaticGraphTFPolicy - return MAMLDynamicTFPolicy + return MAMLStaticGraphTFPolicy else: from ray.rllib.algorithms.maml.maml_tf_policy import MAMLEagerTFPolicy diff --git a/rllib/algorithms/maml/maml_tf_policy.py b/rllib/algorithms/maml/maml_tf_policy.py index 05b1da704d7d..5fe1682b289a 100644 --- a/rllib/algorithms/maml/maml_tf_policy.py +++ b/rllib/algorithms/maml/maml_tf_policy.py @@ -3,7 +3,10 @@ import ray from ray.rllib.agents.ppo.ppo_tf_policy import validate_config -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.utils import get_activation_fn @@ -11,9 +14,10 @@ from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_mixins import ( - ComputeAndClipGradsMixIn, - ComputeGAEMixIn, + LocalOptimizer, + ModelGradients, ValueNetworkMixin, + compute_gradients, ) from ray.rllib.utils import try_import_tf from ray.rllib.utils.annotations import override @@ -364,9 +368,7 @@ def get_maml_tf_policy(base: type) -> type: A TF Policy to be used with MAMLTrainer. """ - class MAMLTFPolicy( - ComputeGAEMixIn, ComputeAndClipGradsMixIn, KLCoeffMixin, ValueNetworkMixin, base - ): + class MAMLTFPolicy(KLCoeffMixin, ValueNetworkMixin, base): def __init__( self, obs_space, @@ -379,7 +381,7 @@ def __init__( base.enable_eager_execution_if_necessary() config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config) - validate_config(self, obs_space, action_space, config) + validate_config(config) # Initialize base class. base.__init__( @@ -391,8 +393,6 @@ def __init__( existing_model=existing_model, ) - ComputeGAEMixIn.__init__(self) - ComputeAndClipGradsMixIn.__init__(self) KLCoeffMixin.__init__(self, config) ValueNetworkMixin.__init__(self, config) @@ -498,8 +498,23 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: "entropy": self.loss_obj.mean_entropy, } + @override(base) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + sample_batch = super().postprocess_trajectory(sample_batch) + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + @override(base) + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + return compute_gradients(self, optimizer, loss) + return MAMLTFPolicy -MAMLDynamicTFPolicy = get_maml_tf_policy(DynamicTFPolicyV2) +MAMLStaticGraphTFPolicy = get_maml_tf_policy(DynamicTFPolicyV2) MAMLEagerTFPolicy = get_maml_tf_policy(EagerTFPolicyV2) diff --git a/rllib/algorithms/maml/maml_torch_policy.py b/rllib/algorithms/maml/maml_torch_policy.py index 73bef4e90567..7f8083c2d89b 100644 --- a/rllib/algorithms/maml/maml_torch_policy.py +++ b/rllib/algorithms/maml/maml_torch_policy.py @@ -4,11 +4,14 @@ import ray from ray.rllib.agents.ppo.ppo_tf_policy import validate_config -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_mixins import ComputeGAEMixIn, ValueNetworkMixin +from ray.rllib.policy.torch_mixins import ValueNetworkMixin from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.numpy import convert_to_numpy @@ -284,12 +287,12 @@ def update_kls(self, sampled_kls): return self.kl_coeff_val -class MAMLTorchPolicy(ComputeGAEMixIn, ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2): +class MAMLTorchPolicy(ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2): """PyTorch policy class used with MAMLTrainer.""" def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config) - validate_config(self, observation_space, action_space, config) + validate_config(config) TorchPolicyV2.__init__( self, @@ -299,7 +302,6 @@ def __init__(self, observation_space, action_space, config): max_seq_len=config["model"]["max_seq_len"], ) - ComputeGAEMixIn.__init__(self) KLCoeffMixin.__init__(self, config) ValueNetworkMixin.__init__(self, config) @@ -422,3 +424,16 @@ def extra_grad_process( self, optimizer: "torch.optim.Optimizer", loss: TensorType ) -> Dict[str, TensorType]: return apply_grad_clipping(self, optimizer, loss) + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + # TODO: no_grad still necessary? + with torch.no_grad(): + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) diff --git a/rllib/algorithms/marwil/__init__.py b/rllib/algorithms/marwil/__init__.py index 5e071c2b3a7a..125588b9b85c 100644 --- a/rllib/algorithms/marwil/__init__.py +++ b/rllib/algorithms/marwil/__init__.py @@ -5,7 +5,7 @@ MARWILTrainer, ) from ray.rllib.algorithms.marwil.marwil_tf_policy import ( - MARWILDynamicTFPolicy, + MARWILStaticGraphTFPolicy, MARWILEagerTFPolicy, ) from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy @@ -14,7 +14,7 @@ "BCConfig", "BCTrainer", "MARWILConfig", - "MARWILDynamicTFPolicy", + "MARWILStaticGraphTFPolicy", "MARWILEagerTFPolicy", "MARWILTorchPolicy", "MARWILTrainer", diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index ecad6315d8fd..92ae2a02aa3c 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -23,6 +23,7 @@ ResultDict, TrainerConfigDict, ) +from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer class MARWILConfig(TrainerConfig): @@ -238,10 +239,10 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: return MARWILTorchPolicy elif config["framework"] == "tf": from ray.rllib.algorithms.marwil.marwil_tf_policy import ( - MARWILDynamicTFPolicy, + MARWILStaticGraphTFPolicy, ) - return MARWILDynamicTFPolicy + return MARWILStaticGraphTFPolicy else: from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy @@ -258,7 +259,11 @@ def training_iteration(self) -> ResultDict: self.local_replay_buffer.add(batch) # Pull batch from replay buffer and train on it. - train_batch = self.local_replay_buffer.sample(self.config["train_batch_size"]) + train_batch = sample_min_n_steps_from_buffer( + self.local_replay_buffer, + self.config["train_batch_size"], + count_by_agent_steps=self._by_agent_steps, + ) # Train. if self.config["simple_optimizer"]: train_results = train_one_step(self, train_batch) diff --git a/rllib/algorithms/marwil/marwil_tf_policy.py b/rllib/algorithms/marwil/marwil_tf_policy.py index fb2f7e49a304..44be5161534e 100644 --- a/rllib/algorithms/marwil/marwil_tf_policy.py +++ b/rllib/algorithms/marwil/marwil_tf_policy.py @@ -11,11 +11,18 @@ from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import ComputeAndClipGradsMixIn, ValueNetworkMixin +from ray.rllib.policy.tf_mixins import ( + ValueNetworkMixin, + compute_gradients, +) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, get_variable from ray.rllib.utils.tf_utils import explained_variance -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, +) tf1, tf, tfv = try_import_tf() @@ -155,9 +162,7 @@ def get_marwil_tf_policy(base: type) -> type: A TF Policy to be used with MAMLTrainer. """ - class MARWILTFPolicy( - ComputeAndClipGradsMixIn, ValueNetworkMixin, PostprocessAdvantages, base - ): + class MARWILTFPolicy(ValueNetworkMixin, PostprocessAdvantages, base): def __init__( self, obs_space, @@ -181,7 +186,6 @@ def __init__( existing_model=existing_model, ) - ComputeAndClipGradsMixIn.__init__(self) ValueNetworkMixin.__init__(self, config) PostprocessAdvantages.__init__(self) @@ -235,8 +239,14 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: return stats + @override(base) + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + return compute_gradients(self, optimizer, loss) + return MARWILTFPolicy -MARWILDynamicTFPolicy = get_marwil_tf_policy(DynamicTFPolicyV2) +MARWILStaticGraphTFPolicy = get_marwil_tf_policy(DynamicTFPolicyV2) MARWILEagerTFPolicy = get_marwil_tf_policy(EagerTFPolicyV2) diff --git a/rllib/algorithms/sac/rnnsac.py b/rllib/algorithms/sac/rnnsac.py index 5ee181ce014e..d5271f62005f 100644 --- a/rllib/algorithms/sac/rnnsac.py +++ b/rllib/algorithms/sac/rnnsac.py @@ -28,15 +28,29 @@ def __init__(self, trainer_class=None): super().__init__(trainer_class=trainer_class or RNNSACTrainer) # fmt: off # __sphinx_doc_begin__ - self.burn_in = DEPRECATED_VALUE self.batch_mode = "complete_episodes" self.zero_init_states = True - self.replay_buffer_config["replay_burn_in"] = 0 - # Set automatically: The number of contiguous environment steps to - # replay at once. Will be calculated via - # model->max_seq_len + burn_in. - # Do not set this to any valid value! - self.replay_buffer_config["replay_sequence_length"] = -1 + self.replay_buffer_config = { + # This algorithm learns on sequences. We therefore require the replay buffer + # to slice sampled batches into sequences before replay. How sequences + # are sliced depends on the parameters `replay_sequence_length`, + # `replay_burn_in`, and `replay_zero_init_states`. + "storage_unit": "sequences", + # If > 0, use the `burn_in` first steps of each replay-sampled sequence + # (starting either from all 0.0-values if `zero_init_state=True` or + # from the already stored values) to calculate an even more accurate + # initial states for the actual sequence (starting after this burn-in + # window). In the burn-in case, the actual length of the sequence + # used for loss calculation is `n - burn_in` time steps + # (n=LSTM’s/attention net’s max_seq_len). + "replay_burn_in": 0, + # Set automatically: The number of contiguous environment steps to + # replay at once. Will be calculated via + # model->max_seq_len + burn_in. + # Do not set this to any valid value! + "replay_sequence_length": -1, + }, + self.burn_in = DEPRECATED_VALUE # fmt: on # __sphinx_doc_end__ diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index 54221b44cc1a..36e26665af55 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -228,7 +228,7 @@ def test_sac_loss_function(self): } env = SimpleEnv - batch_size = 100 + batch_size = 64 obs_size = (batch_size, 1) actions = np.random.random(size=(batch_size, 2)) diff --git a/rllib/contrib/registry.py b/rllib/contrib/registry.py index 154dd36fb27f..9ff06adbb6d0 100644 --- a/rllib/contrib/registry.py +++ b/rllib/contrib/registry.py @@ -17,7 +17,7 @@ def _import_alphazero(): def _import_maddpg(): - from ray.rllib.agents.maddpg import maddpg + from ray.rllib.algorithms.maddpg import maddpg return maddpg.MADDPGTrainer, maddpg.DEFAULT_CONFIG diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 896cb958f525..123bf756ccea 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -3,7 +3,7 @@ import gym import ray -from ray.rllib.utils.annotations import Deprecated, override, PublicAPI +from ray.rllib.utils.annotations import Deprecated, override, PublicAPI, DeveloperAPI from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, MultiEnvDict if TYPE_CHECKING: @@ -366,6 +366,7 @@ def _with_dummy_agent_id( return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()} +@PublicAPI def with_dummy_agent_id( env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID ) -> MultiEnvDict: @@ -720,6 +721,7 @@ def reset(self) -> MultiAgentDict: return self.last_obs +@DeveloperAPI def convert_to_base_env( env: EnvType, make_env: Callable[[int], EnvType] = None, diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index 018fc3a1e0df..716cbbb83535 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -337,6 +337,7 @@ def _send(self): self.results_avail_condition.notify() +@PublicAPI class ExternalEnvWrapper(BaseEnv): """Internal adapter of ExternalEnv to BaseEnv.""" diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 952fbd23b794..1cfb987704e4 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -362,6 +362,7 @@ def _check_if_space_maps_agent_id_to_sub_space(self) -> bool: return obs_space_check and action_space_check +@PublicAPI def make_multi_agent( env_name_or_creator: Union[str, EnvCreator], ) -> Type["MultiAgentEnv"]: @@ -472,6 +473,7 @@ def render(self, mode=None): return MultiEnv +@PublicAPI class MultiAgentEnvWrapper(BaseEnv): """Internal adapter of MultiAgentEnv to BaseEnv. diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index e7fb3f273d74..dfcc6cbbbf83 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -345,7 +345,7 @@ def wrapped_creator(env_config): else: external_cls = ExternalEnv - class ExternalEnvWrapper(external_cls): + class _ExternalEnvWrapper(external_cls): def __init__(self, real_env): super().__init__( observation_space=real_env.observation_space, @@ -357,7 +357,7 @@ def run(self): # client, run doesn't need to do anything. time.sleep(999999) - return ExternalEnvWrapper(real_env) + return _ExternalEnvWrapper(real_env) return real_env return wrapped_creator diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index 33e04c9826a3..90d9d56a200f 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) +@PublicAPI class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader): """REST policy server that acts as an offline data source. diff --git a/rllib/env/utils.py b/rllib/env/utils.py index bcfbee1c4c44..58fa614c7aa8 100644 --- a/rllib/env/utils.py +++ b/rllib/env/utils.py @@ -4,7 +4,7 @@ from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError -def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env: +def _gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env: """Tries to create a gym env given an EnvContext object and descriptor. Note: This function tries to construct the env from a string descriptor diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 68b9fd44a842..41664933a42b 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -258,6 +258,7 @@ def try_render_at(self, index: Optional[int] = None): return self.envs[index].render() +@PublicAPI class VectorEnvWrapper(BaseEnv): """Internal adapter of VectorEnv to BaseEnv. diff --git a/rllib/env/wrappers/atari_wrappers.py b/rllib/env/wrappers/atari_wrappers.py index b36bf6913127..5216153ec777 100644 --- a/rllib/env/wrappers/atari_wrappers.py +++ b/rllib/env/wrappers/atari_wrappers.py @@ -3,9 +3,11 @@ from gym import spaces import numpy as np +from ray.rllib.utils.annotations import Deprecated, PublicAPI from ray.rllib.utils.images import rgb2gray, resize +@PublicAPI def is_atari(env): if ( hasattr(env.observation_space, "shape") @@ -16,6 +18,7 @@ def is_atari(env): return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale") +@PublicAPI def get_wrapper_by_cls(env, cls): """Returns the gym env wrapper of the given class, or None.""" currentenv = env @@ -28,6 +31,7 @@ def get_wrapper_by_cls(env, cls): return None +@PublicAPI class MonitorEnv(gym.Wrapper): def __init__(self, env=None): """Record episodes stats prior to EpisodicLifeEnv, etc.""" @@ -78,6 +82,7 @@ def next_episode_results(self): self._num_returned = len(self._episode_rewards) +@PublicAPI class NoopResetEnv(gym.Wrapper): def __init__(self, env, noop_max=30): """Sample initial states by taking random number of no-ops on reset. @@ -114,6 +119,7 @@ def step(self, ac): return self.env.step(ac) +@PublicAPI class ClipRewardEnv(gym.RewardWrapper): def __init__(self, env): gym.RewardWrapper.__init__(self, env) @@ -123,6 +129,7 @@ def reward(self, reward): return np.sign(reward) +@PublicAPI class FireResetEnv(gym.Wrapper): def __init__(self, env): """Take action on reset. @@ -146,6 +153,7 @@ def step(self, ac): return self.env.step(ac) +@PublicAPI class EpisodicLifeEnv(gym.Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game over. @@ -183,6 +191,7 @@ def reset(self, **kwargs): return obs +@PublicAPI class MaxAndSkipEnv(gym.Wrapper): def __init__(self, env, skip=4): """Return only every `skip`-th frame""" @@ -214,6 +223,7 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) +@PublicAPI class WarpFrame(gym.ObservationWrapper): def __init__(self, env, dim): """Warp frames to the specified size (dim x dim).""" @@ -230,7 +240,7 @@ def observation(self, frame): return frame[:, :, None] -# TODO: (sven) Deprecated class. Remove once traj. view is the norm. +@Deprecated(error=False) class FrameStack(gym.Wrapper): def __init__(self, env, k): """Stack k last frames.""" @@ -261,6 +271,7 @@ def _get_ob(self): return np.concatenate(self.frames, axis=2) +@PublicAPI class FrameStackTrajectoryView(gym.ObservationWrapper): def __init__(self, env): """No stacking. Trajectory View API takes care of this.""" @@ -275,6 +286,7 @@ def observation(self, observation): return np.squeeze(observation, axis=-1) +@PublicAPI class ScaledFloatFrame(gym.ObservationWrapper): def __init__(self, env): gym.ObservationWrapper.__init__(self, env) @@ -288,6 +300,7 @@ def observation(self, observation): return np.array(observation).astype(np.float32) / 255.0 +@PublicAPI def wrap_deepmind(env, dim=84, framestack=True): """Configure environment for DeepMind-style Atari. diff --git a/rllib/env/wrappers/dm_control_wrapper.py b/rllib/env/wrappers/dm_control_wrapper.py index a4bce06c4216..b912f53f12f0 100644 --- a/rllib/env/wrappers/dm_control_wrapper.py +++ b/rllib/env/wrappers/dm_control_wrapper.py @@ -40,6 +40,8 @@ suite = None import numpy as np +from ray.rllib.utils.annotations import PublicAPI + def _spec_to_box(spec): def extract_min_max(s): @@ -71,6 +73,7 @@ def _flatten_obs(obs): return np.concatenate(obs_pieces, axis=0) +@PublicAPI class DMCEnv(core.Env): def __init__( self, diff --git a/rllib/env/wrappers/dm_env_wrapper.py b/rllib/env/wrappers/dm_env_wrapper.py index 9d8bbf3a1186..5748f4bcdb8f 100644 --- a/rllib/env/wrappers/dm_env_wrapper.py +++ b/rllib/env/wrappers/dm_env_wrapper.py @@ -8,6 +8,8 @@ except ImportError: specs = None +from ray.rllib.utils.annotations import PublicAPI + def _convert_spec_to_space(spec): if isinstance(spec, dict): @@ -34,6 +36,7 @@ def _convert_spec_to_space(spec): ) +@PublicAPI class DMEnv(gym.Env): """A `gym.Env` wrapper for the `dm_env` API.""" diff --git a/rllib/env/wrappers/group_agents_wrapper.py b/rllib/env/wrappers/group_agents_wrapper.py index 646090f4807d..9811add06ba2 100644 --- a/rllib/env/wrappers/group_agents_wrapper.py +++ b/rllib/env/wrappers/group_agents_wrapper.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.typing import AgentID # info key for the individual rewards of an agent, for example: @@ -22,6 +23,7 @@ GROUP_INFO = "_group_info" +@DeveloperAPI class GroupAgentsWrapper(MultiAgentEnv): """Wraps a MultiAgentEnv environment with agents grouped as specified. diff --git a/rllib/env/wrappers/pettingzoo_env.py b/rllib/env/wrappers/pettingzoo_env.py index b4e9c384b258..0f9babfd0f36 100644 --- a/rllib/env/wrappers/pettingzoo_env.py +++ b/rllib/env/wrappers/pettingzoo_env.py @@ -1,6 +1,8 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.annotations import PublicAPI +@PublicAPI class PettingZooEnv(MultiAgentEnv): """An interface to the PettingZoo MARL environment library. @@ -141,6 +143,7 @@ def get_sub_environments(self): return self.env.unwrapped +@PublicAPI class ParallelPettingZooEnv(MultiAgentEnv): def __init__(self, env): super().__init__() diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 63712d67b341..6d953a38813c 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -7,11 +7,13 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID logger = logging.getLogger(__name__) +@PublicAPI class Unity3DEnv(MultiAgentEnv): """A MultiAgentEnv representing a single Unity3D game instance. diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index e0b301e0ff80..3c70ef3e8da9 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -6,6 +6,7 @@ from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, TensorType +from ray.rllib.utils.annotations import PublicAPI if TYPE_CHECKING: from ray.rllib.agents.callbacks import DefaultCallbacks @@ -15,6 +16,7 @@ # fmt: off # __sphinx_doc_begin__ +@PublicAPI class SampleCollector(metaclass=ABCMeta): """Collects samples for all policies and agents from a multi-agent env. diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 18e6327cf709..b11f3e00118a 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -12,7 +12,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -def to_float_np_array(v: List[Any]) -> np.ndarray: +def _to_float_np_array(v: List[Any]) -> np.ndarray: if torch and torch.is_tensor(v[0]): raise ValueError arr = np.array(v) @@ -229,7 +229,7 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # np-array for different view_cols using to the same data_col. if data_col not in np_data: np_data[data_col] = [ - to_float_np_array(d) for d in self.buffers[data_col] + _to_float_np_array(d) for d in self.buffers[data_col] ] # Range of indices on time-axis, e.g. "-50:-1". Together with @@ -335,7 +335,7 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # Shift is positive: We still need to 0-pad at the end. elif shift > 0: data = [ - to_float_np_array( + _to_float_np_array( np.concatenate( [ d[self.shift_before + shift :], @@ -519,6 +519,7 @@ def __init__(self, policy_map): self.agent_steps = 0 +@PublicAPI class SimpleListCollector(SampleCollector): """Util to build SampleBatches for each policy in a multi-agent env. diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index d7e99bc181c3..b90f922f4c91 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -17,22 +17,24 @@ logger = logging.getLogger(__name__) -RolloutMetrics = collections.namedtuple( - "RolloutMetrics", - [ - "episode_length", - "episode_reward", - "agent_rewards", - "custom_metrics", - "perf_stats", - "hist_data", - "media", - ], +RolloutMetrics = DeveloperAPI( + collections.namedtuple( + "RolloutMetrics", + [ + "episode_length", + "episode_reward", + "agent_rewards", + "custom_metrics", + "perf_stats", + "hist_data", + "media", + ], + ) ) RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {}) -def extract_stats(stats: Dict, key: str) -> Dict[str, Any]: +def _extract_stats(stats: Dict, key: str) -> Dict[str, Any]: if key in stats: return stats[key] diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index bd7cd61ef853..d501bf299b0b 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -9,6 +9,7 @@ from ray.rllib.utils.typing import AgentID +@DeveloperAPI class Postprocessing: """Constant definitions for postprocessing.""" @@ -16,6 +17,7 @@ class Postprocessing: VALUE_TARGETS = "value_targets" +@DeveloperAPI def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None: """Rewrites `batch` to encode n-step rewards, dones, and next-obs. @@ -134,6 +136,7 @@ def compute_advantages( return rollout +@DeveloperAPI def compute_gae_for_sample_batch( policy: Policy, sample_batch: SampleBatch, @@ -191,6 +194,7 @@ def compute_gae_for_sample_batch( return batch +@DeveloperAPI def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray: """Calculates the discounted cumulative sum over a reward sequence `x`. diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 54610921c459..71659a873183 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -39,6 +39,7 @@ from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils import force_list, merge_dicts, check_env from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary @@ -48,7 +49,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices -from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils.tf_run_builder import _TFRunBuilder from ray.rllib.utils.typing import ( AgentID, EnvConfigDict, @@ -937,7 +938,7 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict: policy = self.policy_map[pid] tf_session = policy.get_session() if tf_session and hasattr(policy, "_build_learn_on_batch"): - builders[pid] = TFRunBuilder(tf_session, "learn_on_batch") + builders[pid] = _TFRunBuilder(tf_session, "learn_on_batch") to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch) else: info_out[pid] = policy.learn_on_batch(batch) @@ -1059,7 +1060,7 @@ def compute_gradients( if not self.is_policy_to_train(pid, samples): continue policy = self.policy_map[pid] - builder = TFRunBuilder(policy.get_session(), "compute_gradients") + builder = _TFRunBuilder(policy.get_session(), "compute_gradients") grad_out[pid], info_out[pid] = policy._build_compute_gradients( builder, batch ) @@ -1698,7 +1699,7 @@ def setup_torch_data_parallel( ) for pid, policy in self.policy_map.items(): - if not isinstance(policy, TorchPolicy): + if not isinstance(policy, (TorchPolicy, TorchPolicyV2)): raise ValueError( "This policy does not support torch distributed", policy ) diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 4ba79a2e7feb..63af002423c6 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -def to_float_array(v: List[Any]) -> np.ndarray: +def _to_float_array(v: List[Any]) -> np.ndarray: arr = np.array(v) if arr.dtype == np.float64: return arr.astype(np.float32) # save some memory @@ -58,7 +58,7 @@ def add_batch(self, batch: SampleBatch) -> None: def build_and_reset(self) -> SampleBatch: """Returns a sample batch including all previously added values.""" - batch = SampleBatch({k: to_float_array(v) for k, v in self.buffers.items()}) + batch = SampleBatch({k: _to_float_array(v) for k, v in self.buffers.items()}) if SampleBatch.UNROLL_ID not in batch: batch[SampleBatch.UNROLL_ID] = np.repeat( SampleBatchBuilder._next_unroll_id, batch.count diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index a86491ff513a..fd292a6191b9 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -60,8 +60,8 @@ tf1, tf, _ = try_import_tf() logger = logging.getLogger(__name__) -PolicyEvalData = namedtuple( - "PolicyEvalData", +_PolicyEvalData = namedtuple( + "_PolicyEvalData", ["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"], ) @@ -69,7 +69,7 @@ StateBatch = List[List[Any]] -class NewEpisodeDefaultDict(defaultdict): +class _NewEpisodeDefaultDict(defaultdict): def __missing__(self, env_id): if self.default_factory is None: raise KeyError(env_id) @@ -650,7 +650,7 @@ def new_episode(env_id): ) return episode - active_episodes: Dict[EnvID, Episode] = NewEpisodeDefaultDict(new_episode) + active_episodes: Dict[EnvID, Episode] = _NewEpisodeDefaultDict(new_episode) while True: perf_stats.iters += 1 @@ -666,7 +666,7 @@ def new_episode(env_id): # Process observations and prepare for policy evaluation. t1 = time.time() - # types: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], + # types: Set[EnvID], Dict[PolicyID, List[_PolicyEvalData]], # List[Union[RolloutMetrics, SampleBatchType]] active_envs, to_eval, outputs = _process_observations( worker=worker, @@ -771,7 +771,7 @@ def _process_observations( sample_collector: SampleCollector, ) -> Tuple[ Set[EnvID], - Dict[PolicyID, List[PolicyEvalData]], + Dict[PolicyID, List[_PolicyEvalData]], List[Union[RolloutMetrics, SampleBatchType]], ]: """Record new data from the environment and prepare for policy evaluation. @@ -806,13 +806,13 @@ def _process_observations( Returns: Tuple consisting of 1) active_envs: Set of non-terminated env ids. - 2) to_eval: Map of policy_id to list of agent PolicyEvalData. + 2) to_eval: Map of policy_id to list of agent _PolicyEvalData. 3) outputs: List of metrics and samples to return from the sampler. """ # Output objects. active_envs: Set[EnvID] = set() - to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list) + to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] # For each (vectorized) sub-environment. @@ -959,7 +959,7 @@ def _process_observations( ) if not agent_done: - item = PolicyEvalData( + item = _PolicyEvalData( env_id, agent_id, filtered_obs, @@ -1086,7 +1086,7 @@ def _process_observations( filtered_obs, ) - item = PolicyEvalData( + item = _PolicyEvalData( env_id, agent_id, filtered_obs, @@ -1110,7 +1110,7 @@ def _process_observations( def _do_policy_eval( *, - to_eval: Dict[PolicyID, List[PolicyEvalData]], + to_eval: Dict[PolicyID, List[_PolicyEvalData]], policies: PolicyMap, sample_collector: SampleCollector, active_episodes: Dict[EnvID, Episode], @@ -1118,7 +1118,7 @@ def _do_policy_eval( """Call compute_actions on collected episode/model data to get next action. Args: - to_eval: Mapping of policy IDs to lists of PolicyEvalData objects + to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects (items in these lists will be the batch's items for the model forward pass). policies: Mapping from policy ID to Policy obj. @@ -1167,7 +1167,7 @@ def _do_policy_eval( def _process_policy_eval_results( *, - to_eval: Dict[PolicyID, List[PolicyEvalData]], + to_eval: Dict[PolicyID, List[_PolicyEvalData]], eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]], active_episodes: Dict[EnvID, Episode], active_envs: Set[int], @@ -1182,7 +1182,7 @@ def _process_policy_eval_results( returns replies to send back to agents in the env. Args: - to_eval: Mapping of policy IDs to lists of PolicyEvalData objects. + to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects. eval_results: Mapping of policy IDs to list of actions, rnn-out states, extra-action-fetches dicts. active_episodes: Mapping from episode ID to currently ongoing @@ -1206,7 +1206,7 @@ def _process_policy_eval_results( for env_id in active_envs: actions_to_send[env_id] = {} # at minimum send empty dict - # types: PolicyID, List[PolicyEvalData] + # types: PolicyID, List[_PolicyEvalData] for policy_id, eval_data in to_eval.items(): actions: TensorStructType = eval_results[policy_id][0] actions = convert_to_numpy(actions) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 5403e1ddfce1..14a035ac86b7 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -265,13 +265,14 @@ def reset(self, new_remote_workers: List[ActorHandle]) -> None: def remove_failed_workers(self): faulty_indices = self._worker_health_check() - + removed_workers = [] # Terminate faulty workers. for worker_index in faulty_indices: worker = self.remote_workers()[worker_index - 1] logger.info(f"Trying to terminate faulty worker {worker_index}.") try: worker.__ray_terminate__.remote() + removed_workers.append(worker) except Exception: logger.exception("Error terminating faulty worker.") @@ -286,12 +287,15 @@ def remove_failed_workers(self): f"No healthy workers remaining (worker indices {faulty_indices} have " f"died)! Can't continue training." ) + return removed_workers - def recreate_failed_workers(self): + def recreate_failed_workers(self) -> Tuple[List[ActorHandle], List[ActorHandle]]: faulty_indices = self._worker_health_check() - + removed_workers = [] + new_workers = [] for worker_index in faulty_indices: worker = self.remote_workers()[worker_index - 1] + removed_workers.append(worker) logger.info(f"Trying to recreate faulty worker {worker_index}") try: worker.__ray_terminate__.remote() @@ -315,6 +319,8 @@ def recreate_failed_workers(self): ) # Add new worker to list of remote workers. self._remote_workers[worker_index - 1] = new_worker + new_workers.append(new_worker) + return removed_workers, new_workers def stop(self) -> None: """Calls `stop` on all rollout workers (including the local one).""" diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 7df45ec42516..273b539d3a08 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -20,14 +20,10 @@ import ray from ray import tune -from ray.rllib.algorithms.maml.maml_torch_policy import ( - KLCoeffMixin as TorchKLCoeffMixin, -) from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_tf_policy import ( - PPOTFPolicy, - KLCoeffMixin, - ppo_surrogate_loss as tf_loss, + PPOStaticGraphTFPolicy, + PPOEagerTFPolicy, ) from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing @@ -38,13 +34,9 @@ ) from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import LearningRateSchedule, EntropyCoeffSchedule -from ray.rllib.policy.torch_mixins import ( - LearningRateSchedule as TorchLR, - EntropyCoeffSchedule as TorchEntropyCoeffSchedule, -) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.test_utils import check_learning_achieved from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -124,10 +116,12 @@ def centralized_critic_postprocessing( .numpy() ) else: - sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( - sample_batch[SampleBatch.CUR_OBS], - sample_batch[OPPONENT_OBS], - sample_batch[OPPONENT_ACTION], + sample_batch[SampleBatch.VF_PREDS] = convert_to_numpy( + policy.compute_central_vf( + sample_batch[SampleBatch.CUR_OBS], + sample_batch[OPPONENT_OBS], + sample_batch[OPPONENT_ACTION], + ) ) else: # Policy hasn't been initialized yet, use zeros. @@ -154,44 +148,26 @@ def centralized_critic_postprocessing( # Copied from PPO but optimizing the central value function. -def loss_with_central_critic(policy, model, dist_class, train_batch): - CentralizedValueMixin.__init__(policy) - func = tf_loss if not policy.config["framework"] == "torch" else PPOTorchPolicy.loss - +def loss_with_central_critic(policy, base_policy, model, dist_class, train_batch): + # Save original value function. vf_saved = model.value_function + + # Calculate loss with a custom value function. model.value_function = lambda: policy.model.central_value_function( train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS], train_batch[OPPONENT_ACTION], ) - policy._central_value_out = model.value_function() - loss = func(policy, model, dist_class, train_batch) + loss = base_policy.loss(model, dist_class, train_batch) + # Restore original value function. model.value_function = vf_saved return loss -def setup_tf_mixins(policy, obs_space, action_space, config): - # Copied from PPOTFPolicy (w/o ValueNetworkMixin). - KLCoeffMixin.__init__(policy, config) - EntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - - -def setup_torch_mixins(policy, obs_space, action_space, config): - # Copied from PPOTorchPolicy (w/o ValueNetworkMixin). - TorchKLCoeffMixin.__init__(policy, config) - TorchEntropyCoeffSchedule.__init__( - policy, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - TorchLR.__init__(policy, config["lr"], config["lr_schedule"]) - - -def central_vf_stats(policy, train_batch, grads): +def central_vf_stats(policy, train_batch): # Report the explained variance of the central value function. return { "vf_explained_var": explained_variance( @@ -200,29 +176,51 @@ def central_vf_stats(policy, train_batch, grads): } -CCPPOTFPolicy = PPOTFPolicy.with_updates( - name="CCPPOTFPolicy", - postprocess_fn=centralized_critic_postprocessing, - loss_fn=loss_with_central_critic, - before_loss_init=setup_tf_mixins, - grad_stats_fn=central_vf_stats, - mixins=[ - LearningRateSchedule, - EntropyCoeffSchedule, - KLCoeffMixin, - CentralizedValueMixin, - ], -) +def get_ccppo_policy(base): + class CCPPOTFPolicy(CentralizedValueMixin, base): + def __init__(self, observation_space, action_space, config): + base.__init__(self, observation_space, action_space, config) + CentralizedValueMixin.__init__(self) + + @override(base) + def loss(self, model, dist_class, train_batch): + # Use super() to get to the base PPO policy. + # This special loss function utilizes a shared + # value function defined on self, and the loss function + # defined on PPO policies. + return loss_with_central_critic( + self, super(), model, dist_class, train_batch + ) + + @override(base) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + return centralized_critic_postprocessing( + self, sample_batch, other_agent_batches, episode + ) + + @override(base) + def stats_fn(self, train_batch: SampleBatch): + stats = super().stats_fn(train_batch) + stats.update(central_vf_stats(self, train_batch)) + return stats + + return CCPPOTFPolicy + + +CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOStaticGraphTFPolicy) +CCPPOEagerTFPolicy = get_ccppo_policy(PPOEagerTFPolicy) -class CCPPOTorchPolicy(PPOTorchPolicy): +class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy): def __init__(self, observation_space, action_space, config): - super().__init__(observation_space, action_space, config) - self.compute_central_vf = self.model.central_value_function + PPOTorchPolicy.__init__(self, observation_space, action_space, config) + CentralizedValueMixin.__init__(self) @override(PPOTorchPolicy) def loss(self, model, dist_class, train_batch): - return loss_with_central_critic(self, model, dist_class, train_batch) + return loss_with_central_critic(self, super(), model, dist_class, train_batch) @override(PPOTorchPolicy) def postprocess_trajectory( @@ -238,8 +236,10 @@ class CCTrainer(PPOTrainer): def get_default_policy_class(self, config): if config["framework"] == "torch": return CCPPOTorchPolicy + elif config["framework"] == "tf": + return CCPPOStaticGraphTFPolicy else: - return CCPPOTFPolicy + return CCPPOEagerTFPolicy if __name__ == "__main__": diff --git a/rllib/examples/multi_agent_two_trainers.py b/rllib/examples/multi_agent_two_trainers.py index d4e63885ed55..0a6a6dd0817b 100644 --- a/rllib/examples/multi_agent_two_trainers.py +++ b/rllib/examples/multi_agent_two_trainers.py @@ -14,7 +14,12 @@ import ray from ray.rllib.algorithms.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy -from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy, PPOTorchPolicy +from ray.rllib.agents.ppo import ( + PPOTrainer, + PPOStaticGraphTFPolicy, + PPOEagerTFPolicy, + PPOTorchPolicy, +) from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.tune.logger import pretty_print from ray.tune.registry import register_env @@ -56,17 +61,33 @@ obs_space = single_dummy_env.observation_space act_space = single_dummy_env.action_space + def seelct_policy(algorithm, framework): + if algorithm == "PPO": + if framework == "torch": + return PPOTorchPolicy + elif framework == "tf": + return PPOStaticGraphTFPolicy + else: + return PPOEagerTFPolicy + elif algorithm == "DQN": + if framework == "torch": + return DQNTorchPolicy + else: + return DQNTFPolicy + else: + raise ValueError("Unknown algorithm: ", algorithm) + # You can also have multiple policies per trainer, but here we just # show one each for PPO and DQN. policies = { "ppo_policy": ( - PPOTorchPolicy if args.framework == "torch" else PPOTFPolicy, + seelct_policy("PPO", args.framework), obs_space, act_space, {}, ), "dqn_policy": ( - DQNTorchPolicy if args.framework == "torch" else DQNTFPolicy, + seelct_policy("DQN", args.framework), obs_space, act_space, {}, diff --git a/rllib/examples/offline_rl.py b/rllib/examples/offline_rl.py index 0052dd06bce1..76fc635c129a 100644 --- a/rllib/examples/offline_rl.py +++ b/rllib/examples/offline_rl.py @@ -118,8 +118,10 @@ # Example on how to do evaluation on the trained Trainer # using the data from our buffer. - # Get a sample (MultiAgentBatch -> SampleBatch). - batch = replay_buffer.replay().policy_batches["default_policy"] + # Get a sample (MultiAgentBatch). + multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"]) + # All experiences have been buffered for `default_policy` + batch = multi_agent_batch.policy_batches["default_policy"] obs = torch.from_numpy(batch["obs"]) # Pass the observations through our model to get the # features, which then to pass through the Q-head. diff --git a/rllib/examples/rnnsac_stateless_cartpole.py b/rllib/examples/rnnsac_stateless_cartpole.py index 0293124f2e94..8930d5625fa5 100644 --- a/rllib/examples/rnnsac_stateless_cartpole.py +++ b/rllib/examples/rnnsac_stateless_cartpole.py @@ -24,17 +24,19 @@ "mode": "max", "verbose": 2, "config": { - "seed": 42, "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "framework": "torch", "num_workers": 4, - # "env": RepeatAfterMeEnv, + "num_envs_per_worker": 1, + "num_cpus_per_worker": 1, + "log_level": "INFO", "env": StatelessCartPole, "horizon": 1000, "gamma": 0.95, "batch_mode": "complete_episodes", "replay_buffer_config": { "type": "MultiAgentReplayBuffer", + "storage_unit": "sequences", "capacity": 100000, "learning_starts": 1000, "replay_burn_in": 4, diff --git a/rllib/examples/sumo_env_local.py b/rllib/examples/sumo_env_local.py index 6afedddaca1b..4b195c7de6fa 100644 --- a/rllib/examples/sumo_env_local.py +++ b/rllib/examples/sumo_env_local.py @@ -78,7 +78,7 @@ tune.register_env("sumo_test_env", marlenvironment.env_creator) # Algorithm. - policy_class = ppo.PPOTFPolicy + policy_class = ppo.PPOStaticGraphTFPolicy config = ppo.DEFAULT_CONFIG config["framework"] = "tf" config["gamma"] = 0.99 diff --git a/rllib/examples/two_trainer_workflow.py b/rllib/examples/two_trainer_workflow.py index eed2fc1166b4..2cbfc6d1bdef 100644 --- a/rllib/examples/two_trainer_workflow.py +++ b/rllib/examples/two_trainer_workflow.py @@ -16,7 +16,7 @@ from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_CONFIG -from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy +from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.execution.rollout_ops import synchronous_parallel_sample @@ -168,20 +168,29 @@ def training_iteration(self) -> ResultDict: "multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4}) ) + # framework can be changed, so removed the hardcoded framework key + # from policy configs. + ppo_config = PPO_CONFIG + del ppo_config["framework"] + dqn_config = DQN_CONFIG + del dqn_config["framework"] + # Note that since the trainer below does not include a default policy or # policy configs, we have to explicitly set it in the multiagent config: policies = { "ppo_policy": ( - PPOTorchPolicy if args.torch or args.mixed_torch_tf else PPOTFPolicy, + PPOTorchPolicy + if args.torch or args.mixed_torch_tf + else PPOStaticGraphTFPolicy, None, None, - PPO_CONFIG, + ppo_config, ), "dqn_policy": ( DQNTorchPolicy if args.torch else DQNTFPolicy, None, None, - DQN_CONFIG, + dqn_config, ), } diff --git a/rllib/execution/buffers/multi_agent_replay_buffer.py b/rllib/execution/buffers/multi_agent_replay_buffer.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/rllib/execution/parallel_requests.py b/rllib/execution/parallel_requests.py index f609001945e9..e8ca6aa5398d 100644 --- a/rllib/execution/parallel_requests.py +++ b/rllib/execution/parallel_requests.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Union import ray from ray.actor import ActorHandle @@ -14,7 +14,7 @@ def asynchronous_parallel_requests( remote_requests_in_flight: DefaultDict[ActorHandle, Set[ray.ObjectRef]], actors: List[ActorHandle], ray_wait_timeout_s: Optional[float] = None, - max_remote_requests_in_flight_per_actor: int = 2, + max_remote_requests_in_flight_per_worker: int = 2, remote_fn: Optional[ Callable[[Any, Optional[Any], Optional[Any]], Any] ] = lambda actor: actor.sample(), @@ -27,7 +27,7 @@ def asynchronous_parallel_requests( May use a timeout (if provided) on `ray.wait()` and returns only those samples that could be gathered in the timeout window. Allows a maximum - of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight + of `max_remote_requests_in_flight_per_worker` remote calls to be in-flight per remote actor. Alternatively to calling `actor.sample.remote()`, the user can provide a @@ -43,7 +43,7 @@ def asynchronous_parallel_requests( ray_wait_timeout_s: Timeout (in sec) to be used for the underlying `ray.wait()` calls. If None (default), never time out (block until at least one actor returns something). - max_remote_requests_in_flight_per_actor: Maximum number of remote + max_remote_requests_in_flight_per_worker: Maximum number of remote requests sent to each actor. 2 (default) is probably sufficient to avoid idle times between two requests. remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of @@ -105,12 +105,12 @@ def asynchronous_parallel_requests( remote_to_actor[r] = actor # Add new requests, if possible (if - # `max_remote_requests_in_flight_per_actor` setting allows it). + # `max_remote_requests_in_flight_per_worker` setting allows it). for actor_idx, actor in enumerate(actors): # Still room for another request to this actor. if ( len(remote_requests_in_flight[actor]) - < max_remote_requests_in_flight_per_actor + < max_remote_requests_in_flight_per_worker ): if remote_fn is not None: args = remote_args[actor_idx] if remote_args else [] @@ -118,7 +118,7 @@ def asynchronous_parallel_requests( for _ in range(num_requests_to_launch): if ( len(remote_requests_in_flight[actor]) - >= max_remote_requests_in_flight_per_actor + >= max_remote_requests_in_flight_per_worker ): break req = actor.apply.remote(remote_fn, *args, **kwargs) @@ -178,7 +178,235 @@ def wait_asynchronous_requests( remote_requests_in_flight=remote_requests_in_flight, actors=list(remote_requests_in_flight.keys()), ray_wait_timeout_s=ray_wait_timeout_s, - max_remote_requests_in_flight_per_actor=float("inf"), + max_remote_requests_in_flight_per_worker=float("inf"), remote_fn=None, ) return ready_requests + + +class AsyncRequestsManager: + """A manager for asynchronous requests to actors. + + Args: + workers: A list of ray remote workers to operate on. These workers must have an + `apply` method which takes a function and a list of arguments to that + function. + max_remote_requests_in_flight_per_worker: The maximum number of remote + requests that can be in flight per actor. Any requests made to the pool + that cannot be scheduled because the + max_remote_requests_in_flight_per_worker per actor has been reached will + be queued. + ray_wait_timeout_s: The maximum amount of time to wait for inflight requests + to be done and ready when calling + AsyncRequestsManager.get_ready_results(). + + Example: + >>> import time + >>> import ray + >>> from ray.rllib.execution.parallel_requests_manager import ( + ... AsyncRequestsManager) + >>> + >>> @ray.remote + ... class MyActor: + ... def apply(self, fn, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: + ... return fn(*args, **kwargs) + ... + ... def task(self, a: int, b: int) -> Any: + ... time.sleep(0.5) + ... return a + b + >>> + >>> workers = [MyActor.remote() for _ in range(3)] + >>> manager = AsyncRequestsManager(workers, + ... max_remote_requests_in_flight_per_worker=2) + >>> manager.call(lambda worker, a, b: worker.task(a, b), fn_args=[1, 2]) + >>> print(manager.get_ready()) + >>> manager.call(lambda worker, a, b: worker.task(a, b), + ... fn_kwargs={"a": 1, "b": 2}) + >>> time.sleep(2) # Wait for the tasks to finish. + >>> print(manager.get_ready()) + """ + + def __init__( + self, + workers: List[ActorHandle], + max_remote_requests_in_flight_per_worker: int = 2, + ray_wait_timeout_s: Optional[float] = 0.0, + return_object_refs: bool = False, + ): + self._ray_wait_timeout_s = ray_wait_timeout_s + self._return_object_refs = return_object_refs + self._max_remote_requests_in_flight = max_remote_requests_in_flight_per_worker + self._pending_to_actor = {} + self._pending_remotes = [] + self._remote_requests_in_flight = defaultdict(set) + + self._all_workers = ( + list(workers) if not isinstance(workers, list) else workers.copy() + ) + self._curr_actor_ptr = 0 + + def call( + self, + remote_fn: Callable, + *, + actor: ActorHandle = None, + fn_args: List[Any] = None, + fn_kwargs: Dict[str, Any] = None, + ) -> bool: + """Call a remote function on an available worker or on actor + if actor is specified. + + Args: + remote_fn: The remote function to call + actor: The actor to call the remote function on. + fn_args: The arguments to pass to the remote function + fn_kwargs: The keyword arguments to pass to the remote function + Raises: + ValueError: If actor has not been added to the manager. + ValueError: If there are no actors available to submit a request to. + + Returns: + True if the remoted_fn was scheduled on an actor. False if it was unable + to be scheduled. + """ + if actor and actor not in self._all_workers: + raise ValueError( + f"Actor {actor} has not been added to the manager." + f" You must call manager.add_worker(actor) first " + f"before submitting requests to actor." + ) + if fn_args is None: + fn_args = [] + if fn_kwargs is None: + fn_kwargs = {} + + def actor_available(a): + return ( + len(self._remote_requests_in_flight[a]) + < self._max_remote_requests_in_flight + ) + + num_workers = len(self._all_workers) + + if not actor: # If no actor is specified, use a random actor. + for _ in range(num_workers): + if actor_available(self._all_workers[self._curr_actor_ptr]): + actor = self._all_workers[self._curr_actor_ptr] + self._curr_actor_ptr = (self._curr_actor_ptr + 1) % num_workers + break + self._curr_actor_ptr = (self._curr_actor_ptr + 1) % num_workers + if not actor: # No actors available to schedule the request on. + return False + else: + if not actor_available(actor): + return False + req = actor.apply.remote(remote_fn, *fn_args, **fn_kwargs) + self._remote_requests_in_flight[actor].add(req) + self._pending_to_actor[req] = actor + self._pending_remotes.append(req) + return True + + def call_on_all_available( + self, + remote_fn: Callable, + *, + fn_args: List[Any] = None, + fn_kwargs: Dict[str, Any] = None, + ) -> int: + """ "Call remote_fn on all available workers + + Args: + remote_fn: The remote function to call + fn_args: The arguments to pass to the remote function + fn_kwargs: The keyword arguments to pass to the remote function + + Returns: + The number of remote calls of remote_fn that were able to be launched. + """ + num_launched = 0 + for worker in self._all_workers: + launched = self.call( + remote_fn, actor=worker, fn_args=fn_args, fn_kwargs=fn_kwargs + ) + num_launched += int(launched) + return num_launched + + def get_ready(self) -> Dict[ActorHandle, List[Any]]: + """Get results that are ready to be returned + + Returns: + A dictionary of actor handles to lists of returns from tasks that were + previously submitted to this actor pool that are now ready to be returned. + If return_object_refs + + """ + ready_requests_dict = defaultdict(list) + ready_requests, self._pending_remotes = ray.wait( + self._pending_remotes, + timeout=self._ray_wait_timeout_s, + num_returns=len(self._pending_remotes), + ) + if not self._return_object_refs: + objs = ray.get(ready_requests) + else: + objs = ready_requests + for req, obj in zip(ready_requests, objs): + actor = self._pending_to_actor[req] + self._remote_requests_in_flight[actor].remove(req) + ready_requests_dict[actor].append(obj) + del self._pending_to_actor[req] + del ready_requests + return dict(ready_requests_dict) + + def add_workers(self, new_workers: Union[List[ActorHandle], ActorHandle]) -> None: + """Add a new worker to the manager + + Args: + new_workers: The actors to add + + """ + if isinstance(new_workers, ActorHandle): + new_workers = [new_workers] + for new_worker in new_workers: + if new_worker not in self._all_workers: + self._all_workers.append(new_worker) + + def remove_workers(self, workers: Union[List[ActorHandle], ActorHandle]) -> None: + """Make workers unschedulable and remove them from this manager. + + Note: + This will not stop their inflight requests. ray.kill can be used to kill + the workers and their inflight requests. + + Args: + workers: The actors to remove + """ + if isinstance(workers, ActorHandle): + workers = [workers] + workers_to_remove = set(workers) + self._all_workers[:] = [ + el for el in self._all_workers if el not in workers_to_remove + ] + if self._all_workers and (self._curr_actor_ptr >= len(self._all_workers)): + # Move current pointer to the new tail of the list. + self._curr_actor_ptr = len(self._all_workers) - 1 + elif not self._all_workers: + self._curr_actor_ptr = 0 + + def get_manager_statistics(self) -> Dict[str, Any]: + """Get statistics about the the manager + + Some of the statistics include the number of actors that are available, + the number of pending inflight requests, and the number of pending requests + to be scheduled on the available actors. + + Returns: + A dictionary of statistics about the manager. + """ + return { + "num_pending_inflight_requests": len(self._pending_remotes), + } + + @property + def workers(self): + return frozenset(self._all_workers) diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index 8f2c8965603b..b24a17ad61c2 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -71,6 +71,7 @@ def __call__(self, batch: SampleBatchType): def Replay( *, local_buffer: Optional[MultiAgentReplayBuffer] = None, + num_items_to_replay: int = 1, actors: Optional[List[ActorHandle]] = None, num_async: int = 4, ) -> LocalIterator[SampleBatchType]: @@ -82,6 +83,7 @@ def Replay( Args: local_buffer: Local buffer to use. Only one of this and replay_actors can be specified. + num_items_to_replay: Number of items to sample from buffer actors: List of replay actors. Only one of this and local_buffer can be specified. num_async: In async mode, the max number of async requests in flight @@ -91,7 +93,8 @@ def Replay( >>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer >>> actors = [ # doctest: +SKIP ... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)] - >>> replay_op = Replay(actors=actors) # doctest: +SKIP + >>> replay_op = Replay(actors=actors, # doctest: +SKIP + ... num_items_to_replay=batch_size) >>> next(replay_op) # doctest: +SKIP SampleBatch(...) """ @@ -100,12 +103,14 @@ def Replay( raise ValueError("Exactly one of local_buffer and replay_actors must be given.") if actors is not None: + for actor in actors: + actor.make_iterator.remote(num_items_to_replay=num_items_to_replay) replay = from_actors(actors) return replay.gather_async(num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: - item = local_buffer.replay() + item = local_buffer.sample(num_items_to_replay) if item is None: yield _NextValueNotReady() else: diff --git a/rllib/execution/tests/test_async_requests_manager.py b/rllib/execution/tests/test_async_requests_manager.py new file mode 100644 index 000000000000..dae92ae60fb1 --- /dev/null +++ b/rllib/execution/tests/test_async_requests_manager.py @@ -0,0 +1,227 @@ +import random +import pytest +import unittest + +import ray +import time + +from ray.rllib.execution.parallel_requests import AsyncRequestsManager + + +@ray.remote +class RemoteRLlibActor: + def __init__(self, sleep_time): + self.sleep_time = sleep_time + + def apply(self, func, *_args, **_kwargs): + return func(self, *_args, **_kwargs) + + def task(self): + time.sleep(self.sleep_time) + return "done" + + def task2(self, a, b): + time.sleep(self.sleep_time) + return a + b + + +class TestAsyncRequestsManager(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init(num_cpus=4) + random.seed(0) + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + @classmethod + def shutdown_method(cls): + ray.shutdown() + + def test_async_requests_manager_num_returns(self): + """Tests that an async manager can properly handle actors with tasks that + vary in the amount of time that they take to run""" + workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] + workers += [RemoteRLlibActor.remote(sleep_time=5) for _ in range(2)] + manager = AsyncRequestsManager( + workers, max_remote_requests_in_flight_per_worker=1 + ) + for _ in range(4): + manager.call(lambda w: w.task()) + time.sleep(3) + if not len(manager.get_ready()) == 2: + raise Exception( + "We should return the 2 ready requests in this case from the actors" + " that have shorter tasks" + ) + time.sleep(7) + if not len(manager.get_ready()) == 2: + raise Exception( + "We should return the 2 ready requests in this case from the actors" + " that have longer tasks" + ) + + def test_round_robin_scheduling(self): + """Test that the async manager schedules actors in a round robin fashion""" + workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] + manager = AsyncRequestsManager( + workers, max_remote_requests_in_flight_per_worker=2 + ) + for i in range(4): + scheduled_actor = workers[i % len(workers)] + manager.call(lambda w: w.task()) + if i < 2: + assert len(manager._remote_requests_in_flight[scheduled_actor]) == 1, ( + "We should have 1 request in flight for the actor that we just " + "scheduled on" + ) + else: + assert len(manager._remote_requests_in_flight[scheduled_actor]) == 2, ( + "We should have 2 request in flight for the actor that we just " + "scheduled on" + ) + + def test_test_async_requests_task_doesnt_buffering(self): + """Tests that the async manager drops""" + workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] + manager = AsyncRequestsManager( + workers, max_remote_requests_in_flight_per_worker=2 + ) + for i in range(8): + scheduled = manager.call(lambda w: w.task()) + if i < 4: + assert scheduled, "We should have scheduled the task" + else: + assert not scheduled, ( + "We should not have scheduled the task because" + " all workers are busy." + ) + assert len(manager._pending_remotes) == 4, "We should have 4 pending requests" + time.sleep(3) + ready_requests = manager.get_ready() + for worker in workers: + if not len(ready_requests[worker]) == 2: + raise Exception( + "We should return the 2 ready requests in this case from each " + "actors." + ) + for _ in range(4): + manager.call(lambda w: w.task()) + # new tasks scheduled from the buffer + time.sleep(3) + ready_requests = manager.get_ready() + for worker in workers: + if not len(ready_requests[worker]) == 2: + raise Exception( + "We should return the 2 ready requests in this case from each " + "actors" + ) + + def test_args_kwargs(self): + """Tests that the async manager can properly handle actors with tasks that + vary in the amount of time that they take to run""" + workers = [RemoteRLlibActor.remote(sleep_time=0.1)] + manager = AsyncRequestsManager( + workers, max_remote_requests_in_flight_per_worker=2 + ) + for _ in range(2): + manager.call(lambda w, a, b: w.task2(a, b), fn_args=[1, 2]) + time.sleep(3) + if not len(manager.get_ready()[workers[0]]) == 2: + raise Exception( + "We should return the 2 ready requests in this case from the actors" + " that have shorter tasks" + ) + for _ in range(2): + manager.call(lambda w, a, b: w.task2(a, b), fn_kwargs=dict(a=1, b=2)) + time.sleep(3) + if not len(manager.get_ready()[workers[0]]) == 2: + raise Exception( + "We should return the 2 ready requests in this case from the actors" + " that have longer tasks" + ) + + def test_add_remove_actors(self): + """Tests that the async manager can properly add and remove actors""" + + workers = [] + manager = AsyncRequestsManager( + workers, max_remote_requests_in_flight_per_worker=2 + ) + if not ( + ( + len(manager._all_workers) + == len(manager._remote_requests_in_flight) + == len(manager._pending_to_actor) + == len(manager._pending_remotes) + == 0 + ) + ): + raise ValueError("We should have no workers in this case.") + + assert not manager.call(lambda w: w.task()), ( + "Task shouldn't have been " + "launched since there are no " + "workers in the manager." + ) + worker = RemoteRLlibActor.remote(sleep_time=0.1) + manager.add_workers(worker) + manager.call(lambda w: w.task()) + if not ( + len(manager._remote_requests_in_flight[worker]) + == len(manager._pending_to_actor) + == len(manager._all_workers) + == len(manager._pending_remotes) + == 1 + ): + raise ValueError("We should have 1 worker and 1 pending request") + time.sleep(3) + manager.get_ready() + # test worker removal + for i in range(2): + manager.call(lambda w: w.task()) + assert len(manager._pending_remotes) == i + 1 + manager.remove_workers(worker) + if not ((len(manager._all_workers) == 0)): + raise ValueError("We should have no workers that we can schedule tasks to") + if not ( + (len(manager._pending_remotes) == 2 and len(manager._pending_to_actor) == 2) + ): + raise ValueError( + "We should still have 2 pending requests in flight from the worker" + ) + time.sleep(3) + result = manager.get_ready() + if not ( + len(result) == 1 + and len(result[worker]) == 2 + and len(manager._pending_remotes) == 0 + and len(manager._pending_to_actor) == 0 + ): + raise ValueError( + "We should have 2 ready results from the worker and no pending requests" + ) + + def test_call_to_actor(self): + workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] + worker_not_in_manager = RemoteRLlibActor.remote(sleep_time=0.1) + manager = AsyncRequestsManager( + workers, max_remote_requests_in_flight_per_worker=2 + ) + manager.call(lambda w: w.task(), actor=workers[0]) + time.sleep(3) + results = manager.get_ready() + if not len(results) == 1 and workers[0] not in results: + raise Exception( + "We should return the 1 ready requests in this case from the worker we " + "called to" + ) + with pytest.raises(ValueError, match=".*has not been added to the manager.*"): + manager.call(lambda w: w.task(), actor=worker_not_in_manager) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index c49f5bf8d75c..1a1fd5dc15ac 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -4,7 +4,7 @@ import gym from typing import Any, List -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.spaces.repeated import Repeated from ray.rllib.utils.typing import TensorType from ray.rllib.utils.images import resize @@ -30,7 +30,7 @@ class Preprocessor: @PublicAPI def __init__(self, obs_space: gym.Space, options: dict = None): - legacy_patch_shapes(obs_space) + _legacy_patch_shapes(obs_space) self._obs_space = obs_space if not options: from ray.rllib.models.catalog import MODEL_DEFAULTS @@ -109,6 +109,7 @@ def observation_space(self) -> gym.Space: return obs_space +@DeveloperAPI class GenericPixelPreprocessor(Preprocessor): """Generic image preprocessor. @@ -151,6 +152,7 @@ def transform(self, observation: TensorType) -> np.ndarray: return scaled +@DeveloperAPI class AtariRamPreprocessor(Preprocessor): @override(Preprocessor) def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: @@ -162,6 +164,7 @@ def transform(self, observation: TensorType) -> np.ndarray: return (observation.astype("float32") - 128) / 128 +@DeveloperAPI class OneHotPreprocessor(Preprocessor): """One-hot preprocessor for Discrete and MultiDiscrete spaces. @@ -195,6 +198,7 @@ def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None array[offset : offset + self.size] = self.transform(observation) +@PublicAPI class NoPreprocessor(Preprocessor): @override(Preprocessor) def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: @@ -215,6 +219,7 @@ def observation_space(self) -> gym.Space: return self._obs_space +@DeveloperAPI class TupleFlatteningPreprocessor(Preprocessor): """Preprocesses each tuple element, then flattens it all into a vector. @@ -254,6 +259,7 @@ def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None offset += p.size +@DeveloperAPI class DictFlatteningPreprocessor(Preprocessor): """Preprocesses each dict value, then flattens it all into a vector. @@ -297,6 +303,7 @@ def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None offset += p.size +@DeveloperAPI class RepeatedValuesPreprocessor(Preprocessor): """Pads and batches the variable-length list value.""" @@ -345,7 +352,7 @@ def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None def get_preprocessor(space: gym.Space) -> type: """Returns an appropriate preprocessor class for the given space.""" - legacy_patch_shapes(space) + _legacy_patch_shapes(space) obs_shape = space.shape if isinstance(space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)): @@ -366,7 +373,7 @@ def get_preprocessor(space: gym.Space) -> type: return preprocessor -def legacy_patch_shapes(space: gym.Space) -> List[int]: +def _legacy_patch_shapes(space: gym.Space) -> List[int]: """Assigns shapes to spaces that don't have shapes. This is only needed for older gym versions that don't set shapes properly @@ -379,7 +386,7 @@ def legacy_patch_shapes(space: gym.Space) -> List[int]: elif isinstance(space, gym.spaces.Tuple): shapes = [] for s in space.spaces: - shape = legacy_patch_shapes(s) + shape = _legacy_patch_shapes(s) shapes.append(shape) space.shape = tuple(shapes) diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py index 0c921bf20442..abf2b4c201a7 100644 --- a/rllib/models/tf/fcnet.py +++ b/rllib/models/tf/fcnet.py @@ -8,11 +8,13 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import TensorType, List, ModelConfigDict +from ray.rllib.utils.annotations import DeveloperAPI tf1, tf, tfv = try_import_tf() # TODO: (sven) obsolete this class once we only support native keras models. +@DeveloperAPI class FullyConnectedNetwork(TFModelV2): """Generic fully connected network implemented in ModelV2 API.""" @@ -150,6 +152,7 @@ def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) +@DeveloperAPI class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object): """Generic fully connected network implemented in tf Keras.""" diff --git a/rllib/models/tf/misc.py b/rllib/models/tf/misc.py index 4467c5fea331..2e293917b94b 100644 --- a/rllib/models/tf/misc.py +++ b/rllib/models/tf/misc.py @@ -1,12 +1,14 @@ import numpy as np from typing import Tuple, Any, Optional +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() +@DeveloperAPI def normc_initializer(std: float = 1.0) -> Any: def _initializer(shape, dtype=None, partition_info=None): out = np.random.randn(*shape).astype( @@ -18,6 +20,7 @@ def _initializer(shape, dtype=None, partition_info=None): return _initializer +@DeveloperAPI def conv2d( x: TensorType, num_filters: int, @@ -65,6 +68,7 @@ def conv2d( return tf1.nn.conv2d(x, w, stride_shape, pad) + b +@DeveloperAPI def linear( x: TensorType, size: int, @@ -79,5 +83,6 @@ def linear( return tf.matmul(x, w) + b +@DeveloperAPI def flatten(x: TensorType) -> TensorType: return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 2a1fa7e583ed..849444b51675 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -117,6 +117,7 @@ def get_initial_state(self): raise NotImplementedError("You must implement this for a RNN model") +@DeveloperAPI class LSTMWrapper(RecurrentNetwork): """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.""" @@ -280,6 +281,7 @@ def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) +@DeveloperAPI class Keras_LSTMWrapper(tf.keras.Model if tf else object): """A tf keras auto-LSTM wrapper used when `use_lstm`=True.""" diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index b3a6d433e6c4..ba887b350407 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -46,6 +46,7 @@ def sampled_action_logp(self) -> TensorType: return self.sampled_action_logp_op +@DeveloperAPI class Categorical(TFActionDistribution): """Categorical distribution for discrete action spaces.""" diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index 110ed848dac2..26ba9ba92a05 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -1,6 +1,7 @@ import gym from typing import Dict, List, Optional, Sequence +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.utils import get_activation_fn, get_filter_config @@ -12,6 +13,7 @@ # TODO: (sven) obsolete this class once we only support native keras models. +@DeveloperAPI class VisionNetwork(TFModelV2): """Generic vision network implemented in ModelV2 API. @@ -264,6 +266,7 @@ def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) +@DeveloperAPI class Keras_VisionNetwork(tf.keras.Model if tf else object): """Generic vision network implemented in tf keras. diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py index 546690eb8988..4948cec918e6 100644 --- a/rllib/models/torch/misc.py +++ b/rllib/models/torch/misc.py @@ -3,12 +3,14 @@ from typing import Union, Tuple, Any, List from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() +@DeveloperAPI def normc_initializer(std: float = 1.0) -> Any: def initializer(tensor): tensor.data.normal_(0, 1) @@ -17,6 +19,7 @@ def initializer(tensor): return initializer +@DeveloperAPI def same_padding( in_size: Tuple[int, int], filter_size: Tuple[int, int], @@ -61,6 +64,7 @@ def same_padding( return padding, output +@DeveloperAPI class SlimConv2d(nn.Module): """Simple mock of tf.slim Conv2d""" @@ -120,6 +124,7 @@ def forward(self, x: TensorType) -> TensorType: return self._model(x) +@DeveloperAPI class SlimFC(nn.Module): """Simple PyTorch version of `linear` function""" @@ -164,6 +169,7 @@ def forward(self, x: TensorType) -> TensorType: return self._model(x) +@DeveloperAPI class AppendBiasLayer(nn.Module): """Simple bias appending layer for free_log_std.""" @@ -177,6 +183,7 @@ def forward(self, x: TensorType) -> TensorType: return out +@DeveloperAPI class Reshape(nn.Module): """Standard module that reshapes/views a tensor""" diff --git a/rllib/models/utils.py b/rllib/models/utils.py index 6c1c580bf074..1eae1520d227 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -1,8 +1,10 @@ from typing import Optional +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch +@DeveloperAPI def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): """Returns a framework specific activation function, given a name string. @@ -65,6 +67,7 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): ) +@DeveloperAPI def get_filter_config(shape): """Returns a default Conv2D filter config (list) for a given image shape. @@ -134,6 +137,7 @@ def get_filter_config(shape): ) +@DeveloperAPI def get_initializer(name, framework="tf"): """Returns a framework specific initializer, given a name string. diff --git a/rllib/offline/d4rl_reader.py b/rllib/offline/d4rl_reader.py index 80022d1b1c5a..a1063bf8ac9d 100644 --- a/rllib/offline/d4rl_reader.py +++ b/rllib/offline/d4rl_reader.py @@ -26,7 +26,7 @@ def __init__(self, inputs: str, ioctx: IOContext = None): import d4rl self.env = gym.make(inputs) - self.dataset = convert_to_batch(d4rl.qlearning_dataset(self.env)) + self.dataset = _convert_to_batch(d4rl.qlearning_dataset(self.env)) assert self.dataset.count >= 1 self.counter = 0 @@ -39,7 +39,7 @@ def next(self) -> SampleBatchType: return self.dataset.slice(start=self.counter, end=self.counter + 1) -def convert_to_batch(dataset: Dict) -> SampleBatchType: +def _convert_to_batch(dataset: Dict) -> SampleBatchType: # Converts D4RL dataset to SampleBatch d = {} d[SampleBatch.OBS] = dataset["observations"] diff --git a/rllib/offline/dataset_reader.py b/rllib/offline/dataset_reader.py index 3d1a5d6c31b8..4790f1eb4586 100644 --- a/rllib/offline/dataset_reader.py +++ b/rllib/offline/dataset_reader.py @@ -14,7 +14,7 @@ DEFAULT_NUM_CPUS_PER_TASK = 0.5 -def get_resource_bundles(config: TrainerConfigDict): +def _get_resource_bundles(config: TrainerConfigDict): input_config = config.get("input_config", {}) parallelism = input_config.get("parallelism", config.get("num_workers", 1)) cpus_per_task = input_config.get( @@ -23,6 +23,7 @@ def get_resource_bundles(config: TrainerConfigDict): return [{"CPU": math.ceil(parallelism * cpus_per_task)}] +@PublicAPI def get_dataset_and_shards( config: TrainerConfigDict, num_workers: int, local_worker: bool ) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]): diff --git a/rllib/offline/estimators/importance_sampling.py b/rllib/offline/estimators/importance_sampling.py index f00cad42ac0f..7138125fdc72 100644 --- a/rllib/offline/estimators/importance_sampling.py +++ b/rllib/offline/estimators/importance_sampling.py @@ -1,8 +1,9 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.typing import SampleBatchType +@DeveloperAPI class ImportanceSampling(OffPolicyEstimator): """The step-wise IS estimator. diff --git a/rllib/offline/estimators/weighted_importance_sampling.py b/rllib/offline/estimators/weighted_importance_sampling.py index 7c6876b8a574..bf772cea6f99 100644 --- a/rllib/offline/estimators/weighted_importance_sampling.py +++ b/rllib/offline/estimators/weighted_importance_sampling.py @@ -1,9 +1,10 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.typing import SampleBatchType +@DeveloperAPI class WeightedImportanceSampling(OffPolicyEstimator): """The weighted step-wise IS estimator. diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 4d7a1502a2c1..7c0aa3d34cf0 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -24,7 +24,7 @@ MultiAgentBatch, SampleBatch, ) -from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.compression import unpack_if_needed from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action from ray.rllib.utils.typing import Any, FileType, SampleBatchType @@ -71,6 +71,7 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict: return json_data +@DeveloperAPI def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]): # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch). if "type" in json_data: diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index f4bcb3416ba2..9f1e2068daa1 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -14,7 +14,9 @@ logger = logging.getLogger(__name__) -OffPolicyEstimate = namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"]) +OffPolicyEstimate = DeveloperAPI( + namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"]) +) @DeveloperAPI diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index 2389c3d741b6..ca26c5a538fa 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -16,6 +16,7 @@ def write(self, sample_batch: SampleBatchType): raise NotImplementedError +@PublicAPI class NoopOutput(OutputWriter): """Output writer that discards its outputs.""" diff --git a/rllib/offline/resource.py b/rllib/offline/resource.py index d176395e41f4..084d634781d3 100644 --- a/rllib/offline/resource.py +++ b/rllib/offline/resource.py @@ -1,10 +1,12 @@ from ray.rllib.offline.dataset_reader import ( - get_resource_bundles as dataset_reader_get_resource_bundles, + _get_resource_bundles as dataset_reader_get_resource_bundles, ) +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.typing import PartialTrainerConfigDict from typing import Dict, List +@PublicAPI def get_offline_io_resource_bundles( config: PartialTrainerConfigDict, ) -> List[Dict[str, float]]: diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 249672d51dec..cc6ab1143184 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -893,6 +893,7 @@ def _do_loss_init(self, train_batch: SampleBatch): return losses +@DeveloperAPI class TFMultiGPUTowerStack: """Optimizer that runs in parallel across multiple local devices. @@ -1002,7 +1003,7 @@ def __init__( if self.policy.config["_tf_policy_handles_more_than_one_loss"]: avgs = [] for i, optim in enumerate(self.optimizers): - avg = average_gradients([t.grads[i] for t in self._towers]) + avg = _average_gradients([t.grads[i] for t in self._towers]) if grad_norm_clipping: clipped = [] for grad, _ in avg: @@ -1031,7 +1032,7 @@ def __init__( [o.apply_gradients(a) for o, a in zip(self.optimizers, avgs)] ) else: - avg = average_gradients([t.grads for t in self._towers]) + avg = _average_gradients([t.grads for t in self._towers]) if grad_norm_clipping: clipped = [] for grad, _ in avg: @@ -1133,7 +1134,7 @@ def load_data(self, sess, inputs, state_inputs): if len(smallest_array) < sequences_per_minibatch: # Dynamically shrink the batch size if insufficient data - sequences_per_minibatch = make_divisible_by( + sequences_per_minibatch = _make_divisible_by( len(smallest_array), len(self.devices) ) @@ -1160,7 +1161,7 @@ def load_data(self, sess, inputs, state_inputs): if len(state_inputs) > 0: # First truncate the RNN state arrays to the sequences_per_minib. state_inputs = [ - make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs + _make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs ] # Then truncate the data inputs to match inputs = [arr[: len(state_inputs[0]) * seq_len] for arr in inputs] @@ -1176,7 +1177,7 @@ def load_data(self, sess, inputs, state_inputs): else: truncated_len = 0 for ph, arr in zip(self.loss_inputs, inputs): - truncated_arr = make_divisible_by(arr, sequences_per_minibatch) + truncated_arr = _make_divisible_by(arr, sequences_per_minibatch) feed_dict[ph] = truncated_arr if truncated_len == 0: truncated_len = len(truncated_arr) @@ -1259,7 +1260,7 @@ def _setup_device(self, tower_i, device, device_input_placeholders, num_data_in) device_input_slices.append(current_slice) graph_obj = self.policy_copy(device_input_slices) device_grads = graph_obj.gradients(self.optimizers, graph_obj._losses) - return Tower( + return _Tower( tf.group(*[batch.initializer for batch in device_input_batches]), device_grads, graph_obj, @@ -1267,16 +1268,16 @@ def _setup_device(self, tower_i, device, device_input_placeholders, num_data_in) # Each tower is a copy of the loss graph pinned to a specific device. -Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"]) +_Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"]) -def make_divisible_by(a, n): +def _make_divisible_by(a, n): if type(a) is int: return a - a % n return a[0 : a.shape[0] - a.shape[0] % n] -def average_gradients(tower_grads): +def _average_gradients(tower_grads): """Averages gradients across towers. Calculate the average gradient for each shared variable across all towers. diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index cf3dbabc06b3..d871a6849bc0 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -16,7 +16,7 @@ from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import add_mixins, force_list -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics import NUM_AGENT_STEPS_TRAINED @@ -71,7 +71,7 @@ def _map(x): ) -def convert_eager_inputs(func): +def _convert_eager_inputs(func): @functools.wraps(func) def _func(*args, **kwargs): if tf.executing_eagerly(): @@ -89,7 +89,7 @@ def _func(*args, **kwargs): return _func -def convert_eager_outputs(func): +def _convert_eager_outputs(func): @functools.wraps(func) def _func(*args, **kwargs): out = func(*args, **kwargs) @@ -109,7 +109,7 @@ def _disallow_var_creation(next_creator, **kw): ) -def check_too_many_retraces(obj): +def _check_too_many_retraces(obj): """Asserts that a given number of re-traces is not breached.""" def _func(self_, *args, **kwargs): @@ -129,13 +129,14 @@ def _func(self_, *args, **kwargs): return _func +@DeveloperAPI class EagerTFPolicy(Policy): """Dummy class to recognize any eagerized TFPolicy by its inheritance.""" pass -def traced_eager_policy(eager_policy_cls): +def _traced_eager_policy(eager_policy_cls): """Wrapper class that enables tracing for all eager policy methods. This is enabled by the `--trace`/`eager_tracing=True` config when @@ -150,7 +151,7 @@ def __init__(self, *args, **kwargs): self._traced_apply_gradients_helper = False super(TracedEagerPolicy, self).__init__(*args, **kwargs) - @check_too_many_retraces + @_check_too_many_retraces @override(Policy) def compute_actions_from_input_dict( self, @@ -164,7 +165,7 @@ def compute_actions_from_input_dict( # Create a traced version of `self._compute_actions_helper`. if self._traced_compute_actions_helper is False and not self._no_tracing: - self._compute_actions_helper = convert_eager_inputs( + self._compute_actions_helper = _convert_eager_inputs( tf.function( super(TracedEagerPolicy, self)._compute_actions_helper, autograph=False, @@ -183,14 +184,14 @@ def compute_actions_from_input_dict( **kwargs, ) - @check_too_many_retraces + @_check_too_many_retraces @override(eager_policy_cls) def learn_on_batch(self, samples): """Traced version of Policy.learn_on_batch.""" # Create a traced version of `self._learn_on_batch_helper`. if self._traced_learn_on_batch_helper is False and not self._no_tracing: - self._learn_on_batch_helper = convert_eager_inputs( + self._learn_on_batch_helper = _convert_eager_inputs( tf.function( super(TracedEagerPolicy, self)._learn_on_batch_helper, autograph=False, @@ -203,14 +204,14 @@ def learn_on_batch(self, samples): # apply_gradients (which will call the traced helper). return super(TracedEagerPolicy, self).learn_on_batch(samples) - @check_too_many_retraces + @_check_too_many_retraces @override(eager_policy_cls) def compute_gradients(self, samples: SampleBatch) -> ModelGradients: """Traced version of Policy.compute_gradients.""" # Create a traced version of `self._compute_gradients_helper`. if self._traced_compute_gradients_helper is False and not self._no_tracing: - self._compute_gradients_helper = convert_eager_inputs( + self._compute_gradients_helper = _convert_eager_inputs( tf.function( super(TracedEagerPolicy, self)._compute_gradients_helper, autograph=False, @@ -223,14 +224,14 @@ def compute_gradients(self, samples: SampleBatch) -> ModelGradients: # `compute_gradients()` (which will call the traced helper). return super(TracedEagerPolicy, self).compute_gradients(samples) - @check_too_many_retraces + @_check_too_many_retraces @override(Policy) def apply_gradients(self, grads: ModelGradients) -> None: """Traced version of Policy.apply_gradients.""" # Create a traced version of `self._apply_gradients_helper`. if self._traced_apply_gradients_helper is False and not self._no_tracing: - self._apply_gradients_helper = convert_eager_inputs( + self._apply_gradients_helper = _convert_eager_inputs( tf.function( super(TracedEagerPolicy, self)._apply_gradients_helper, autograph=False, @@ -253,7 +254,7 @@ def with_tracing(cls): return TracedEagerPolicy -class OptimizerWrapper: +class _OptimizerWrapper: def __init__(self, tape): self.tape = tape @@ -261,7 +262,7 @@ def compute_gradients(self, loss, var_list): return list(zip(self.tape.gradient(loss, var_list), var_list)) -def build_eager_tf_policy( +def _build_eager_tf_policy( name, loss_fn, get_default_config=None, @@ -909,7 +910,7 @@ def _compute_gradients_helper(self, samples): # object looks like a "classic" tf.optimizer. This way, custom # compute_gradients_fn will work on both tf static graph # and tf-eager. - optimizer = OptimizerWrapper(tape) + optimizer = _OptimizerWrapper(tape) # More than one loss terms/optimizers. if self.config["_tf_policy_handles_more_than_one_loss"]: grads_and_vars = compute_gradients_fn( @@ -993,7 +994,7 @@ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch): @classmethod def with_tracing(cls): - return traced_eager_policy(cls) + return _traced_eager_policy(cls) eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" diff --git a/rllib/policy/eager_tf_policy_v2.py b/rllib/policy/eager_tf_policy_v2.py index 5de55eb17ff6..505ca3eaf39d 100644 --- a/rllib/policy/eager_tf_policy_v2.py +++ b/rllib/policy/eager_tf_policy_v2.py @@ -17,8 +17,8 @@ from ray.rllib.policy.eager_tf_policy import ( _convert_to_tf, _disallow_var_creation, - OptimizerWrapper, - traced_eager_policy, + _OptimizerWrapper, + _traced_eager_policy, ) from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size @@ -846,7 +846,7 @@ def _compute_gradients_helper(self, samples): # object looks like a "classic" tf.optimizer. This way, custom # compute_gradients_fn will work on both tf static graph # and tf-eager. - optimizer = OptimizerWrapper(tape) + optimizer = _OptimizerWrapper(tape) # More than one loss terms/optimizers. if self.config["_tf_policy_handles_more_than_one_loss"]: grads_and_vars = self.compute_gradients_fn( @@ -924,4 +924,4 @@ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch): @classmethod def with_tracing(cls): - return traced_eager_policy(cls) + return _traced_eager_policy(cls) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index f2f5db9d247b..33c400fb8f3e 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -26,6 +26,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import ( + PublicAPI, DeveloperAPI, ExperimentalAPI, OverrideToImplementCustomLogic, @@ -69,22 +70,24 @@ # "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}), # "pol2": PolicySpec(config={"lr": 0.001}), # } -PolicySpec = namedtuple( - "PolicySpec", - [ - # If None, use the Trainer's default policy class stored under - # `Trainer._policy_class`. - "policy_class", - # If None, use the env's observation space. If None and there is no Env - # (e.g. offline RL), an error is thrown. - "observation_space", - # If None, use the env's action space. If None and there is no Env - # (e.g. offline RL), an error is thrown. - "action_space", - # Overrides defined keys in the main Trainer config. - # If None, use {}. - "config", - ], +PolicySpec = PublicAPI( + namedtuple( + "PolicySpec", + [ + # If None, use the Trainer's default policy class stored under + # `Trainer._policy_class`. + "policy_class", + # If None, use the env's observation space. If None and there is no Env + # (e.g. offline RL), an error is thrown. + "observation_space", + # If None, use the env's action space. If None and there is no Env + # (e.g. offline RL), an error is thrown. + "action_space", + # Overrides defined keys in the main Trainer config. + # If None, use {}. + "config", + ], + ) ) # defaults=(None, None, None, None) # TODO: From 3.7 on, we could pass `defaults` into the above constructor. # We still support py3.6. @@ -818,7 +821,7 @@ def _create_exploration(self) -> Exploration: This method only exists b/c some Trainers do not use TfPolicy nor TorchPolicy, but inherit directly from Policy. Others inherit from - TfPolicy w/o using DynamicTfPolicy. + TfPolicy w/o using DynamicTFPolicy. TODO(sven): unify these cases. Returns: diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index 52e0ce7ea07f..1c309d8c5efe 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -6,7 +6,7 @@ from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING from ray.rllib.policy.policy import PolicySpec -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.threading import with_lock @@ -19,6 +19,7 @@ tf1, tf, tfv = try_import_tf() +@PublicAPI class PolicyMap(dict): """Maps policy IDs to Policy objects. diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index f58e5d2e7bfc..ef02aae2486e 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -360,6 +360,7 @@ def chop_into_sequences( return feature_sequences, initial_states, seq_lens +@DeveloperAPI def timeslice_along_seq_lens_with_overlap( sample_batch: SampleBatchType, seq_lens: Optional[List[int]] = None, diff --git a/rllib/policy/tests/test_view_requirement.py b/rllib/policy/tests/test_view_requirement.py new file mode 100644 index 000000000000..381314a2a8e0 --- /dev/null +++ b/rllib/policy/tests/test_view_requirement.py @@ -0,0 +1,38 @@ +import gym +import json +import unittest + +from ray.rllib.policy.view_requirement import ViewRequirement + + +class TestViewRequirement(unittest.TestCase): + def test_serialize_view_requirement(self): + """Test serializing simple ViewRequirement into JSON serializable dict""" + vr = ViewRequirement( + "obs", + shift=[-1], + used_for_training=False, + used_for_compute_actions=True, + batch_repeat_value=1, + ) + d = vr.to_dict() + self.assertEqual(d["data_col"], "obs") + self.assertEqual(d["space"]["space"], "box") + + # Make sure serialized dict is JSON serializable. + s = json.dumps(d) + d2 = json.loads(s) + + self.assertEqual(d2["used_for_training"], False) + self.assertEqual(d2["used_for_compute_actions"], True) + + vr2 = ViewRequirement.from_dict(d2) + self.assertEqual(vr2.data_col, "obs") + self.assertTrue(isinstance(vr2.space, gym.spaces.Box)) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/policy/tf_mixins.py b/rllib/policy/tf_mixins.py index 085d0329eab4..558da1e5e14e 100644 --- a/rllib/policy/tf_mixins.py +++ b/rllib/policy/tf_mixins.py @@ -1,6 +1,7 @@ +import gym +import logging from typing import Dict, List, Union -from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch @@ -13,8 +14,11 @@ LocalOptimizer, ModelGradients, TensorType, + TrainerConfigDict, ) +logger = logging.getLogger(__name__) + tf1, tf, tfv = try_import_tf() @@ -230,10 +234,31 @@ def value(*args, **kwargs): return tf.constant(0.0) self._value = value + self._should_cache_extra_action = config["framework"] == "tf" self._cached_extra_action_fetches = None + def _extra_action_out_impl(self) -> Dict[str, TensorType]: + extra_action_out = super().extra_action_out_fn() + # Keras models return values for each call in third return argument + # (dict). + if isinstance(self.model, tf.keras.Model): + return extra_action_out + # Return value function outputs. VF estimates will hence be added to the + # SampleBatches produced by the sampler(s) to generate the train batches + # going into the loss function. + extra_action_out.update( + { + SampleBatch.VF_PREDS: self.model.value_function(), + } + ) + return extra_action_out + def extra_action_out_fn(self) -> Dict[str, TensorType]: - # Note: there are 2 reasons we are caching the extra_action_fetches here. + if not self._should_cache_extra_action: + return self._extra_action_out_impl() + + # Note: there are 2 reasons we are caching the extra_action_fetches for + # TF1 static graph here. # 1. for better performance, so we don't query base class and model for # extra fetches every single time. # 2. for correctness. TF1 is special because the static graph may contain @@ -248,67 +273,77 @@ def extra_action_out_fn(self) -> Dict[str, TensorType]: if self._cached_extra_action_fetches is not None: return self._cached_extra_action_fetches - # TODO: (sven) Deprecate once we only allow native keras models. - self._cached_extra_action_fetches = super().extra_action_out_fn() - # Keras models return values for each call in third return argument - # (dict). - if isinstance(self.model, tf.keras.Model): - return self._cached_extra_action_fetches - # Return value function outputs. VF estimates will hence be added to the - # SampleBatches produced by the sampler(s) to generate the train batches - # going into the loss function. - self._cached_extra_action_fetches.update( - { - SampleBatch.VF_PREDS: self.model.value_function(), - } - ) + self._cached_extra_action_fetches = self._extra_action_out_impl() return self._cached_extra_action_fetches -class ComputeGAEMixIn: - """Postprocess SampleBatch to Compute GAE before they get used for training.""" +class TargetNetworkMixin: + """Assign the `update_target` method to the SimpleQTFPolicy - def __init__(self): - pass + The function is called every `target_network_update_freq` steps by the + master learner. + """ - @DeveloperAPI - def postprocess_trajectory( - self, sample_batch, other_agent_batches=None, episode=None + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, ): - sample_batch = super().postprocess_trajectory(sample_batch) - return compute_gae_for_sample_batch( - self, sample_batch, other_agent_batches, episode - ) + @make_tf_callable(self.get_session()) + def do_update(): + # update_target_fn will be called periodically to copy Q network to + # target Q network + update_target_expr = [] + assert len(self.q_func_vars) == len(self.target_q_func_vars), ( + self.q_func_vars, + self.target_q_func_vars, + ) + for var, var_target in zip(self.q_func_vars, self.target_q_func_vars): + update_target_expr.append(var_target.assign(var)) + logger.debug("Update target op {}".format(var_target)) + return tf.group(*update_target_expr) + self.update_target = do_update -class ComputeAndClipGradsMixIn: - """Compute and maybe clip gradients.""" + @property + def q_func_vars(self): + if not hasattr(self, "_q_func_vars"): + self._q_func_vars = self.model.variables() + return self._q_func_vars - def __init__(self): - pass + @property + def target_q_func_vars(self): + if not hasattr(self, "_target_q_func_vars"): + self._target_q_func_vars = self.target_model.variables() + return self._target_q_func_vars - @DeveloperAPI - def compute_gradients_fn( - self, optimizer: LocalOptimizer, loss: TensorType - ) -> ModelGradients: - # Compute the gradients. - variables = self.model.trainable_variables - if isinstance(self.model, ModelV2): - variables = variables() - grads_and_vars = optimizer.compute_gradients(loss, variables) - - # Clip by global norm, if necessary. - if self.config["grad_clip"] is not None: - # Defuse inf gradients (due to super large losses). - grads = [g for (g, v) in grads_and_vars] - grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - # If the global_norm is inf -> All grads will be NaN. Stabilize this - # here by setting them to 0.0. This will simply ignore destructive loss - # calculations. - self.grads = [ - tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads - ] - clipped_grads_and_vars = list(zip(self.grads, variables)) - return clipped_grads_and_vars - else: - return grads_and_vars + @override(TFPolicy) + def variables(self): + return self.q_func_vars + self.target_q_func_vars + + +# TODO: find a better place for this util, since it's not technically MixIns. +@DeveloperAPI +def compute_gradients( + policy, optimizer: LocalOptimizer, loss: TensorType +) -> ModelGradients: + # Compute the gradients. + variables = policy.model.trainable_variables + if isinstance(policy.model, ModelV2): + variables = variables() + grads_and_vars = optimizer.compute_gradients(loss, variables) + + # Clip by global norm, if necessary. + if policy.config["grad_clip"] is not None: + # Defuse inf gradients (due to super large losses). + grads = [g for (g, v) in grads_and_vars] + grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + # If the global_norm is inf -> All grads will be NaN. Stabilize this + # here by setting them to 0.0. This will simply ignore destructive loss + # calculations. + policy.grads = [tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads] + clipped_grads_and_vars = list(zip(policy.grads, variables)) + return clipped_grads_and_vars + else: + return grads_and_vars diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 491557b2a13d..112f2a4c33f7 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -23,7 +23,7 @@ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.tf_utils import get_gpu_devices -from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils.tf_run_builder import _TFRunBuilder from ray.rllib.utils.typing import ( LocalOptimizer, ModelGradients, @@ -317,7 +317,7 @@ def compute_actions_from_input_dict( # Deprecated dict input. input_dict["is_training"] = False - builder = TFRunBuilder(self.get_session(), "compute_actions_from_input_dict") + builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict") obs_batch = input_dict[SampleBatch.OBS] to_fetch = self._build_compute_actions( builder, input_dict=input_dict, explore=explore, timestep=timestep @@ -354,7 +354,7 @@ def compute_actions( explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep - builder = TFRunBuilder(self.get_session(), "compute_actions") + builder = _TFRunBuilder(self.get_session(), "compute_actions") input_dict = {SampleBatch.OBS: obs_batch, "is_training": False} if state_batches: @@ -402,7 +402,7 @@ def compute_log_likelihoods( explore=False, tf_sess=self.get_session() ) - builder = TFRunBuilder(self.get_session(), "compute_log_likelihoods") + builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods") # Normalize actions if necessary. if actions_normalized is False and self.config["normalize_actions"]: @@ -440,7 +440,7 @@ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorTy # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) - builder = TFRunBuilder(self.get_session(), "learn_on_batch") + builder = _TFRunBuilder(self.get_session(), "learn_on_batch") # Callback handling. learn_stats = {} @@ -466,7 +466,7 @@ def compute_gradients( assert self.loss_initialized() # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) - builder = TFRunBuilder(self.get_session(), "compute_gradients") + builder = _TFRunBuilder(self.get_session(), "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) @@ -474,7 +474,7 @@ def compute_gradients( @DeveloperAPI def apply_gradients(self, gradients: ModelGradients) -> None: assert self.loss_initialized() - builder = TFRunBuilder(self.get_session(), "apply_gradients") + builder = _TFRunBuilder(self.get_session(), "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches) diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 4c895d2bf923..dff3dccb89a0 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -370,7 +370,7 @@ def with_updates(**overrides): return build_tf_policy(**dict(original_kwargs, **overrides)) def as_eager(): - return eager_tf_policy.build_eager_tf_policy(**original_kwargs) + return eager_tf_policy._build_eager_tf_policy(**original_kwargs) policy_cls.with_updates = staticmethod(with_updates) policy_cls.as_eager = staticmethod(as_eager) diff --git a/rllib/policy/torch_mixins.py b/rllib/policy/torch_mixins.py index 651b726dc925..5f568f21c641 100644 --- a/rllib/policy/torch_mixins.py +++ b/rllib/policy/torch_mixins.py @@ -1,13 +1,11 @@ from typing import Dict, List, Union -from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.schedules import PiecewiseSchedule -from ray.rllib.utils.torch_utils import apply_grad_clipping from ray.rllib.utils.typing import ( TensorType, ) @@ -199,36 +197,3 @@ def set_weights(self, weights): # at the same time. TorchPolicy.set_weights(self, weights) self.update_target() - - -class ComputeGAEMixIn: - """Postprocess SampleBatch to Compute GAE before they get used for training.""" - - def __init__(self): - pass - - @DeveloperAPI - def postprocess_trajectory( - self, sample_batch, other_agent_batches=None, episode=None - ): - # Do all post-processing always with no_grad(). - # Not using this here will introduce a memory leak - # in torch (issue #6962). - # TODO: no_grad still necessary? - with torch.no_grad(): - return compute_gae_for_sample_batch( - self, sample_batch, other_agent_batches, episode - ) - - -class GradClippingMixin: - """Apply gradient clipping.""" - - def __init__(self): - pass - - @DeveloperAPI - def extra_grad_process( - self, optimizer: "torch.optim.Optimizer", loss: TensorType - ) -> Dict[str, TensorType]: - return apply_grad_clipping(self, optimizer, loss) diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index 89f6bcd3f4e0..aac9bcfe3454 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -1,12 +1,17 @@ import gym -import numpy as np -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.serialization import ( + gym_space_to_dict, + gym_space_from_dict, +) torch, _ = try_import_torch() +@PublicAPI class ViewRequirement: """Single view requirement (for one column in an SampleBatch/input_dict). @@ -74,8 +79,6 @@ def __init__( ) self.shift = shift - if isinstance(self.shift, (list, tuple)): - self.shift = np.array(self.shift) # Special case: Providing a (probably larger) range of indices, e.g. # "-100:0" (past 100 timesteps plus current one). @@ -90,3 +93,40 @@ def __init__( self.used_for_compute_actions = used_for_compute_actions self.used_for_training = used_for_training + + def __str__(self): + """For easier inspection of view requirements.""" + return "|".join( + [ + str(v) + for v in [ + self.data_col, + self.space, + self.shift, + self.shift_from, + self.shift_to, + self.index, + self.batch_repeat_value, + self.used_for_training, + self.used_for_compute_actions, + ] + ] + ) + + def to_dict(self) -> Dict: + """Return a dict for this ViewRequirement that can be JSON serialized.""" + return { + "data_col": self.data_col, + "space": gym_space_to_dict(self.space), + "shift": self.shift, + "index": self.index, + "batch_repeat_value": self.batch_repeat_value, + "used_for_training": self.used_for_training, + "used_for_compute_actions": self.used_for_compute_actions, + } + + @classmethod + def from_dict(cls, d: Dict): + """Construct a ViewRequirement instance from JSON deserialized dict.""" + d["space"] = gym_space_from_dict(d["space"]) + return cls(**d) diff --git a/rllib/tests/test_execution.py b/rllib/tests/test_execution.py index 0b6965af0227..a587f870cd92 100644 --- a/rllib/tests/test_execution.py +++ b/rllib/tests/test_execution.py @@ -5,7 +5,7 @@ import unittest import ray -from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy +from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER @@ -38,13 +38,13 @@ def iter_list(values): def make_workers(n): local = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=PPOTFPolicy, + policy_spec=PPOStaticGraphTFPolicy, rollout_fragment_length=100, ) remotes = [ RolloutWorker.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=PPOTFPolicy, + policy_spec=PPOStaticGraphTFPolicy, rollout_fragment_length=100, ) for _ in range(n) @@ -210,7 +210,6 @@ def test_store_to_replay_local(self): num_shards=1, learning_starts=200, capacity=1000, - replay_batch_size=100, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001, @@ -226,7 +225,7 @@ def test_store_to_replay_local(self): next(b) assert buf.sample(100).count == 100 - replay_op = Replay(local_buffer=buf) + replay_op = Replay(local_buffer=buf, num_items_to_replay=100) assert next(replay_op).count == 100 def test_store_to_replay_actor(self): @@ -235,7 +234,6 @@ def test_store_to_replay_actor(self): num_shards=1, learning_starts=200, capacity=1000, - replay_batch_size=100, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001, @@ -251,7 +249,7 @@ def test_store_to_replay_actor(self): next(b) assert ray.get(actor.sample.remote(100)).count == 100 - replay_op = Replay(actors=[actor]) + replay_op = Replay(actors=[actor], num_items_to_replay=100) assert next(replay_op).count == 100 diff --git a/rllib/tuned_examples/dqn/stateless-cartpole-r2d2-fake-gpus.yaml b/rllib/tuned_examples/dqn/stateless-cartpole-r2d2-fake-gpus.yaml index 7730634e5652..81414b7d84f2 100644 --- a/rllib/tuned_examples/dqn/stateless-cartpole-r2d2-fake-gpus.yaml +++ b/rllib/tuned_examples/dqn/stateless-cartpole-r2d2-fake-gpus.yaml @@ -2,17 +2,18 @@ stateless-cartpole-r2d2: env: ray.rllib.examples.env.stateless_cartpole.StatelessCartPole run: R2D2 stop: - episode_reward_mean: 150 - timesteps_total: 1000000 + episode_reward_mean: 100 + timesteps_total: 50000 config: # Works for both torch and tf. framework: tf - num_workers: 0 + num_workers: 4 # R2D2 settings. replay_buffer_config: type: MultiAgentReplayBuffer + storage_unit: sequences replay_burn_in: 20 - zero_init_states: true + zero_init_states: true #dueling: false lr: 0.0005 # Give some more time to explore. diff --git a/rllib/tuned_examples/dqn/stateless-cartpole-r2d2.yaml b/rllib/tuned_examples/dqn/stateless-cartpole-r2d2.yaml index a3e89487f114..72241f1e4790 100644 --- a/rllib/tuned_examples/dqn/stateless-cartpole-r2d2.yaml +++ b/rllib/tuned_examples/dqn/stateless-cartpole-r2d2.yaml @@ -2,17 +2,18 @@ stateless-cartpole-r2d2: env: ray.rllib.examples.env.stateless_cartpole.StatelessCartPole run: R2D2 stop: - episode_reward_mean: 150 - timesteps_total: 1000000 + episode_reward_mean: 100 + timesteps_total: 50000 config: # Works for both torch and tf. framework: tf - num_workers: 0 + num_workers: 4 # R2D2 settings. replay_buffer_config: type: MultiAgentReplayBuffer + storage_unit: sequences replay_burn_in: 20 - zero_init_states: true + zero_init_states: true #dueling: false lr: 0.0005 # Give some more time to explore. diff --git a/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml b/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml index 6cc9f7a159c0..9cee89a5f9e0 100644 --- a/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml +++ b/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml @@ -1,8 +1,8 @@ -two-step-game-qmix-with-qmix-mixer: +two-step-game-maddpg: env: ray.rllib.examples.env.two_step_game.TwoStepGame run: MADDPG stop: - episode_reward_mean: 8.0 + episode_reward_mean: 7.2 timesteps_total: 20000 config: # MADDPG only supports tf for now. diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index fc307073c6ac..3b8e426d05de 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -35,6 +35,7 @@ from ray.tune.utils import merge_dicts, deep_update +@DeveloperAPI def add_mixins(base, mixins, reversed=False): """Returns a new class with mixins applied in priority order.""" @@ -56,6 +57,7 @@ class new_base(mixins.pop(), base): return base +@DeveloperAPI def force_list(elements=None, to_tuple=False): """ Makes sure `elements` is returned as a list, whether `elements` is a single @@ -83,6 +85,7 @@ def force_list(elements=None, to_tuple=False): ) +@DeveloperAPI class NullContextManager(contextlib.AbstractContextManager): """No-op context manager""" diff --git a/rllib/utils/debug/deterministic.py b/rllib/utils/debug/deterministic.py index 15cefc645a12..a95511c07af2 100644 --- a/rllib/utils/debug/deterministic.py +++ b/rllib/utils/debug/deterministic.py @@ -3,9 +3,11 @@ import random from typing import Optional +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch +@DeveloperAPI def update_global_seed_if_necessary( framework: Optional[str] = None, seed: Optional[int] = None ) -> None: diff --git a/rllib/utils/debug/memory.py b/rllib/utils/debug/memory.py index a33eb17f0036..a6acab4a5133 100644 --- a/rllib/utils/debug/memory.py +++ b/rllib/utils/debug/memory.py @@ -7,30 +7,34 @@ import tree # pip install dm_tree from typing import Callable, DefaultDict, List, Optional, Set +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch # A suspicious memory-allocating stack-trace that we should re-test # to make sure it's not a false positive. -Suspect = namedtuple( - "Suspect", - [ - # The stack trace of the allocation, going back n frames, depending - # on the tracemalloc.start(n) call. - "traceback", - # The amount of memory taken by this particular stack trace - # over the course of the experiment. - "memory_increase", - # The slope of the scipy linear regression (x=iteration; y=memory size). - "slope", - # The rvalue of the scipy linear regression. - "rvalue", - # The memory size history (list of all memory sizes over all iterations). - "hist", - ], +Suspect = DeveloperAPI( + namedtuple( + "Suspect", + [ + # The stack trace of the allocation, going back n frames, depending + # on the tracemalloc.start(n) call. + "traceback", + # The amount of memory taken by this particular stack trace + # over the course of the experiment. + "memory_increase", + # The slope of the scipy linear regression (x=iteration; y=memory size). + "slope", + # The rvalue of the scipy linear regression. + "rvalue", + # The memory size history (list of all memory sizes over all iterations). + "hist", + ], + ) ) +@DeveloperAPI def check_memory_leaks( trainer, to_check: Optional[Set[str]] = None, diff --git a/rllib/utils/debug/summary.py b/rllib/utils/debug/summary.py index 58370f2df815..e0e6737a63c0 100644 --- a/rllib/utils/debug/summary.py +++ b/rllib/utils/debug/summary.py @@ -3,10 +3,12 @@ from typing import Any, Mapping from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils.annotations import DeveloperAPI _printer = pprint.PrettyPrinter(indent=2, width=60) +@DeveloperAPI def summarize(obj: Any) -> Any: """Return a pretty-formatted string for an object. diff --git a/rllib/utils/deprecation.py b/rllib/utils/deprecation.py index 89898d97b303..70d677837a6d 100644 --- a/rllib/utils/deprecation.py +++ b/rllib/utils/deprecation.py @@ -3,6 +3,7 @@ from typing import Optional, Union from ray.util import log_once +from ray.util.annotations import _mark_annotated logger = logging.getLogger(__name__) @@ -102,6 +103,7 @@ def patched_init(*args, **kwargs): return obj_init(*args, **kwargs) obj.__init__ = patched_init + _mark_annotated(obj) # Return the patched class (with the warning/error when # instantiated). return obj diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index 7b3f80fe3b52..17a9af88359c 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -2,6 +2,7 @@ import numpy as np from typing import Optional, Tuple, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -29,6 +30,7 @@ F = nn.functional +@PublicAPI class Curiosity(Exploration): """Implementation of: [1] Curiosity-driven Exploration by Self-supervised Prediction @@ -341,8 +343,12 @@ def _postprocess_torch(self, policy, sample_batch): { SampleBatch.OBS: torch.cat( [ - torch.from_numpy(sample_batch[SampleBatch.OBS]), - torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]), + torch.from_numpy(sample_batch[SampleBatch.OBS]).to( + policy.device + ), + torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to( + policy.device + ), ] ) } diff --git a/rllib/utils/exploration/epsilon_greedy.py b/rllib/utils/exploration/epsilon_greedy.py index ec1df00f8c5a..52e8118a1c12 100644 --- a/rllib/utils/exploration/epsilon_greedy.py +++ b/rllib/utils/exploration/epsilon_greedy.py @@ -4,6 +4,7 @@ import random from typing import Union, Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override @@ -18,6 +19,7 @@ torch, _ = try_import_torch() +@PublicAPI class EpsilonGreedy(Exploration): """Epsilon-greedy Exploration class that produces exploration actions. diff --git a/rllib/utils/exploration/gaussian_noise.py b/rllib/utils/exploration/gaussian_noise.py index 3ab59f495e37..234287852cf3 100644 --- a/rllib/utils/exploration/gaussian_noise.py +++ b/rllib/utils/exploration/gaussian_noise.py @@ -2,6 +2,7 @@ import numpy as np from typing import Union, Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override @@ -22,6 +23,7 @@ torch, _ = try_import_torch() +@PublicAPI class GaussianNoise(Exploration): """An exploration that adds white noise to continuous actions. diff --git a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py index 5fe2ef7b0b31..f6ace79d397b 100644 --- a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -1,6 +1,7 @@ import numpy as np from typing import Optional, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise @@ -18,6 +19,7 @@ torch, _ = try_import_torch() +@PublicAPI class OrnsteinUhlenbeckNoise(GaussianNoise): """An exploration that adds Ornstein-Uhlenbeck noise to continuous actions. diff --git a/rllib/utils/exploration/parameter_noise.py b/rllib/utils/exploration/parameter_noise.py index c7df3f018cdd..332d94327799 100644 --- a/rllib/utils/exploration/parameter_noise.py +++ b/rllib/utils/exploration/parameter_noise.py @@ -2,6 +2,7 @@ import numpy as np from typing import Optional, TYPE_CHECKING, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.env.base_env import BaseEnv from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -25,6 +26,7 @@ torch, _ = try_import_torch() +@PublicAPI class ParameterNoise(Exploration): """An exploration that changes a Model's parameters. diff --git a/rllib/utils/exploration/per_worker_epsilon_greedy.py b/rllib/utils/exploration/per_worker_epsilon_greedy.py index 3873deb88143..e40672d2778a 100644 --- a/rllib/utils/exploration/per_worker_epsilon_greedy.py +++ b/rllib/utils/exploration/per_worker_epsilon_greedy.py @@ -1,10 +1,12 @@ from gym.spaces import Space from typing import Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy from ray.rllib.utils.schedules import ConstantSchedule +@PublicAPI class PerWorkerEpsilonGreedy(EpsilonGreedy): """A per-worker epsilon-greedy class for distributed algorithms. diff --git a/rllib/utils/exploration/per_worker_gaussian_noise.py b/rllib/utils/exploration/per_worker_gaussian_noise.py index 1ff1d1801f0c..06913373be23 100644 --- a/rllib/utils/exploration/per_worker_gaussian_noise.py +++ b/rllib/utils/exploration/per_worker_gaussian_noise.py @@ -1,10 +1,12 @@ from gym.spaces import Space from typing import Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise from ray.rllib.utils.schedules import ConstantSchedule +@PublicAPI class PerWorkerGaussianNoise(GaussianNoise): """A per-worker Gaussian noise class for distributed algorithms. diff --git a/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py index 13e2b2bbd6b1..8ce537d5f0ce 100644 --- a/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py @@ -1,10 +1,12 @@ from gym.spaces import Space from typing import Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.exploration.ornstein_uhlenbeck_noise import OrnsteinUhlenbeckNoise from ray.rllib.utils.schedules import ConstantSchedule +@PublicAPI class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise): """A per-worker Ornstein Uhlenbeck noise class for distributed algorithms. diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index 0ad2aa703302..91782b861303 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -3,6 +3,7 @@ import tree # pip install dm_tree from typing import Union, Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override @@ -17,6 +18,7 @@ torch, _ = try_import_torch() +@PublicAPI class Random(Exploration): """A random action selector (deterministic/greedy for explore=False). diff --git a/rllib/utils/exploration/random_encoder.py b/rllib/utils/exploration/random_encoder.py index 3a16607cab6d..7a5a3401f757 100644 --- a/rllib/utils/exploration/random_encoder.py +++ b/rllib/utils/exploration/random_encoder.py @@ -2,6 +2,7 @@ import numpy as np from typing import List, Optional, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -16,7 +17,7 @@ tf1, tf, tfv = try_import_tf() -class MovingMeanStd: +class _MovingMeanStd: """Track moving mean, std and count.""" def __init__(self, epsilon: float = 1e-4, shape: Optional[List[int]] = None): @@ -78,6 +79,7 @@ def std(self) -> float: return np.sqrt(self.var) +@PublicAPI def update_beta(beta_schedule: str, beta: float, rho: float, step: int) -> float: """Update beta based on schedule and training step. @@ -95,6 +97,7 @@ def update_beta(beta_schedule: str, beta: float, rho: float, step: int) -> float return beta +@PublicAPI def compute_states_entropy( obs_embeds: np.ndarray, embed_dim: int, k_nn: int ) -> np.ndarray: @@ -114,6 +117,7 @@ def compute_states_entropy( return dist.argsort(axis=-1)[:, :k_nn][:, -1] +@PublicAPI class RE3(Exploration): """Random Encoder for Efficient Exploration. diff --git a/rllib/utils/exploration/slate_epsilon_greedy.py b/rllib/utils/exploration/slate_epsilon_greedy.py index 35b2e1ed3a1a..94cbea7a2b10 100644 --- a/rllib/utils/exploration/slate_epsilon_greedy.py +++ b/rllib/utils/exploration/slate_epsilon_greedy.py @@ -1,5 +1,6 @@ from typing import Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy @@ -10,6 +11,7 @@ torch, _ = try_import_torch() +@PublicAPI class SlateEpsilonGreedy(EpsilonGreedy): @override(EpsilonGreedy) def _get_tf_exploration_action_op( diff --git a/rllib/utils/exploration/slate_soft_q.py b/rllib/utils/exploration/slate_soft_q.py index 3fed8157d950..4d43ebee70dc 100644 --- a/rllib/utils/exploration/slate_soft_q.py +++ b/rllib/utils/exploration/slate_soft_q.py @@ -1,5 +1,6 @@ from typing import Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import TensorType @@ -10,6 +11,7 @@ torch, _ = try_import_torch() +@PublicAPI class SlateSoftQ(SoftQ): @override(SoftQ) def get_exploration_action( diff --git a/rllib/utils/exploration/soft_q.py b/rllib/utils/exploration/soft_q.py index 4ed480a641da..3aba619346e2 100644 --- a/rllib/utils/exploration/soft_q.py +++ b/rllib/utils/exploration/soft_q.py @@ -1,6 +1,7 @@ from gym.spaces import Discrete, MultiDiscrete, Space from typing import Union, Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical @@ -9,6 +10,7 @@ from ray.rllib.utils.framework import TensorType +@PublicAPI class SoftQ(StochasticSampling): """Special case of StochasticSampling w/ Categorical and temperature param. diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index c46116fcdf46..704779f8f8da 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -3,6 +3,7 @@ import numpy as np from typing import Optional, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override @@ -20,6 +21,7 @@ torch, _ = try_import_torch() +@PublicAPI class StochasticSampling(Exploration): """An exploration that simply samples from a distribution. diff --git a/rllib/utils/exploration/thompson_sampling.py b/rllib/utils/exploration/thompson_sampling.py index ded1efac2c54..8dbf5d373299 100644 --- a/rllib/utils/exploration/thompson_sampling.py +++ b/rllib/utils/exploration/thompson_sampling.py @@ -1,5 +1,6 @@ from typing import Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration @@ -11,6 +12,7 @@ tf1, tf, tfv = try_import_tf() +@PublicAPI class ThompsonSampling(Exploration): @override(Exploration) def get_exploration_action( diff --git a/rllib/utils/exploration/upper_confidence_bound.py b/rllib/utils/exploration/upper_confidence_bound.py index 919561fc7b04..68cbdd2e84de 100644 --- a/rllib/utils/exploration/upper_confidence_bound.py +++ b/rllib/utils/exploration/upper_confidence_bound.py @@ -1,5 +1,6 @@ from typing import Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration @@ -11,6 +12,7 @@ tf1, tf, tfv = try_import_tf() +@PublicAPI class UpperConfidenceBound(Exploration): @override(Exploration) def get_exploration_action( diff --git a/rllib/utils/filter.py b/rllib/utils/filter.py index 1a9f590c86e2..630440b2bfea 100644 --- a/rllib/utils/filter.py +++ b/rllib/utils/filter.py @@ -3,6 +3,7 @@ import threading import tree # pip install dm_tree +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.numpy import SMALL_NUMBER from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TensorStructType @@ -10,6 +11,7 @@ logger = logging.getLogger(__name__) +@DeveloperAPI class Filter: """Processes input, possibly statefully.""" @@ -41,6 +43,7 @@ def clear_buffer(self): return self.reset_buffer() +@DeveloperAPI class NoFilter(Filter): is_concurrent = True @@ -71,6 +74,7 @@ def as_serializable(self) -> "NoFilter": # http://www.johndcook.com/blog/standard_deviation/ +@DeveloperAPI class RunningStat: def __init__(self, shape=None): self._n = 0 @@ -143,6 +147,7 @@ def shape(self): return self._M.shape +@DeveloperAPI class MeanStdFilter(Filter): """Keeps track of a running mean for seen states""" @@ -287,6 +292,7 @@ def __repr__(self) -> str: ) +@DeveloperAPI class ConcurrentMeanStdFilter(MeanStdFilter): is_concurrent = True @@ -321,6 +327,7 @@ def __repr__(self) -> str: ) +@DeveloperAPI def get_filter(filter_config, shape): # TODO(rliaw): move this into filter manager if filter_config == "MeanStdFilter": diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index ece32c3721ee..16a9995f9cd5 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -4,12 +4,14 @@ import sys from typing import Any, Optional +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TensorShape, TensorType logger = logging.getLogger(__name__) +@PublicAPI def try_import_jax(error: bool = False): """Tries importing JAX and FLAX and returns both modules (or Nones). @@ -41,6 +43,7 @@ def try_import_jax(error: bool = False): return jax, flax +@PublicAPI def try_import_tf(error: bool = False): """Tries importing tf and returns the module (or None). @@ -105,6 +108,7 @@ def try_import_tf(error: bool = False): return tf1_module, tf_module, version +@DeveloperAPI def tf_function(tf_module): """Conditional decorator for @tf.function. @@ -121,6 +125,7 @@ def decorator(func): return decorator +@PublicAPI def try_import_tfp(error: bool = False): """Tries importing tfp and returns the module (or None). @@ -148,19 +153,20 @@ def try_import_tfp(error: bool = False): # Fake module for torch.nn. -class NNStub: +class _NNStub: def __init__(self, *a, **kw): # Fake nn.functional module within torch.nn. self.functional = None - self.Module = ModuleStub + self.Module = _ModuleStub # Fake class for torch.nn.Module to allow it to be inherited from. -class ModuleStub: +class _ModuleStub: def __init__(self, *a, **kw): raise ImportError("Could not import `torch`.") +@PublicAPI def try_import_torch(error: bool = False): """Tries importing torch and returns the module (or None). @@ -193,10 +199,11 @@ def try_import_torch(error: bool = False): def _torch_stubs(): - nn = NNStub() + nn = _NNStub() return None, nn +@DeveloperAPI def get_variable( value: Any, framework: str = "tf", diff --git a/rllib/utils/from_config.py b/rllib/utils/from_config.py index f1ea842de0cc..872a318a7824 100644 --- a/rllib/utils/from_config.py +++ b/rllib/utils/from_config.py @@ -6,9 +6,11 @@ import re import yaml +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils import force_list, merge_dicts +@DeveloperAPI def from_config(cls, config=None, **kwargs): """Uses the given config to create an object. @@ -117,7 +119,7 @@ def from_config(cls, config=None, **kwargs): constructor = cls # Try the __type_registry__ of this class. else: - constructor = lookup_type(cls, type_) + constructor = _lookup_type(cls, type_) # Found in cls.__type_registry__. if constructor is not None: @@ -208,6 +210,7 @@ def from_config(cls, config=None, **kwargs): return object_ +@DeveloperAPI def from_file(cls, filename, *args, **kwargs): """ Create object from config saved in filename. Expects json or yaml file. @@ -233,7 +236,7 @@ def from_file(cls, filename, *args, **kwargs): return from_config(cls, config=config, **kwargs) -def lookup_type(cls, type_): +def _lookup_type(cls, type_): if ( cls is not None and hasattr(cls, "__type_registry__") diff --git a/rllib/utils/images.py b/rllib/utils/images.py index 30337e67c2e4..91e6bc610843 100644 --- a/rllib/utils/images.py +++ b/rllib/utils/images.py @@ -2,6 +2,8 @@ import numpy as np +from ray.rllib.utils.annotations import DeveloperAPI + logger = logging.getLogger(__name__) try: @@ -22,18 +24,21 @@ raise ModuleNotFoundError("Either scikit-image or opencv is required") +@DeveloperAPI def resize(img: np.ndarray, height: int, width: int) -> np.ndarray: if cv2: return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) return transform.resize(img, (height, width)) +@DeveloperAPI def rgb2gray(img: np.ndarray) -> np.ndarray: if cv2: return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) return color.rgb2gray(img) +@DeveloperAPI def imread(img_file: str) -> np.ndarray: if cv2: return cv2.imread(img_file).astype(np.float32) diff --git a/rllib/utils/metrics/learner_info.py b/rllib/utils/metrics/learner_info.py index d94c4c5c859e..6f96d6956ef6 100644 --- a/rllib/utils/metrics/learner_info.py +++ b/rllib/utils/metrics/learner_info.py @@ -3,6 +3,7 @@ import tree # pip install dm_tree from typing import Dict +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.typing import PolicyID @@ -13,6 +14,7 @@ LEARNER_STATS_KEY = "learner_stats" +@DeveloperAPI class LearnerInfoBuilder: def __init__(self, num_devices: int = 1): self.num_devices = num_devices @@ -43,7 +45,7 @@ def add_learn_on_batch_results( else: self.results_all_towers[policy_id].append( tree.map_structure_with_path( - lambda p, *s: all_tower_reduce(p, *s), + lambda p, *s: _all_tower_reduce(p, *s), *( results.pop("tower_{}".format(tower_num)) for tower_num in range(self.num_devices) @@ -82,13 +84,13 @@ def finalize(self): # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). info[policy_id] = tree.map_structure_with_path( - all_tower_reduce, *results_all_towers + _all_tower_reduce, *results_all_towers ) return info -def all_tower_reduce(path, *tower_data): +def _all_tower_reduce(path, *tower_data): """Reduces stats across towers based on their stats-dict paths.""" # TD-errors: Need to stay per batch item in order to be able to update # each item's weight in a prioritized replay buffer. diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index b7b297419ee0..2ca83356b63c 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -5,6 +5,7 @@ from types import MappingProxyType from typing import List, Optional +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.typing import SpaceStruct, TensorType, TensorStructType, Union @@ -22,6 +23,7 @@ MAX_LOG_NN_OUTPUT = 2 +@PublicAPI def aligned_array(size: int, dtype, align: int = 64) -> np.ndarray: """Returns an array of a given size that is 64-byte aligned. @@ -51,6 +53,7 @@ def aligned_array(size: int, dtype, align: int = 64) -> np.ndarray: return output +@PublicAPI def concat_aligned( items: List[np.ndarray], time_major: Optional[bool] = None ) -> np.ndarray: @@ -105,6 +108,7 @@ def concat_aligned( return np.concatenate(items, axis=1 if time_major else 0) +@PublicAPI def convert_to_numpy( x: TensorStructType, reduce_type: bool = True, reduce_floats=DEPRECATED_VALUE ): @@ -153,6 +157,7 @@ def mapping(item): return tree.map_structure(mapping, x) +@PublicAPI def fc( x: np.ndarray, weights: np.ndarray, @@ -194,6 +199,7 @@ def map_(data, transpose=False): return np.matmul(x, weights) + (0.0 if biases is None else biases) +@PublicAPI def flatten_inputs_to_1d_tensor( inputs: TensorStructType, spaces_struct: Optional[SpaceStruct] = None, @@ -302,6 +308,7 @@ def flatten_inputs_to_1d_tensor( return merged +@PublicAPI def make_action_immutable(obj): """Flags actions immutable to notify users when trying to change them. @@ -338,6 +345,7 @@ def make_action_immutable(obj): return obj +@PublicAPI def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray: """Reference: https://en.wikipedia.org/wiki/Huber_loss.""" return np.where( @@ -345,6 +353,7 @@ def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray: ) +@PublicAPI def l2_loss(x: np.ndarray) -> np.ndarray: """Computes half the L2 norm of a tensor (w/o the sqrt): sum(x**2) / 2. @@ -357,6 +366,7 @@ def l2_loss(x: np.ndarray) -> np.ndarray: return np.sum(np.square(x)) / 2.0 +@PublicAPI def lstm( x, weights: np.ndarray, @@ -426,6 +436,7 @@ def lstm( return unrolled_outputs, (c_states, h_states) +@PublicAPI def one_hot( x: Union[TensorType, int], depth: int = 0, @@ -488,6 +499,7 @@ def one_hot( return out +@PublicAPI def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray: """Implementation of the leaky ReLU function. @@ -503,6 +515,7 @@ def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray: return np.maximum(x, x * alpha, x) +@PublicAPI def sigmoid(x: np.ndarray, derivative: bool = False) -> np.ndarray: """ Returns the sigmoid function applied to x. @@ -522,6 +535,7 @@ def sigmoid(x: np.ndarray, derivative: bool = False) -> np.ndarray: return 1 / (1 + np.exp(-x)) +@PublicAPI def softmax( x: Union[np.ndarray, list], axis: int = -1, epsilon: Optional[float] = None ) -> np.ndarray: diff --git a/rllib/utils/pre_checks/env.py b/rllib/utils/pre_checks/env.py index 060cdd875b5d..b054c978cd97 100644 --- a/rllib/utils/pre_checks/env.py +++ b/rllib/utils/pre_checks/env.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Set from ray.actor import ActorHandle +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type from ray.rllib.utils.typing import EnvType from ray.util import log_once @@ -17,6 +18,7 @@ logger = logging.getLogger(__name__) +@DeveloperAPI def check_env(env: EnvType) -> None: """Run pre-checks on env that uncover common errors in environments. @@ -87,6 +89,7 @@ def check_env(env: EnvType) -> None: ) +@DeveloperAPI def check_gym_environments(env: gym.Env) -> None: """Checking for common errors in gym environments. @@ -195,6 +198,7 @@ def get_type(var): _check_info(info) +@DeveloperAPI def check_multiagent_environments(env: "MultiAgentEnv") -> None: """Checking for common errors in RLlib MultiAgentEnvs. @@ -283,6 +287,7 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None: raise ValueError(error) +@DeveloperAPI def check_base_env(env: "BaseEnv") -> None: """Checking for common errors in RLlib BaseEnvs. diff --git a/rllib/utils/pre_checks/multi_agent.py b/rllib/utils/pre_checks/multi_agent.py index 4871dc6ec0a0..e57fb1538060 100644 --- a/rllib/utils/pre_checks/multi_agent.py +++ b/rllib/utils/pre_checks/multi_agent.py @@ -3,12 +3,14 @@ from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.from_config import from_config from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, PartialTrainerConfigDict logger = logging.getLogger(__name__) +@DeveloperAPI def check_multi_agent( config: PartialTrainerConfigDict, ) -> Tuple[MultiAgentPolicyConfigDict, bool]: diff --git a/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py b/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py index b67c7235945d..6194bf50a49f 100644 --- a/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py @@ -13,7 +13,6 @@ from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import ( MultiAgentPrioritizedReplayBuffer, ) -from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ( merge_dicts_with_warning, @@ -47,16 +46,16 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer): ... replay_ratio=0.66) >>> buffer.add() >>> buffer.add() - >>> buffer.replay() + >>> buffer.sample(1) ... [, , ] >>> buffer.add() - >>> buffer.sample() + >>> buffer.sample(1) ... [, , ] >>> # or: [, , ], [, , ] or [, , ], >>> # but always as it is the newest sample >>> buffer.add() - >>> buffer.sample() + >>> buffer.sample(1) ... [, , ] >>> # or: [, , ], [, , ] or [, , ], etc.. >>> # but always as it is the newest sample @@ -80,7 +79,6 @@ def __init__( prioritized_replay_beta: float = 0.4, prioritized_replay_eps: float = 1e-6, learning_starts: int = 1000, - replay_batch_size: int = 1, replay_sequence_length: int = 1, replay_burn_in: int = 0, replay_zero_init_states: bool = True, @@ -100,13 +98,7 @@ def __init__( learning_starts: Number of timesteps after which a call to `replay()` will yield samples (before that, `replay()` will return None). - capacity: The capacity of the buffer. Note that when - `replay_sequence_length` > 1, this is the number of sequences - (not single timesteps) stored. - replay_batch_size: The batch size to be sampled (in timesteps). - Note that if `replay_sequence_length` > 1, - `self.replay_batch_size` will be set to the number of - sequences sampled (B). + capacity: The capacity of the buffer, measured in `storage_unit`. replay_sequence_length: The sequence length (T) of a single sample. If > 1, we will sample B x T from this buffer. replay_burn_in: The burn-in length in case @@ -157,7 +149,6 @@ def __init__( num_shards=num_shards, replay_mode="independent", learning_starts=learning_starts, - replay_batch_size=replay_batch_size, replay_sequence_length=replay_sequence_length, replay_burn_in=replay_burn_in, replay_zero_init_states=replay_zero_init_states, @@ -195,15 +186,7 @@ def add(self, batch: SampleBatchType, **kwargs) -> None: with self.add_batch_timer: if self._storage_unit == StorageUnit.TIMESTEPS: for policy_id, sample_batch in batch.policy_batches.items(): - if self.replay_sequence_length == 1: - timeslices = sample_batch.timeslices(1) - else: - timeslices = timeslice_along_seq_lens_with_overlap( - sample_batch=sample_batch, - zero_pad_max_seq_len=self.replay_sequence_length, - pre_overlap=self.replay_burn_in, - zero_init_states=self.replay_zero_init_states, - ) + timeslices = sample_batch.timeslices(1) for time_slice in timeslices: self.replay_buffers[policy_id].add(time_slice, **kwargs) self.last_added_batches[policy_id].append(time_slice) @@ -269,6 +252,9 @@ def sample( # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) + if self._num_added < self.replay_starts: + return MultiAgentBatch({}, 0) + def mix_batches(_policy_id): """Mixes old with new samples. diff --git a/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py b/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py index 24c06553c524..924f38c55da6 100644 --- a/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py @@ -36,7 +36,6 @@ def __init__( capacity: int = 10000, storage_unit: str = "timesteps", num_shards: int = 1, - replay_batch_size: int = 1, learning_starts: int = 1000, replay_mode: str = "independent", replay_sequence_length: int = 1, @@ -61,13 +60,7 @@ def __init__( learning_starts: Number of timesteps after which a call to `replay()` will yield samples (before that, `replay()` will return None). - capacity: The capacity of the buffer. Note that when - `replay_sequence_length` > 1, this is the number of sequences - (not single timesteps) stored. - replay_batch_size: The batch size to be sampled (in timesteps). - Note that if `replay_sequence_length` > 1, - `self.replay_batch_size` will be set to the number of - sequences sampled (B). + capacity: The capacity of the buffer, measured in `storage_unit`. prioritized_replay_alpha: Alpha parameter for a prioritized replay buffer. Use 0.0 for no prioritization. prioritized_replay_beta: Beta parameter for a prioritized @@ -132,7 +125,6 @@ def __init__( storage_unit, **kwargs, underlying_buffer_config=prioritized_replay_buffer_config, - replay_batch_size=replay_batch_size, learning_starts=learning_starts, replay_mode=replay_mode, replay_sequence_length=replay_sequence_length, @@ -164,49 +156,49 @@ def _add_to_underlying_buffer( # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self._storage_unit is StorageUnit.TIMESTEPS: - if self.replay_sequence_length == 1: - timeslices = batch.timeslices(1) - else: - timeslices = timeslice_along_seq_lens_with_overlap( - sample_batch=batch, - zero_pad_max_seq_len=self.replay_sequence_length, - pre_overlap=self.replay_burn_in, - zero_init_states=self.replay_zero_init_states, - ) - for time_slice in timeslices: - # If SampleBatch has prio-replay weights, average - # over these to use as a weight for the entire - # sequence. - if self.replay_mode is ReplayMode.INDEPENDENT: - if "weights" in time_slice and len(time_slice["weights"]): - weight = np.mean(time_slice["weights"]) - else: - weight = None - - if "weight" in kwargs and weight is not None: - if log_once("overwrite_weight"): - logger.warning( - "Adding batches with column " - "`weights` to this buffer while " - "providing weights as a call argument " - "to the add method results in the " - "column being overwritten." - ) - - kwargs = {"weight": weight, **kwargs} - else: - if "weight" in kwargs: - if log_once("lockstep_no_weight_allowed"): - logger.warning( - "Settings weights for batches in " - "lockstep mode is not allowed." - "Weights are being ignored." - ) - - kwargs = {**kwargs, "weight": None} - self.replay_buffers[policy_id].add(time_slice, **kwargs) + timeslices = batch.timeslices(1) + elif self._storage_unit is StorageUnit.SEQUENCES: + timeslices = timeslice_along_seq_lens_with_overlap( + sample_batch=batch, + zero_pad_max_seq_len=self.replay_sequence_length, + pre_overlap=self.replay_burn_in, + zero_init_states=self.replay_zero_init_states, + ) else: - self.replay_buffers[policy_id].add(batch, **kwargs) + timeslices = [batch] + + for time_slice in timeslices: + # If SampleBatch has prio-replay weights, average + # over these to use as a weight for the entire + # sequence. + if self.replay_mode is ReplayMode.INDEPENDENT: + if "weights" in time_slice and len(time_slice["weights"]): + weight = np.mean(time_slice["weights"]) + else: + weight = None + + if "weight" in kwargs and weight is not None: + if log_once("overwrite_weight"): + logger.warning( + "Adding batches with column " + "`weights` to this buffer while " + "providing weights as a call argument " + "to the add method results in the " + "column being overwritten." + ) + + kwargs = {"weight": weight, **kwargs} + else: + if "weight" in kwargs: + if log_once("lockstep_no_weight_allowed"): + logger.warning( + "Settings weights for batches in " + "lockstep mode is not allowed." + "Weights are being ignored." + ) + + kwargs = {**kwargs, "weight": None} + self.replay_buffers[policy_id].add(time_slice, **kwargs) @DeveloperAPI @override(PrioritizedReplayBuffer) diff --git a/rllib/utils/replay_buffers/multi_agent_replay_buffer.py b/rllib/utils/replay_buffers/multi_agent_replay_buffer.py index bed61ee245f4..af0dbff7cbb9 100644 --- a/rllib/utils/replay_buffers/multi_agent_replay_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_replay_buffer.py @@ -64,7 +64,6 @@ def __init__( capacity: int = 10000, storage_unit: str = "timesteps", num_shards: int = 1, - replay_batch_size: int = 1, learning_starts: int = 1000, replay_mode: str = "independent", replay_sequence_length: int = 1, @@ -84,25 +83,18 @@ def __init__( learning_starts: Number of timesteps after which a call to `sample()` will yield samples (before that, `sample()` will return None). - capacity: Max number of total timesteps in all policy buffers. - After reaching this number, older samples will be - dropped to make space for new ones. - replay_batch_size: The batch size to be sampled (in timesteps). - Note that if `replay_sequence_length` > 1, - `self.replay_batch_size` will be set to the number of - sequences sampled (B). + capacity: The capacity of the buffer, measured in `storage_unit`. replay_mode: One of "independent" or "lockstep". Determines, whether batches are sampled independently or to an equal amount. replay_sequence_length: The sequence length (T) of a single sample. If > 1, we will sample B x T from this buffer. This only has an effect if storage_unit is 'timesteps'. - replay_burn_in: The burn-in length in case - `replay_sequence_length` > 0. This is the number of timesteps + replay_burn_in: This is the number of timesteps each sequence overlaps with the previous one to generate a better internal state (=state after the burn-in), instead of starting from 0.0 each RNN rollout. This only has an effect - if storage_unit is 'timesteps'. + if storage_unit is `sequences`. replay_zero_init_states: Whether the initial states in the buffer (if replay_sequence_length > 0) are alwayas 0.0 or should be updated with the previous train_batch state outputs. @@ -122,25 +114,34 @@ def __init__( else: self.underlying_buffer_call_args = {} - if replay_sequence_length > 1 and self._storage_unit == "timesteps": - self.replay_batch_size = int( - max(1, replay_batch_size // replay_sequence_length) - ) - logger.info( - "Since replay_sequence_length={} and replay_batch_size={}, " - "we will replay {} sequences at a time.".format( - replay_sequence_length, replay_batch_size, self.replay_batch_size - ) - ) - else: - self.replay_batch_size = replay_batch_size - self.replay_starts = learning_starts // num_shards self.replay_mode = replay_mode self.replay_sequence_length = replay_sequence_length self.replay_burn_in = replay_burn_in self.replay_zero_init_states = replay_zero_init_states + if ( + replay_sequence_length > 1 + and self._storage_unit is not StorageUnit.SEQUENCES + ): + logger.warning( + "MultiAgentReplayBuffer configured with " + "`replay_sequence_length={}`, but `storage_unit={}`. " + "replay_sequence_length will be ignored and set to 1.".format( + replay_sequence_length, storage_unit + ) + ) + self.replay_sequence_length = 1 + + if replay_sequence_length == 1 and self._storage_unit is StorageUnit.SEQUENCES: + logger.warning( + "MultiAgentReplayBuffer configured with " + "`replay_sequence_length={}`, but `storage_unit={}`. " + "This will result in sequences equal to timesteps.".format( + replay_sequence_length, storage_unit + ) + ) + if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]: self.replay_mode = ReplayMode.LOCKSTEP if self._storage_unit in [StorageUnit.EPISODES, StorageUnit.SEQUENCES]: @@ -156,7 +157,7 @@ def __init__( if self.underlying_buffer_config: ctor_args = { - **{"capacity": shard_capacity, "storage_unit": storage_unit}, + **{"capacity": shard_capacity, "storage_unit": StorageUnit.FRAGMENTS}, **self.underlying_buffer_config, } @@ -169,7 +170,7 @@ def new_buffer(): self.underlying_buffer_call_args = {} return ReplayBuffer( self.capacity, - storage_unit=storage_unit, + storage_unit=StorageUnit.FRAGMENTS, ) self.replay_buffers = collections.defaultdict(new_buffer) @@ -184,12 +185,14 @@ def __len__(self) -> int: return sum(len(buffer._storage) for buffer in self.replay_buffers.values()) @DeveloperAPI - @Deprecated(old="replay", new="sample", error=False) + @Deprecated( + old="ReplayBuffer.replay()", + new="ReplayBuffer.sample(num_items)", + error=True, + ) def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]: """Deprecated in favor of new ReplayBuffer API.""" - if num_items is None: - num_items = self.replay_batch_size - return self.sample(num_items, **kwargs) + pass @DeveloperAPI @override(ReplayBuffer) @@ -252,6 +255,10 @@ def _add_to_underlying_buffer( # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self._storage_unit is StorageUnit.TIMESTEPS: + timeslices = batch.timeslices(1) + for time_slice in timeslices: + self.replay_buffers[policy_id].add(time_slice, **kwargs) + elif self._storage_unit is StorageUnit.SEQUENCES: if self.replay_sequence_length == 1: timeslices = batch.timeslices(1) else: @@ -263,8 +270,10 @@ def _add_to_underlying_buffer( ) for time_slice in timeslices: self.replay_buffers[policy_id].add(time_slice, **kwargs) - else: + elif self._storage_unit in [StorageUnit.FRAGMENTS, StorageUnit.EPISODES]: self.replay_buffers[policy_id].add(batch, **kwargs) + else: + raise ValueError("Unknown `storage_unit={}`".format(self._storage_unit)) @DeveloperAPI @override(ReplayBuffer) diff --git a/rllib/utils/replay_buffers/replay_buffer.py b/rllib/utils/replay_buffers/replay_buffer.py index 44c62950d266..5f370024e097 100644 --- a/rllib/utils/replay_buffers/replay_buffer.py +++ b/rllib/utils/replay_buffers/replay_buffer.py @@ -1,6 +1,6 @@ import logging import platform -from typing import Any, Dict, List, Optional, Callable +from typing import Any, Dict, List, Optional, Callable, Union import numpy as np import random @@ -11,7 +11,7 @@ import psutil # noqa E402 from ray.util.debug import log_once -from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.typing import SampleBatchType, T @@ -58,7 +58,10 @@ def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None: @DeveloperAPI class ReplayBuffer(ParallelIteratorWorker): def __init__( - self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs + self, + capacity: int = 10000, + storage_unit: Union[str, StorageUnit] = "timesteps", + **kwargs, ): """Initializes a (FIFO) ReplayBuffer instance. @@ -82,7 +85,7 @@ def __init__( else: raise ValueError( "storage_unit must be either 'timesteps', `sequences` or `episodes` " - "or `fragments`." + "or `fragments`, but is {}".format(storage_unit) ) # The actual storage (list of SampleBatches or MultiAgentBatches). @@ -119,12 +122,6 @@ def __init__( self.batch_size = None - def gen_replay(): - while True: - yield self.replay() - - ParallelIteratorWorker.__init__(self, gen_replay, False) - def __len__(self) -> int: """Returns the number of items currently stored in this buffer.""" return len(self._storage) @@ -141,18 +138,10 @@ def add(self, batch: SampleBatchType, **kwargs) -> None: batch: Batch to add to this buffer's storage. **kwargs: Forward compatibility kwargs. """ - assert batch.count > 0, batch - warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) + if not batch.count > 0: + return - if ( - type(batch) == MultiAgentBatch - and self._storage_unit != StorageUnit.TIMESTEPS - ): - raise ValueError( - "Can not add MultiAgentBatch to ReplayBuffer " - "with storage_unit {}" - "".format(str(self._storage_unit)) - ) + warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) if self._storage_unit == StorageUnit.TIMESTEPS: self._add_single_batch(batch, **kwargs) @@ -352,14 +341,34 @@ def apply( """ return func(self, *_args, **kwargs) - @Deprecated(old="ReplayBuffer.add_batch()", new="RepayBuffer.add()", error=False) + @Deprecated(old="ReplayBuffer.add_batch()", new="ReplayBuffer.add()", error=False) def add_batch(self, *args, **kwargs): return self.add(*args, **kwargs) @Deprecated( - old="RepayBuffer.replay(num_items)", - new="RepayBuffer.sample(" "num_items)", + old="ReplayBuffer.replay(num_items)", + new="ReplayBuffer.sample(num_items)", error=False, ) def replay(self, num_items): return self.sample(num_items) + + @Deprecated( + help="ReplayBuffers could be iterated over by default before. " + "Making a buffer an iterator will soon " + "be deprecated altogether. Consider switching to the training " + "iteration API to resolve this.", + error=False, + ) + def make_iterator(self, num_items_to_replay: int): + """Make this buffer a ParallelIteratorWorker to retain compatibility. + + Execution plans have made heavy use of buffers as ParallelIteratorWorkers. + This method provides an easy way to support this for now. + """ + + def gen_replay(): + while True: + yield self.sample(num_items_to_replay) + + ParallelIteratorWorker.__init__(self, gen_replay, False) diff --git a/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py index 2576b5bbb494..d7cc6c70dc0c 100644 --- a/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py @@ -48,7 +48,10 @@ def test_mixin_sampling_episodes(self): """Test sampling of episodes.""" # 50% replay ratio. buffer = MultiAgentMixInReplayBuffer( - capacity=self.capacity, storage_unit="episodes", replay_ratio=0.5 + capacity=self.capacity, + storage_unit="episodes", + replay_ratio=0.5, + learning_starts=0, ) # If we insert and replay n times, expect roughly return batches of @@ -69,7 +72,7 @@ def test_mixin_sampling_sequences(self): """Test sampling of sequences.""" # 50% replay ratio. buffer = MultiAgentMixInReplayBuffer( - capacity=100, storage_unit="sequences", replay_ratio=0.5 + capacity=100, storage_unit="sequences", replay_ratio=0.5, learning_starts=0 ) # If we insert and replay n times, expect roughly return batches of @@ -88,7 +91,10 @@ def test_mixin_sampling_timesteps(self): """Test different mixin ratios with timesteps.""" # 33% replay ratio. buffer = MultiAgentMixInReplayBuffer( - capacity=self.capacity, storage_unit="timesteps", replay_ratio=0.333 + capacity=self.capacity, + storage_unit="timesteps", + replay_ratio=0.333, + learning_starts=0, ) # Expect exactly 0 samples to be returned (buffer empty). sample = buffer.sample(10) @@ -120,7 +126,9 @@ def test_mixin_sampling_timesteps(self): self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2) # 90% replay ratio. - buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, replay_ratio=0.9) + buffer = MultiAgentMixInReplayBuffer( + capacity=self.capacity, replay_ratio=0.9, learning_starts=0 + ) # If we insert and replay n times, expect roughly return batches of # len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old @@ -134,7 +142,9 @@ def test_mixin_sampling_timesteps(self): self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2) # 0% replay ratio -> Only new samples. - buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, replay_ratio=0.0) + buffer = MultiAgentMixInReplayBuffer( + capacity=self.capacity, replay_ratio=0.0, learning_starts=0 + ) # Add a new batch. batch = self._generate_single_timesteps() buffer.add(batch) @@ -159,7 +169,9 @@ def test_mixin_sampling_timesteps(self): self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2) # 100% replay ratio -> Only new samples. - buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, replay_ratio=1.0) + buffer = MultiAgentMixInReplayBuffer( + capacity=self.capacity, replay_ratio=1.0, learning_starts=0 + ) # Expect exactly 0 samples to be returned (buffer empty). sample = buffer.sample(1) assert len(sample.policy_batches) == 0 diff --git a/rllib/utils/replay_buffers/utils.py b/rllib/utils/replay_buffers/utils.py index 3485eed5cd0d..418e008a8095 100644 --- a/rllib/utils/replay_buffers/utils.py +++ b/rllib/utils/replay_buffers/utils.py @@ -4,7 +4,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import deprecation_warning -from ray.rllib.utils.annotations import ExperimentalAPI +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.from_config import from_config from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY @@ -120,20 +120,18 @@ def sample_min_n_steps_from_buffer( train_batches = [] while train_batch_size < min_steps: batch = replay_buffer.sample(num_items=1) - if batch is None: - return None + batch_len = batch.agent_steps() if count_by_agent_steps else batch.env_steps() + if batch_len == 0: + # Replay has not started, so we can't accumulate timesteps here + return batch train_batches.append(batch) - train_batch_size += ( - train_batches[-1].agent_steps() - if count_by_agent_steps - else train_batches[-1].env_steps() - ) + train_batch_size += batch_len # All batch types are the same type, hence we can use any concat_samples() train_batch = SampleBatch.concat_samples(train_batches) return train_batch -@ExperimentalAPI +@DeveloperAPI def validate_buffer_config(config: dict): if config.get("replay_buffer_config", None) is None: config["replay_buffer_config"] = {} @@ -178,6 +176,22 @@ def validate_buffer_config(config: dict): help="config['replay_buffer_config']['replay_burn_in']", ) + replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE) + if replay_batch_size == DEPRECATED_VALUE: + replay_batch_size = config["replay_buffer_config"].get( + "replay_batch_size", DEPRECATED_VALUE + ) + if replay_batch_size != DEPRECATED_VALUE: + deprecation_warning( + old="config['replay_batch_size'] or config['replay_buffer_config'][" + "'replay_batch_size']", + help="Specification of replay_batch_size is not supported anymore but is " + "derived from `train_batch_size`. Specify the number of " + "items you want to replay upon calling the sample() method of replay " + "buffers if this does not work for you.", + error=True, + ) + # Deprecation of old-style replay buffer args # Warnings before checking of we need local buffer so that algorithms # Without local buffer also get warned @@ -186,7 +200,6 @@ def validate_buffer_config(config: dict): "prioritized_replay_beta", "prioritized_replay_eps", "no_local_replay_buffer", - "replay_batch_size", "replay_zero_init_states", "learning_starts", "replay_buffer_shards_colocated_with_driver", @@ -242,15 +255,6 @@ def validate_buffer_config(config: dict): "ray.rllib.utils.replay_buffers." + buffer_type ) - if config["replay_buffer_config"].get("replay_batch_size", None) is None: - # Fall back to train batch size if no replay batch size was provided - logger.info( - "No value for key `replay_batch_size` in replay_buffer_config. " - "config['replay_buffer_config']['replay_batch_size'] will be " - "automatically set to config['train_batch_size']" - ) - config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"] - # Instantiate a dummy buffer to fail early on misconfiguration and find out about # inferred buffer class dummy_buffer = from_config(buffer_type, config["replay_buffer_config"]) @@ -329,7 +333,7 @@ def patch_buffer_with_fake_sampling_method( ): fake_sample_output = SampleBatch(fake_sample_output).as_multi_agent() - def fake_sample(_: Any, __: Any = None, **kwargs) -> Optional[SampleBatchType]: + def fake_sample(_: Any = None, **kwargs) -> Optional[SampleBatchType]: """Always returns a predefined batch. Args: diff --git a/rllib/utils/schedules/constant_schedule.py b/rllib/utils/schedules/constant_schedule.py index 44f66d2888c5..9681c5167040 100644 --- a/rllib/utils/schedules/constant_schedule.py +++ b/rllib/utils/schedules/constant_schedule.py @@ -1,6 +1,6 @@ from typing import Optional -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.schedules.schedule import Schedule from ray.rllib.utils.typing import TensorType @@ -8,6 +8,7 @@ tf1, tf, tfv = try_import_tf() +@PublicAPI class ConstantSchedule(Schedule): """A Schedule where the value remains constant over time.""" diff --git a/rllib/utils/schedules/exponential_schedule.py b/rllib/utils/schedules/exponential_schedule.py index 3adfcc916fed..0c6571d928a9 100644 --- a/rllib/utils/schedules/exponential_schedule.py +++ b/rllib/utils/schedules/exponential_schedule.py @@ -1,6 +1,6 @@ from typing import Optional -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.schedules.schedule import Schedule from ray.rllib.utils.typing import TensorType @@ -8,6 +8,7 @@ torch, _ = try_import_torch() +@PublicAPI class ExponentialSchedule(Schedule): """Exponential decay schedule from `initial_p` to `final_p`. diff --git a/rllib/utils/schedules/linear_schedule.py b/rllib/utils/schedules/linear_schedule.py index 5cbbdb8f5ceb..df892548eac7 100644 --- a/rllib/utils/schedules/linear_schedule.py +++ b/rllib/utils/schedules/linear_schedule.py @@ -1,6 +1,8 @@ +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule +@PublicAPI class LinearSchedule(PolynomialSchedule): """Linear interpolation between `initial_p` and `final_p`. diff --git a/rllib/utils/schedules/piecewise_schedule.py b/rllib/utils/schedules/piecewise_schedule.py index fa34047614bd..21236b7b78e4 100644 --- a/rllib/utils/schedules/piecewise_schedule.py +++ b/rllib/utils/schedules/piecewise_schedule.py @@ -1,6 +1,6 @@ from typing import Callable, List, Optional, Tuple -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.schedules.schedule import Schedule from ray.rllib.utils.typing import TensorType @@ -12,6 +12,7 @@ def _linear_interpolation(left, right, alpha): return left + alpha * (right - left) +@PublicAPI class PiecewiseSchedule(Schedule): def __init__( self, diff --git a/rllib/utils/schedules/polynomial_schedule.py b/rllib/utils/schedules/polynomial_schedule.py index 9e69dcd955f4..cd72b69cea51 100644 --- a/rllib/utils/schedules/polynomial_schedule.py +++ b/rllib/utils/schedules/polynomial_schedule.py @@ -1,6 +1,6 @@ from typing import Optional -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.schedules.schedule import Schedule from ray.rllib.utils.typing import TensorType @@ -9,6 +9,7 @@ torch, _ = try_import_torch() +@PublicAPI class PolynomialSchedule(Schedule): """Polynomial interpolation between `initial_p` and `final_p`. diff --git a/rllib/utils/serialization.py b/rllib/utils/serialization.py new file mode 100644 index 000000000000..3160188de118 --- /dev/null +++ b/rllib/utils/serialization.py @@ -0,0 +1,163 @@ +import base64 +import gym +import io +import numpy as np +from typing import Dict +import zlib + +from ray.rllib.utils.annotations import DeveloperAPI + + +def _serialize_ndarray(array: np.ndarray) -> str: + """Pack numpy ndarray into Base64 encoded strings for serialization. + + This function uses numpy.save() instead of pickling to ensure + compatibility. + + Args: + array: numpy ndarray. + + Returns: + b64 escaped string. + """ + buf = io.BytesIO() + np.save(buf, array) + return base64.b64encode(zlib.compress(buf.getvalue())).decode("ascii") + + +def _deserialize_ndarray(b64_string: str) -> np.ndarray: + """Unpack b64 escaped string into numpy ndarray. + + This function assumes the unescaped bytes are of npy format. + + Args: + b64_string: Base64 escaped string. + + Returns: + numpy ndarray. + """ + return np.load(io.BytesIO(zlib.decompress(base64.b64decode(b64_string)))) + + +@DeveloperAPI +def gym_space_to_dict(space: gym.spaces.Space) -> Dict: + """Serialize a gym Space into JSON-serializable dict. + + Args: + space: gym.spaces.Space + + Returns: + Serialized JSON string. + """ + + def _box(sp: gym.spaces.Box) -> Dict: + return { + "space": "box", + "low": _serialize_ndarray(sp.low), + "high": _serialize_ndarray(sp.high), + "shape": sp._shape, # shape is a tuple. + "dtype": sp.dtype.str, + } + + def _discrete(sp: gym.spaces.Discrete) -> Dict: + d = { + "space": "discrete", + "n": sp.n, + } + # Offset is a relatively new Discrete space feature. + if hasattr(sp, "start"): + d["start"] = sp.start + return d + + def _multi_discrete(sp: gym.spaces.MultiDiscrete) -> Dict: + return { + "space": "multi-discrete", + "nvec": _serialize_ndarray(sp.nvec), + "dtype": sp.dtype.str, + } + + def _tuple(sp: gym.spaces.Tuple) -> Dict: + return { + "space": "tuple", + "spaces": [gym_space_to_dict(sp) for sp in sp.spaces], + } + + def _dict(sp: gym.spaces.Dict) -> Dict: + return { + "space": "dict", + "spaces": {k: gym_space_to_dict(sp) for k, sp in sp.spaces.items()}, + } + + if isinstance(space, gym.spaces.Box): + return _box(space) + elif isinstance(space, gym.spaces.Discrete): + return _discrete(space) + elif isinstance(space, gym.spaces.MultiDiscrete): + return _multi_discrete(space) + elif isinstance(space, gym.spaces.Tuple): + return _tuple(space) + elif isinstance(space, gym.spaces.Dict): + return _dict(space) + else: + raise ValueError("Unknown space type for serialization, ", type(space)) + + +@DeveloperAPI +def gym_space_from_dict(d: Dict) -> gym.spaces.Space: + """De-serialize a dict into gym Space. + + Args: + str: serialized JSON str. + + Returns: + De-serialized gym space. + """ + + def __common(d: Dict): + """Common updates to the dict before we use it to construct spaces""" + del d["space"] + if "dtype" in d: + d["dtype"] = np.dtype(d["dtype"]) + return d + + def _box(d: Dict) -> gym.spaces.Box: + d.update( + { + "low": _deserialize_ndarray(d["low"]), + "high": _deserialize_ndarray(d["high"]), + } + ) + return gym.spaces.Box(**__common(d)) + + def _discrete(d: Dict) -> gym.spaces.Discrete: + return gym.spaces.Discrete(**__common(d)) + + def _multi_discrete(d: Dict) -> gym.spaces.Discrete: + d.update( + { + "nvec": _deserialize_ndarray(d["nvec"]), + } + ) + return gym.spaces.MultiDiscrete(**__common(d)) + + def _tuple(d: Dict) -> gym.spaces.Discrete: + spaces = [gym_space_from_dict(sp) for sp in d["spaces"]] + return gym.spaces.Tuple(spaces=spaces) + + def _dict(d: Dict) -> gym.spaces.Discrete: + spaces = {k: gym_space_from_dict(sp) for k, sp in d["spaces"].items()} + return gym.spaces.Dict(spaces=spaces) + + space_map = { + "box": _box, + "discrete": _discrete, + "multi-discrete": _multi_discrete, + "tuple": _tuple, + "dict": _dict, + } + + space_type = d["space"] + if space_type not in space_map: + raise ValueError("Unknown space type for de-serialization, ", space_type) + + return space_map[space_type](d) diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index 41a9ac27feeb..0da6b0d1082a 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -4,12 +4,14 @@ import numpy as np import random +from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder logger = logging.getLogger(__name__) +@DeveloperAPI def standardized(array: np.ndarray): """Normalize the values in an array. @@ -22,6 +24,7 @@ def standardized(array: np.ndarray): return (array - array.mean()) / max(1e-4, array.std()) +@DeveloperAPI def minibatches(samples: SampleBatch, sgd_minibatch_size: int, shuffle: bool = True): """Return a generator yielding minibatches from a sample batch. @@ -65,6 +68,7 @@ def minibatches(samples: SampleBatch, sgd_minibatch_size: int, shuffle: bool = T yield samples.slice(i, j, si, sj) +@DeveloperAPI def do_minibatch_sgd( samples, policies, diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index a693b9ab1770..8e33a833e437 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, Union +@DeveloperAPI def flatten_space(space: gym.Space) -> List[gym.Space]: """Flattens a gym.Space into its primitive components. @@ -37,6 +38,7 @@ def _helper_flatten(space_, return_list): return ret +@DeveloperAPI def get_base_struct_from_space(space): """Returns a Tuple/Dict Space as native (equally structured) py tuple/dict. @@ -67,6 +69,7 @@ def _helper_struct(space_): return _helper_struct(space) +@DeveloperAPI def get_dummy_batch_for_space( space: gym.Space, batch_size: int = 32, @@ -144,6 +147,7 @@ def get_dummy_batch_for_space( ) +@DeveloperAPI def flatten_to_single_ndarray(input_): """Returns a single np.ndarray given a list/tuple of np.ndarrays. @@ -173,6 +177,7 @@ def flatten_to_single_ndarray(input_): return input_ +@DeveloperAPI def unbatch(batches_struct): """Converts input from (nested) struct of batches to batch of structs. @@ -210,6 +215,7 @@ def unbatch(batches_struct): return out +@DeveloperAPI def clip_action(action, action_space): """Clips all components in `action` according to the given Space. @@ -234,6 +240,7 @@ def map_(a, s): return tree.map_structure(map_, action, action_space) +@DeveloperAPI def unsquash_action(action, action_space_struct): """Unsquashes all components in `action` according to the given Space. @@ -277,6 +284,7 @@ def map_(a, s): return tree.map_structure(map_, action, action_space_struct) +@DeveloperAPI def normalize_action(action, action_space_struct): """Normalizes all (Box) components in `action` to be in [-1.0, 1.0]. diff --git a/rllib/utils/tests/test_serialization.py b/rllib/utils/tests/test_serialization.py new file mode 100644 index 000000000000..360624ec420a --- /dev/null +++ b/rllib/utils/tests/test_serialization.py @@ -0,0 +1,105 @@ +import gym +import numpy as np +import unittest + +from ray.rllib.utils.serialization import ( + gym_space_from_dict, + gym_space_to_dict, +) + + +def _assert_array_equal(eq, a1, a2, margin=None): + for a in zip(a1, a2): + eq(a[0], a[1], margin) + + +class TestGymCheckEnv(unittest.TestCase): + def test_box_space(self): + env = gym.make("CartPole-v0") + d = gym_space_to_dict(env.observation_space) + sp = gym_space_from_dict(d) + + obs_space = env.observation_space + _assert_array_equal( + self.assertAlmostEqual, sp.low.tolist(), obs_space.low.tolist(), 0.001 + ) + _assert_array_equal( + self.assertAlmostEqual, sp.high.tolist(), obs_space.high.tolist(), 0.001 + ) + _assert_array_equal(self.assertEqual, sp._shape, obs_space._shape) + self.assertEqual(sp.dtype, obs_space.dtype) + + def test_discrete_space(self): + env = gym.make("CartPole-v0") + d = gym_space_to_dict(env.action_space) + sp = gym_space_from_dict(d) + + action_space = env.action_space + self.assertEqual(sp.n, action_space.n) + + def test_multi_discrete_space(self): + md_space = gym.spaces.MultiDiscrete(nvec=np.array([3, 4, 5])) + d = gym_space_to_dict(md_space) + sp = gym_space_from_dict(d) + + _assert_array_equal(self.assertAlmostEqual, sp.nvec, md_space.nvec, 0.001) + self.assertEqual(md_space.dtype, sp.dtype) + + def test_tuple_space(self): + env = gym.make("CartPole-v0") + space = gym.spaces.Tuple(spaces=[env.observation_space, env.action_space]) + d = gym_space_to_dict(space) + sp = gym_space_from_dict(d) + + _assert_array_equal( + self.assertAlmostEqual, + sp.spaces[0].low.tolist(), + space.spaces[0].low.tolist(), + 0.001, + ) + _assert_array_equal( + self.assertAlmostEqual, + sp.spaces[0].high.tolist(), + space.spaces[0].high.tolist(), + 0.001, + ) + _assert_array_equal( + self.assertEqual, sp.spaces[0]._shape, space.spaces[0]._shape + ) + self.assertEqual(sp.dtype, space.dtype) + + self.assertEqual(sp.spaces[1].n, space.spaces[1].n) + + def test_dict_space(self): + env = gym.make("CartPole-v0") + space = gym.spaces.Dict( + spaces={"obs": env.observation_space, "action": env.action_space} + ) + d = gym_space_to_dict(space) + sp = gym_space_from_dict(d) + + _assert_array_equal( + self.assertAlmostEqual, + sp.spaces["obs"].low.tolist(), + space.spaces["obs"].low.tolist(), + 0.001, + ) + _assert_array_equal( + self.assertAlmostEqual, + sp.spaces["obs"].high.tolist(), + space.spaces["obs"].high.tolist(), + 0.001, + ) + _assert_array_equal( + self.assertEqual, sp.spaces["obs"]._shape, space.spaces["obs"]._shape + ) + self.assertEqual(sp.dtype, space.dtype) + + self.assertEqual(sp.spaces["action"].n, space.spaces["action"].n) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/tf_run_builder.py b/rllib/utils/tf_run_builder.py index 6f6abc82b0d0..8fc1ec907320 100644 --- a/rllib/utils/tf_run_builder.py +++ b/rllib/utils/tf_run_builder.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) -class TFRunBuilder: +class _TFRunBuilder: """Used to incrementally build up a TensorFlow run. This is particularly useful for batching ops from multiple different @@ -39,7 +39,7 @@ def add_fetches(self, fetches): def get(self, to_fetch): if self._executed is None: try: - self._executed = run_timeline( + self._executed = _run_timeline( self.session, self.fetches, self.debug_name, @@ -66,7 +66,7 @@ def get(self, to_fetch): _count = 0 -def run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None): +def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None): if feed_dict is None: feed_dict = {} diff --git a/rllib/utils/tf_utils.py b/rllib/utils/tf_utils.py index 0950cd2a0ac7..1cbb659c16e4 100644 --- a/rllib/utils/tf_utils.py +++ b/rllib/utils/tf_utils.py @@ -4,6 +4,7 @@ import tree # pip install dm_tree from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING, Union +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.typing import ( @@ -21,6 +22,7 @@ tf1, tf, tfv = try_import_tf() +@PublicAPI def explained_variance(y: TensorType, pred: TensorType) -> TensorType: """Computes the explained variance for a pair of labels and predictions. @@ -39,6 +41,7 @@ def explained_variance(y: TensorType, pred: TensorType) -> TensorType: return tf.maximum(-1.0, 1 - (diff_var / y_var)) +@PublicAPI def flatten_inputs_to_1d_tensor( inputs: TensorStructType, spaces_struct: Optional[SpaceStruct] = None, @@ -138,6 +141,7 @@ def flatten_inputs_to_1d_tensor( return merged +@PublicAPI def get_gpu_devices() -> List[str]: """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"]. @@ -160,6 +164,7 @@ def get_gpu_devices() -> List[str]: return [d.name for d in devices if "GPU" in d.device_type] +@PublicAPI def get_placeholder( *, space: Optional[gym.Space] = None, @@ -218,6 +223,7 @@ def get_placeholder( ) +@PublicAPI def get_tf_eager_cls_if_necessary( orig_cls: Type["TFPolicy"], config: PartialTrainerConfigDict ) -> Type["TFPolicy"]: @@ -242,6 +248,7 @@ class for. from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.eager_tf_policy import EagerTFPolicy + from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 # Create eager-class (if not already one). if hasattr(orig_cls, "as_eager") and not issubclass(orig_cls, EagerTFPolicy): @@ -256,11 +263,14 @@ class for. ) # Now that we know, policy is an eager one, add tracing, if necessary. - if config.get("eager_tracing") and issubclass(cls, EagerTFPolicy): + if config.get("eager_tracing") and issubclass( + cls, (EagerTFPolicy, EagerTFPolicyV2) + ): cls = cls.with_tracing() return cls +@PublicAPI def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: """Computes the huber loss for a given term and delta parameter. @@ -285,6 +295,7 @@ def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: ) +@PublicAPI def make_tf_callable( session_or_none: Optional["tf1.Session"], dynamic_shape: bool = False ) -> Callable: @@ -379,6 +390,7 @@ def _create_placeholders(path, value): return make_wrapper +@PublicAPI def minimize_and_clip( optimizer: LocalOptimizer, objective: TensorType, @@ -419,6 +431,7 @@ def minimize_and_clip( ] +@PublicAPI def one_hot(x: TensorType, space: gym.Space) -> TensorType: """Returns a one-hot tensor, given and int tensor and a space. @@ -464,6 +477,7 @@ def one_hot(x: TensorType, space: gym.Space) -> TensorType: raise ValueError("Unsupported space for `one_hot`: {}".format(space)) +@PublicAPI def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType: """Same as tf.reduce_mean() but ignores -inf values. @@ -481,6 +495,7 @@ def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorT ) +@PublicAPI def scope_vars( scope: Union[str, "tf1.VariableScope"], trainable_only: bool = False ) -> List["tf.Variable"]: @@ -502,6 +517,7 @@ def scope_vars( ) +@PublicAPI def zero_logps_from_actions(actions: TensorStructType) -> TensorType: """Helper function useful for returning dummy logp's (0) for some actions. diff --git a/rllib/utils/threading.py b/rllib/utils/threading.py index 866c277fd7d3..409bf37fa9cd 100644 --- a/rllib/utils/threading.py +++ b/rllib/utils/threading.py @@ -1,6 +1,9 @@ from typing import Callable +from ray.rllib.utils.annotations import DeveloperAPI + +@DeveloperAPI def with_lock(func: Callable) -> Callable: """Use as decorator (@withlock) around object methods that need locking. diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index be2f267d4ea8..9c40f28cc128 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -7,7 +7,7 @@ import warnings from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.annotations import Deprecated, PublicAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import SMALL_NUMBER from ray.rllib.utils.typing import ( @@ -28,6 +28,7 @@ FLOAT_MAX = 3.4e38 +@PublicAPI def apply_grad_clipping( policy: "TorchPolicy", optimizer: LocalOptimizer, loss: TensorType ) -> Dict[str, TensorType]: @@ -68,6 +69,7 @@ def atanh(x: TensorType) -> TensorType: ) +@PublicAPI def concat_multi_gpu_td_errors(policy: "TorchPolicy") -> Dict[str, TensorType]: """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy. @@ -122,6 +124,7 @@ def mapping(item): return tree.map_structure(mapping, stats) +@PublicAPI def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None): """Converts any struct to torch.Tensors. @@ -167,6 +170,7 @@ def mapping(item): return tree.map_structure(mapping, x) +@PublicAPI def explained_variance(y: TensorType, pred: TensorType) -> TensorType: """Computes the explained variance for a pair of labels and predictions. @@ -186,6 +190,7 @@ def explained_variance(y: TensorType, pred: TensorType) -> TensorType: return torch.max(min_, 1 - (diff_var / y_var))[0] +@PublicAPI def flatten_inputs_to_1d_tensor( inputs: TensorStructType, spaces_struct: Optional[SpaceStruct] = None, @@ -284,6 +289,7 @@ def flatten_inputs_to_1d_tensor( return merged +@PublicAPI def global_norm(tensors: List[TensorType]) -> TensorType: """Returns the global L2 norm over a list of tensors. @@ -302,6 +308,7 @@ def global_norm(tensors: List[TensorType]) -> TensorType: return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) +@PublicAPI def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: """Computes the huber loss for a given term and delta parameter. @@ -326,6 +333,7 @@ def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: ) +@PublicAPI def l2_loss(x: TensorType) -> TensorType: """Computes half the L2 norm over a tensor's values without the sqrt. @@ -340,6 +348,7 @@ def l2_loss(x: TensorType) -> TensorType: return 0.5 * torch.sum(torch.pow(x, 2.0)) +@PublicAPI def minimize_and_clip( optimizer: "torch.optim.Optimizer", clip_val: float = 10.0 ) -> None: @@ -360,6 +369,7 @@ def minimize_and_clip( torch.nn.utils.clip_grad_norm_(p.grad, clip_val) +@PublicAPI def one_hot(x: TensorType, space: gym.Space) -> TensorType: """Returns a one-hot tensor, given and int tensor and a space. @@ -408,6 +418,7 @@ def one_hot(x: TensorType, space: gym.Space) -> TensorType: raise ValueError("Unsupported space for `one_hot`: {}".format(space)) +@PublicAPI def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType: """Same as torch.mean() but ignores -inf values. @@ -423,6 +434,7 @@ def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorT return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis) +@PublicAPI def sequence_mask( lengths: TensorType, maxlen: Optional[int] = None, @@ -464,6 +476,7 @@ def sequence_mask( return mask +@PublicAPI def set_torch_seed(seed: Optional[int] = None) -> None: """Sets the torch random seed to the given value. @@ -483,6 +496,7 @@ def set_torch_seed(seed: Optional[int] = None) -> None: torch.backends.cudnn.deterministic = True +@PublicAPI def softmax_cross_entropy_with_logits( logits: TensorType, labels: TensorType, @@ -499,6 +513,7 @@ def softmax_cross_entropy_with_logits( return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1) +@PublicAPI class Swish(nn.Module): def __init__(self): super().__init__() diff --git a/src/mock/ray/core_worker/lease_policy.h b/src/mock/ray/core_worker/lease_policy.h index 6a0282797afe..9bdb30a219fe 100644 --- a/src/mock/ray/core_worker/lease_policy.h +++ b/src/mock/ray/core_worker/lease_policy.h @@ -30,7 +30,7 @@ class MockLocalityDataProviderInterface : public LocalityDataProviderInterface { MOCK_METHOD(absl::optional, GetLocalityData, (const ObjectID &object_id), - (override)); + (const override)); }; } // namespace core diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d2f51a33c220..fda8c7c736b7 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2427,7 +2427,7 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, {return_id}, [return_id, pinned_return_object](const Status &status, const rpc::PinObjectIDsReply &reply) { - if (!status.ok()) { + if (!status.ok() || !reply.successes(0)) { RAY_LOG(INFO) << "Failed to pin existing copy of the task return object " << return_id << ". This object may get evicted while there are still " diff --git a/src/ray/core_worker/lease_policy.h b/src/ray/core_worker/lease_policy.h index 6634f78a8062..fe8cf272fd4b 100644 --- a/src/ray/core_worker/lease_policy.h +++ b/src/ray/core_worker/lease_policy.h @@ -32,7 +32,8 @@ struct LocalityData { /// Interface for providers of locality data to the lease policy. class LocalityDataProviderInterface { public: - virtual absl::optional GetLocalityData(const ObjectID &object_id) = 0; + virtual absl::optional GetLocalityData( + const ObjectID &object_id) const = 0; virtual ~LocalityDataProviderInterface() {} }; diff --git a/src/ray/core_worker/object_recovery_manager.cc b/src/ray/core_worker/object_recovery_manager.cc index 3c7f8470a7ba..448a3bdce4f2 100644 --- a/src/ray/core_worker/object_recovery_manager.cc +++ b/src/ray/core_worker/object_recovery_manager.cc @@ -122,7 +122,7 @@ void ObjectRecoveryManager::PinExistingObjectCopy( {object_id}, [this, object_id, other_locations, node_id]( const Status &status, const rpc::PinObjectIDsReply &reply) { - if (status.ok()) { + if (status.ok() && reply.successes(0)) { // TODO(swang): Make sure that the node is still alive when // marking the object as pinned. RAY_CHECK(in_memory_store_->Put( diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 1ee295b640b1..f347b181532c 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -695,8 +695,6 @@ void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id, if (!it->second.OutOfScope(lineage_pinning_enabled_)) { if (check_node_alive_(raylet_id)) { it->second.pinned_at_raylet_id = raylet_id; - // We eagerly add the pinned location to the set of object locations. - AddObjectLocationInternal(it, raylet_id); } else { ReleasePlasmaObject(it); objects_to_recover_.push_back(object_id); @@ -1254,7 +1252,7 @@ bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id, } absl::optional ReferenceCounter::GetLocalityData( - const ObjectID &object_id) { + const ObjectID &object_id) const { absl::MutexLock lock(&mutex_); // Uses the reference table to return locality data for an object. auto it = object_id_refs_.find(object_id); @@ -1281,11 +1279,16 @@ absl::optional ReferenceCounter::GetLocalityData( // locations. // - If we don't own this object, this will contain a snapshot of the object locations // at future resolution time. - const auto &node_ids = it->second.locations; + auto node_ids = it->second.locations; + // Add location of the primary copy since the object must be there: either in memory or + // spilled. + if (it->second.pinned_at_raylet_id.has_value()) { + node_ids.emplace(it->second.pinned_at_raylet_id.value()); + } // We should only reach here if we have valid locality data to return. absl::optional locality_data( - {static_cast(object_size), node_ids}); + {static_cast(object_size), std::move(node_ids)}); return locality_data; } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 25f539069e2a..87bd3692e566 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -448,7 +448,7 @@ class ReferenceCounter : public ReferenceCounterInterface, /// /// \param[in] object_id Object whose locality data we want. /// \return Locality data. - absl::optional GetLocalityData(const ObjectID &object_id); + absl::optional GetLocalityData(const ObjectID &object_id) const; /// Report locality data for object. This is used by the FutureResolver to report /// locality data for borrowed refs. diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 992bf6a2a6d2..17edd8e63d0e 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -252,11 +252,14 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, const auto nested_refs = VectorFromProtobuf(return_object.nested_inlined_refs()); if (return_object.in_plasma()) { + // NOTE(swang): We need to add the location of the object before marking + // it as local in the in-memory store so that the data locality policy + // will choose the right raylet for any queued dependent tasks. + const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id()); + reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id); // Mark it as in plasma with a dummy object. RAY_CHECK( in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id)); - const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id()); - reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id); } else { // NOTE(swang): If a direct object was promoted to plasma, then we do not // record the node ID that it was pinned at, which means that we will not diff --git a/src/ray/core_worker/test/lease_policy_test.cc b/src/ray/core_worker/test/lease_policy_test.cc index 04b3a557a222..036b4316bfb7 100644 --- a/src/ray/core_worker/test/lease_policy_test.cc +++ b/src/ray/core_worker/test/lease_policy_test.cc @@ -40,15 +40,15 @@ class MockLocalityDataProvider : public LocalityDataProviderInterface { MockLocalityDataProvider(absl::flat_hash_map locality_data) : locality_data_(locality_data) {} - absl::optional GetLocalityData(const ObjectID &object_id) { + absl::optional GetLocalityData(const ObjectID &object_id) const { num_locality_data_fetches++; return locality_data_[object_id]; }; ~MockLocalityDataProvider() {} - int num_locality_data_fetches = 0; - absl::flat_hash_map locality_data_; + mutable int num_locality_data_fetches = 0; + mutable absl::flat_hash_map locality_data_; }; absl::optional MockNodeAddrFactory(const NodeID &node_id) { diff --git a/src/ray/core_worker/test/object_recovery_manager_test.cc b/src/ray/core_worker/test/object_recovery_manager_test.cc index cbfc92b5f785..a9884a547940 100644 --- a/src/ray/core_worker/test/object_recovery_manager_test.cc +++ b/src/ray/core_worker/test/object_recovery_manager_test.cc @@ -66,12 +66,15 @@ class MockRayletClient : public PinObjectsInterface { callbacks.push_back(callback); } - size_t Flush() { - size_t flushed = callbacks.size(); - for (const auto &callback : callbacks) { - callback(Status::OK(), rpc::PinObjectIDsReply()); + size_t Flush(bool success = true) { + std::list> callbacks_snapshot; + std::swap(callbacks_snapshot, callbacks); + size_t flushed = callbacks_snapshot.size(); + for (const auto &callback : callbacks_snapshot) { + rpc::PinObjectIDsReply reply; + reply.add_successes(success); + callback(Status::OK(), reply); } - callbacks.clear(); return flushed; } @@ -230,12 +233,15 @@ TEST_F(ObjectRecoveryManagerTest, TestPinNewCopy) { 0, true, /*add_local_ref=*/true); - std::vector addresses({rpc::Address()}); + std::vector addresses({rpc::Address(), rpc::Address()}); object_directory_->SetLocations(object_id, addresses); ASSERT_TRUE(manager_.RecoverObject(object_id)); ASSERT_TRUE(object_directory_->Flush() == 1); - ASSERT_TRUE(raylet_client_->Flush() == 1); + // First copy is evicted so pin fails. + ASSERT_TRUE(raylet_client_->Flush(false) == 1); + // Second copy is present so pin succeeds. + ASSERT_TRUE(raylet_client_->Flush(true) == 1); ASSERT_TRUE(failed_reconstructions_.empty()); ASSERT_EQ(task_resubmitter_->num_tasks_resubmitted, 0); } diff --git a/src/ray/core_worker/test/reference_count_test.cc b/src/ray/core_worker/test/reference_count_test.cc index c0e86ccea3a3..2b571d75f4d2 100644 --- a/src/ray/core_worker/test/reference_count_test.cc +++ b/src/ray/core_worker/test/reference_count_test.cc @@ -648,6 +648,7 @@ TEST_F(ReferenceCountTest, TestHandleObjectSpilled) { TEST_F(ReferenceCountTest, TestGetLocalityData) { ObjectID obj1 = ObjectID::FromRandom(); ObjectID obj2 = ObjectID::FromRandom(); + ObjectID obj3 = ObjectID::FromRandom(); NodeID node1 = NodeID::FromRandom(); NodeID node2 = NodeID::FromRandom(); rpc::Address address; @@ -696,6 +697,13 @@ TEST_F(ReferenceCountTest, TestGetLocalityData) { ASSERT_EQ(locality_data_obj1->nodes_containing_object, absl::flat_hash_set({node1})); + // Include spilled locations in locality data. + rc->RemoveObjectLocation(obj1, node1); + rc->HandleObjectSpilled(obj1, "spill_loc", node1); + locality_data_obj1 = rc->GetLocalityData(obj1); + ASSERT_EQ(locality_data_obj1->nodes_containing_object, + absl::flat_hash_set({node1})); + // Borrowed object with defined object size and at least one node location should // return valid locality data. rc->AddLocalReference(obj2, "file.py:43"); @@ -735,8 +743,25 @@ TEST_F(ReferenceCountTest, TestGetLocalityData) { auto locality_data_obj2_no_object_size = rc->GetLocalityData(obj2); ASSERT_FALSE(locality_data_obj2_no_object_size.has_value()); + // Primary copy location is always returned + // even if it's not in-memory (i.e. spilled). + rc->AddOwnedObject(obj3, + {}, + address, + "file2.py:43", + -1, + false, + /*add_local_ref=*/true); + rc->UpdateObjectSize(obj3, 101); + rc->UpdateObjectPinnedAtRaylet(obj3, node1); + auto locality_data_obj3 = rc->GetLocalityData(obj3); + ASSERT_TRUE(locality_data_obj3.has_value()); + ASSERT_EQ(locality_data_obj3->nodes_containing_object, + absl::flat_hash_set({node1})); + rc->RemoveLocalReference(obj1, nullptr); rc->RemoveLocalReference(obj2, nullptr); + rc->RemoveLocalReference(obj3, nullptr); } // Tests that we can get the owner address correctly for objects that we own, @@ -2625,6 +2650,7 @@ TEST_F(ReferenceCountLineageEnabledTest, TestPlasmaLocation) { ASSERT_TRUE(rc->IsPlasmaObjectPinnedOrSpilled(id, &owned_by_us, &pinned_at, &spilled)); ASSERT_TRUE(owned_by_us); ASSERT_FALSE(pinned_at.IsNil()); + ASSERT_TRUE(rc->GetObjectLocations(id)->empty()); rc->RemoveLocalReference(id, nullptr); ASSERT_FALSE(rc->IsPlasmaObjectPinnedOrSpilled(id, &owned_by_us, &pinned_at, &spilled)); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 8549eb71cd46..ef005771b6dc 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -346,6 +346,31 @@ TEST_F(TaskManagerTest, TestLineageEvicted) { ASSERT_FALSE(reference_counter_->HasReference(return_id)); } +TEST_F(TaskManagerTest, TestLocalityDataAdded) { + auto spec = CreateTaskHelper(1, {}); + auto return_id = spec.ReturnId(0); + auto node_id = NodeID::FromRandom(); + int object_size = 100; + store_->GetAsync(return_id, [&](std::shared_ptr obj) { + // By the time the return object is available to get, we should be able + // to get the locality data too. + auto locality_data = reference_counter_->GetLocalityData(return_id); + ASSERT_TRUE(locality_data.has_value()); + ASSERT_EQ(locality_data->object_size, object_size); + ASSERT_TRUE(locality_data->nodes_containing_object.contains(node_id)); + }); + + rpc::PushTaskReply reply; + auto return_object = reply.add_return_objects(); + return_object->set_object_id(return_id.Binary()); + return_object->set_in_plasma(true); + return_object->set_size(object_size); + rpc::Address worker_addr; + worker_addr.set_raylet_id(node_id.Binary()); + manager_.AddPendingTask(rpc::Address(), spec, "", 0); + manager_.CompletePendingTask(spec.TaskId(), reply, worker_addr); +} + // Test to make sure that the task spec and dependencies for an object are // pinned when lineage pinning is enabled in the ReferenceCounter. TEST_F(TaskManagerLineageTest, TestLineagePinned) { diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index 1955796f5b37..af02808b172e 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -1176,7 +1176,12 @@ void GcsActorManager::Initialize(const GcsInitData &gcs_init_data) { for (const auto &[actor_id, actor_table_data] : gcs_init_data.Actors()) { auto job_iter = jobs.find(actor_id.JobId()); auto is_job_dead = (job_iter == jobs.end() || job_iter->second.is_dead()); - if (actor_table_data.state() != ray::rpc::ActorTableData::DEAD && !is_job_dead) { + // We only load actors which are supposed to be alive: + // - Actors which are not dead. + // - Non-deatched actors whoes owner is alive. + // - Detached actors which lives even when their original owner is dead. + if (actor_table_data.state() != ray::rpc::ActorTableData::DEAD && + (!is_job_dead || actor_table_data.is_detached())) { const auto &iter = actor_task_specs.find(actor_id); RAY_CHECK(iter != actor_task_specs.end()); auto actor = std::make_shared(actor_table_data, iter->second); diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index ec1b308cdab2..7ed23b18a85c 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -169,7 +169,11 @@ message PinObjectIDsRequest { repeated bytes object_ids = 2; } -message PinObjectIDsReply {} +message PinObjectIDsReply { + // Whether pinning the corresponding object succeeded or not. + // Pin can fail if the object is already evicted. + repeated bool successes = 1; +} message GetNodeStatsRequest { // Whether to include memory stats. This could be large since it includes diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index 843effae6ea1..f49a869bb6fa 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -198,3 +198,20 @@ message DeploymentStatusInfo { message DeploymentStatusInfoList { repeated DeploymentStatusInfo deployment_status_infos = 1; } + +enum ApplicationStatus { + DEPLOYING = 0; + RUNNING = 1; + DEPLOY_FAILED = 2; +} + +message ApplicationStatusInfo { + ApplicationStatus status = 1; + string message = 2; + double deployment_timestamp = 3; +} + +message StatusOverview { + ApplicationStatusInfo app_status = 1; + DeploymentStatusInfoList deployment_statuses = 2; +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 4af64f1ebe0b..00e559bb509e 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2373,16 +2373,32 @@ void NodeManager::HandlePinObjectIDs(const rpc::PinObjectIDsRequest &request, } std::vector> results; if (!GetObjectsFromPlasma(object_ids, &results)) { - RAY_LOG(WARNING) - << "Failed to get objects that should have been in the object store. These " - "objects may have been evicted while there are still references in scope."; - // TODO(suquark): Maybe "Status::ObjectNotFound" is more accurate here. - send_reply_callback(Status::Invalid("Failed to get objects."), nullptr, nullptr); - return; + for (size_t i = 0; i < object_ids.size(); ++i) { + reply->add_successes(false); + } + } else { + RAY_CHECK_EQ(object_ids.size(), results.size()); + auto object_id_it = object_ids.begin(); + auto result_it = results.begin(); + while (object_id_it != object_ids.end()) { + if (*result_it == nullptr) { + RAY_LOG(DEBUG) << "Failed to get object in the object store: " << *object_id_it + << ". This should only happen when the owner tries to pin a " + << "secondary copy and it's evicted in the meantime"; + object_id_it = object_ids.erase(object_id_it); + result_it = results.erase(result_it); + reply->add_successes(false); + } else { + ++object_id_it; + ++result_it; + reply->add_successes(true); + } + } + // Wait for the object to be freed by the owner, which keeps the ref count. + local_object_manager_.PinObjectsAndWaitForFree( + object_ids, std::move(results), owner_address); } - // Wait for the object to be freed by the owner, which keeps the ref count. - local_object_manager_.PinObjectsAndWaitForFree( - object_ids, std::move(results), owner_address); + RAY_CHECK_EQ(request.object_ids_size(), reply->successes_size()); send_reply_callback(Status::OK(), nullptr, nullptr); }