diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b701079..7f8127f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,19 +7,16 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer -- repo: https://github.com/pycqa/pylint - rev: v3.3.1 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.7 hooks: - - id: pylint - args: - - "--rcfile=.pylintrc" + - id: ruff + args: [ + "--config=ruff.toml", + "--fix", + ] exclude: tests(/\w*)*/ -- repo: https://github.com/google/yapf - rev: v0.40.2 - hooks: - - id: yapf - - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.13.0 hooks: diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index f8bf8d2..0000000 --- a/.pylintrc +++ /dev/null @@ -1,430 +0,0 @@ -# This Pylint rcfile contains a best-effort configuration to uphold the -# best-practices and style described in the Google Python style guide: -# https://google.github.io/styleguide/pyguide.html -# -# Its canonical open-source location is: -# https://google.github.io/styleguide/pylintrc - -[MASTER] - -# Files or directories to be skipped. They should be base names, not paths. -ignore=third_party - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns= - -# Pickle collected data for later comparisons. -persistent=no - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Use multiple processes to speed up Pylint. -jobs=4 - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -#enable= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=abstract-method, - apply-builtin, - arguments-differ, - attribute-defined-outside-init, - backtick, - bad-option-value, - basestring-builtin, - buffer-builtin, - c-extension-no-member, - consider-using-enumerate, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, - delslice-method, - div-method, - duplicate-code, - eq-without-hash, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, - fixme, - getslice-method, - global-statement, - hex-method, - idiv-method, - implicit-str-concat, - import-error, - import-self, - import-star-module-level, - inconsistent-return-statements, - input-builtin, - intern-builtin, - invalid-str-codec, - locally-disabled, - long-builtin, - long-suffix, - map-builtin-not-iterating, - misplaced-comparison-constant, - missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, - no-else-break, - no-else-continue, - no-else-raise, - no-else-return, - no-init, # added - no-member, - no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, - raising-string, - range-builtin-not-iterating, - raw_input-builtin, - rdiv-method, - reduce-builtin, - relative-import, - reload-builtin, - round-builtin, - setslice-method, - signature-differs, - standarderror-builtin, - suppressed-message, - sys-max-int, - too-few-public-methods, - too-many-ancestors, - too-many-arguments, - too-many-boolean-expressions, - too-many-branches, - too-many-instance-attributes, - too-many-locals, - too-many-nested-blocks, - too-many-public-methods, - too-many-return-statements, - too-many-statements, - trailing-newlines, - unichr-builtin, - unicode-builtin, - unnecessary-pass, - unpacking-in-except, - useless-else-on-loop, - useless-object-inheritance, - useless-suppression, - using-cmp-argument, - wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, - too-many-positional-arguments - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl - -# Regular expression matching correct function names -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression matching correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct constant names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression matching correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class attribute names -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct inline iteration names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression matching correct module names -module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ - -# Regular expression matching correct method names -method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=10 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=80 - -# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt -# lines made too long by directives to pytype. - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x)( - ^\s*(\#\ )??$| - ^\s*(from\s+\S+\s+)?import\s+.+$) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. The internal Google style guide mandates 2 -# spaces. Google's externaly-published style guide says 4, consistent with -# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google -# projects (like TensorFlow). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=TODO - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging,absl.logging,tensorflow.io.logging - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub, - TERMIOS, - Bastion, - rexec, - sets - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant, absl - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls, - class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 5004558..0000000 --- a/.style.yapf +++ /dev/null @@ -1,8 +0,0 @@ -[style] -based_on_style: pep8 -column_limit: 80 -indent_width: 2 -split_before_named_assigns: True -spaces_around_power_operator: True -dedent_closing_brackets: True -coalesce_brackets: True diff --git a/.vscode/settings.json b/.vscode/settings.json index 67aa94a..b06e040 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,7 +6,7 @@ "pylint.path": ["conda", "run", "-n", "rltorch", "python", "-m", "pylint"], "yapf.args": ["--style", "${workspaceFolder}/.style.yapf"], "[python]": { - "editor.defaultFormatter": "eeyore.yapf", + "editor.defaultFormatter": "charliermarsh.ruff", }, "editor.tokenColorCustomizations": { "[*Dark*]": { diff --git a/AC/__init__.py b/AC/__init__.py index 8875f60..b7b77a8 100644 --- a/AC/__init__.py +++ b/AC/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/AC/a2c.py b/AC/a2c.py index 0073051..1c5ddee 100644 --- a/AC/a2c.py +++ b/AC/a2c.py @@ -6,8 +6,8 @@ from torch.nn import functional as F from util.agent import Agent -from util.buffer import ReplayBuffer, Trajectory from util.algo import calc_gaes, calc_nstep_return, standardize +from util.buffer import ReplayBuffer, Trajectory class Actor(nn.Module): diff --git a/AC/main.py b/AC/main.py index 63207d5..295112c 100644 --- a/AC/main.py +++ b/AC/main.py @@ -1,16 +1,18 @@ """main executable file for A2C""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience, Trajectory -from collections import deque +import torch + # pylint: disable=invalid-name from AC.a2c import A2CAgent as A2C_torch +from util import generate_gif +from util.buffer import Experience, Trajectory +from util.wrappers import TrainMonitor Agent = A2C_torch logging.basicConfig(level=logging.INFO) @@ -20,7 +22,6 @@ EPSILON_DECAY_STEPS = 100 - def main( n_episodes=20000, max_t=500, @@ -30,7 +31,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/AWR/__init__.py b/AWR/__init__.py index 8875f60..b7b77a8 100644 --- a/AWR/__init__.py +++ b/AWR/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/AWR/awr.py b/AWR/awr.py index 295541a..c1d9e74 100644 --- a/AWR/awr.py +++ b/AWR/awr.py @@ -1,5 +1,6 @@ """AWR implementation with pytorch.""" from functools import partial + import numpy as np import torch from torch import nn @@ -7,8 +8,8 @@ from torch.nn import functional as F from util.agent import Agent +from util.algo import scale_down_values, scale_up_values, standardize from util.buffer import ReplayBuffer, Trajectory -from util.algo import standardize, scale_down_values, scale_up_values class Actor(nn.Module): @@ -21,7 +22,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=0.01, init_bias=0 ): @@ -71,7 +72,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_value_weight_gain=1, init_bias=0 ): diff --git a/AWR/main.py b/AWR/main.py index 7700406..3052617 100644 --- a/AWR/main.py +++ b/AWR/main.py @@ -1,16 +1,18 @@ """main executable file for awr""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience, Trajectory -from collections import deque +import torch + # pylint: disable=invalid-name from AWR.awr import AWRAgent as AWR_torch +from util import generate_gif +from util.buffer import Experience, Trajectory +from util.wrappers import TrainMonitor Agent = AWR_torch logging.basicConfig(level=logging.INFO) @@ -30,7 +32,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/C51/__init__.py b/C51/__init__.py index 8875f60..b7b77a8 100644 --- a/C51/__init__.py +++ b/C51/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/C51/c51.py b/C51/c51.py index b8e2550..bbe593f 100644 --- a/C51/c51.py +++ b/C51/c51.py @@ -3,9 +3,9 @@ import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience +from util.buffer import Experience, ReplayBuffer class Q(nn.Module): @@ -146,7 +146,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= @@ -211,7 +210,7 @@ def calc_q_target(self, rewards, next_states, terminate): ) # bⱼ ∈ [0, N - 1], shape: (batch, n_atom) b_j = (tau_z - self.v_min) / self.delta - l = torch.floor(b_j).clamp(max=self.n_atoms - 1).long() + l = torch.floor(b_j).clamp(max=self.n_atoms - 1).long() # noqa: E741 u = torch.ceil(b_j).clamp(max=self.n_atoms - 1 ).long() # Prevent out of bounds # m: shape (batch, n_atoms) @@ -231,11 +230,11 @@ def calc_q_target(self, rewards, next_states, terminate): delta_m_l = p_j * (u - b_j) delta_m_u = p_j * (b_j - l) - # mₗ ← mₗ + pⱼ(xt+1, a*)(u − bj ) + # mₗ ← mₗ + pⱼ(xt+1, a*)(u - bj ) m.scatter_add_(1, l, delta_m_l) - # mᵤ ← mᵤ + pⱼ(xt+1, a*)(bj − l) + # mᵤ ← mᵤ + pⱼ(xt+1, a*)(bj - l) m.scatter_add_(1, u, delta_m_u) return m - def update_targe_q(self): + def update_target_q(self): self.qnetwork_target.load_state_dict(self.qnetwork_local.state_dict()) diff --git a/C51/main.py b/C51/main.py index 810e53a..f319a79 100644 --- a/C51/main.py +++ b/C51/main.py @@ -1,17 +1,19 @@ """main executable file for Distribution Q learning.""" -import os -import math import logging +import math +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from C51.c51 import C51Agent as C51_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor Agent = C51_torch logging.basicConfig(level=logging.INFO) @@ -31,7 +33,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params @@ -83,7 +84,7 @@ def main( score += reward if (t * i_episode) % update_q_target_freq: - agent.update_targe_q() + agent.update_target_q() if done or score_term_rules(score): break diff --git a/DDPG/__init__.py b/DDPG/__init__.py index 8875f60..b7b77a8 100644 --- a/DDPG/__init__.py +++ b/DDPG/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/DDPG/ddpg.py b/DDPG/ddpg.py index c5aebca..c788c41 100644 --- a/DDPG/ddpg.py +++ b/DDPG/ddpg.py @@ -3,9 +3,9 @@ import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience +from util.buffer import Experience, ReplayBuffer from util.dist import OrnsteinUhlenbeckNoise @@ -20,7 +20,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -73,7 +73,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -235,7 +235,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/DDPG/main.py b/DDPG/main.py index 6a40765..ddad28b 100644 --- a/DDPG/main.py +++ b/DDPG/main.py @@ -1,16 +1,19 @@ """main executable file for DDPG""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from DDPG.ddpg import DDPGAgent as DDPG_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = DDPG_torch @@ -31,7 +34,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/DDQN/__init__.py b/DDQN/__init__.py index 8875f60..b7b77a8 100644 --- a/DDQN/__init__.py +++ b/DDQN/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/DDQN/ddqn.py b/DDQN/ddqn.py index 05d7ca9..1a9834f 100644 --- a/DDQN/ddqn.py +++ b/DDQN/ddqn.py @@ -3,9 +3,9 @@ import torch from torch import nn from torch.nn import functional as F + from util.agent import Agent -from util.buffer import Experience -from util.buffer import ProportionalPrioritizedReplayBuffer +from util.buffer import Experience, ProportionalPrioritizedReplayBuffer class Q(nn.Module): @@ -115,7 +115,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= @@ -153,7 +152,7 @@ def _learn(self, experiences): ) with torch.no_grad(): - # r + (1 − done) × γ × Q(state, argmax Q(state', a')) + # r + (1 - done) × γ × Q(state, argmax Q(state', a')) # noqa: RUF003 next_wanted_action = F.one_hot( (torch.argmax(self.qnetwork_local.forward(next_states), dim=1)), self.action_space @@ -184,5 +183,5 @@ def _learn(self, experiences): ) self.optimizer.step() - def update_targe_q(self): + def update_target_q(self): self.qnetwork_target.load_state_dict(self.qnetwork_local.state_dict()) diff --git a/DDQN/main.py b/DDQN/main.py index eda704d..c631ff4 100644 --- a/DDQN/main.py +++ b/DDQN/main.py @@ -1,17 +1,20 @@ """main executable file for DDQN""" +import logging import math import os -import logging +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from DDQN.ddqn import DDQNAgent as DDQN_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = DDQN_torch @@ -32,7 +35,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params @@ -80,7 +82,7 @@ def main( score += reward if (t * i_episode) % update_q_target_freq: - agent.update_targe_q() + agent.update_target_q() if done or score_term_rules(score): break diff --git a/DQN/__init__.py b/DQN/__init__.py index 8875f60..b7b77a8 100644 --- a/DQN/__init__.py +++ b/DQN/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/DQN/dqn.py b/DQN/dqn.py index 958cd68..8da4725 100644 --- a/DQN/dqn.py +++ b/DQN/dqn.py @@ -3,9 +3,9 @@ import torch from torch import nn from torch.nn import functional as F + from util.agent import Agent -from util.buffer import Experience -from util.buffer import ProportionalPrioritizedReplayBuffer +from util.buffer import Experience, ProportionalPrioritizedReplayBuffer class Q(nn.Module): @@ -115,7 +115,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= @@ -153,7 +152,7 @@ def _learn(self, experiences): ) with torch.no_grad(): - # r + (1 − done) × γ × max(Q(state)) + # r + (1 − done) × γ × max(Q(state)) # noqa: RUF003 labels = rewards + (1 - terminate) * self.gamma * torch.max( self.qnetwork_target.forward(next_states).detach(), dim=1, @@ -179,5 +178,5 @@ def _learn(self, experiences): ) self.optimizer.step() - def update_targe_q(self): + def update_target_q(self): self.qnetwork_target.load_state_dict(self.qnetwork_local.state_dict()) diff --git a/DQN/main.py b/DQN/main.py index cb8933a..31ca942 100644 --- a/DQN/main.py +++ b/DQN/main.py @@ -1,17 +1,20 @@ """main executable file for DQN""" -import os -import math import logging +import math +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from DQN.dqn import DQNAgent as DQN_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = DQN_torch @@ -32,7 +35,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params @@ -79,7 +81,7 @@ def main( score += reward if (t * i_episode) % update_q_target_freq: - agent.update_targe_q() + agent.update_target_q() if done or score_term_rules(score): break diff --git a/MPO/__init__.py b/MPO/__init__.py index 8875f60..b7b77a8 100644 --- a/MPO/__init__.py +++ b/MPO/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/MPO/main.py b/MPO/main.py index 5cba98a..5c132c5 100644 --- a/MPO/main.py +++ b/MPO/main.py @@ -1,16 +1,18 @@ """main executable file for mpo""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience, Trajectory -from collections import deque +import torch + # pylint: disable=invalid-name from MPO.mpo import MPOAgent as MPO_torch +from util import generate_gif +from util.buffer import Experience, Trajectory +from util.wrappers import TrainMonitor Agent = MPO_torch logging.basicConfig(level=logging.INFO) @@ -30,7 +32,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/MPO/mpo.py b/MPO/mpo.py index 578e03e..0a63b76 100644 --- a/MPO/mpo.py +++ b/MPO/mpo.py @@ -1,9 +1,9 @@ """MPO implementation with pytorch.""" + import numpy as np import torch from torch import nn -from torch.distributions import kl_divergence -from torch.distributions import Categorical +from torch.distributions import Categorical, kl_divergence from torch.nn import functional as F from util.agent import Agent @@ -13,29 +13,29 @@ class Actor(nn.Module): - """ Actor (Policy) Model.""" + """Actor (Policy) Model.""" def __init__( - self, - state_dim, - action_space, - seed=0, - fc1_unit=256, - fc2_unit=256, - init_weight_gain=np.sqrt(2), - init_policy_weight_gain=0.01, - init_bias=0 + self, + state_dim, + action_space, + seed=0, + fc1_unit=256, + fc2_unit=256, + init_weight_gain=np.sqrt(2), # noqa: B008 + init_policy_weight_gain=0.01, + init_bias=0, ): """ - Initialize parameters and build model. - Params - ======= - state_size (int): Dimension of each state - action_size (int): Dimension of each action - seed (int): Random seed - fc1_unit (int): Number of nodes in first hidden layer - fc2_unit (int): Number of nodes in second hidden layer - """ + Initialize parameters and build model. + Params + ======= + state_size (int): Dimension of each state + action_size (int): Dimension of each action + seed (int): Random seed + fc1_unit (int): Number of nodes in first hidden layer + fc2_unit (int): Number of nodes in second hidden layer + """ super().__init__() ## calls __init__ method of nn.Module class self.seed = torch.manual_seed(seed) self.fc1 = nn.Linear(state_dim, fc1_unit) @@ -52,8 +52,8 @@ def __init__( def forward(self, x): """ - Build a network that maps state -> action values. - """ + Build a network that maps state -> action values. + """ x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) pi = F.softmax(self.fc_policy(x), dim=1) @@ -61,29 +61,29 @@ def forward(self, x): class Critic(nn.Module): - """ Critic (Policy) Model.""" + """Critic (Policy) Model.""" def __init__( - self, - state_dim, - action_space=1, - seed=0, - fc1_unit=256, - fc2_unit=256, - init_weight_gain=np.sqrt(2), - init_value_weight_gain=1, - init_bias=0 + self, + state_dim, + action_space=1, + seed=0, + fc1_unit=256, + fc2_unit=256, + init_weight_gain=np.sqrt(2), # noqa: B008 + init_value_weight_gain=1, + init_bias=0, ): """ - Initialize parameters and build model. - Params - ======= - state_size (int): Dimension of each state - action_size (int): Dimension of each action - seed (int): Random seed - fc1_unit (int): Number of nodes in first hidden layer - fc2_unit (int): Number of nodes in second hidden layer - """ + Initialize parameters and build model. + Params + ======= + state_size (int): Dimension of each state + action_size (int): Dimension of each action + seed (int): Random seed + fc1_unit (int): Number of nodes in first hidden layer + fc2_unit (int): Number of nodes in second hidden layer + """ super().__init__() ## calls __init__ method of nn.Module class self.seed = torch.manual_seed(seed) self.action_space = action_space @@ -116,31 +116,30 @@ class MPOAgent(Agent): """Interacts with and learns form environment.""" def __init__( - self, - state_dims, - action_space, - gamma=0.99, - lr_actor=0.001, - lr_critic=0.001, - batch_size=64, - epsilon=0.01, - mem_size=None, - forget_experience=True, - grad_clip=0.5, - init_eta=1.0, - lr_eta=0.001, - eta_epsilon=0.1, - action_sample_round=10, - kl_epsilon=0.01, - kl_alpha=1., - kl_alpha_max=1.0, - kl_clip_min=0.0, - kl_clip_max=1.0, - improved_policy_iteration=5, - update_tau=0.005, - seed=0, + self, + state_dims, + action_space, + gamma=0.99, + lr_actor=0.001, + lr_critic=0.001, + batch_size=64, + epsilon=0.01, + mem_size=None, + forget_experience=True, + grad_clip=0.5, + init_eta=1.0, + lr_eta=0.001, + eta_epsilon=0.1, + action_sample_round=10, + kl_epsilon=0.01, + kl_alpha=1.0, + kl_alpha_max=1.0, + kl_clip_min=0.0, + kl_clip_max=1.0, + improved_policy_iteration=5, + update_tau=0.005, + seed=0, ): - self.state_dims = state_dims self.action_space = action_space self.gamma = gamma @@ -156,7 +155,7 @@ def __init__( self.action_sample_round = action_sample_round self.kl_epsilon = kl_epsilon self.kl_alpha_scaler = kl_alpha - self.kl_alpha = torch.tensor(0., requires_grad=False).to(device) + self.kl_alpha = torch.tensor(0.0, requires_grad=False).to(device) self.kl_clip_min = kl_clip_min self.kl_clip_max = kl_clip_max self.kl_alpha_max = kl_alpha_max @@ -168,17 +167,13 @@ def __init__( self.actor_target = Actor(state_dims, action_space).to(device) self.actor_target.load_state_dict(self.actor.state_dict()) - #Q Network + # Q Network self.critic = Critic(self.state_dims, self.action_space).to(device) self.critic_target = Critic(self.state_dims, self.action_space).to(device) self.critic_target.load_state_dict(self.critic.state_dict()) - self.actor_optimizer = torch.optim.Adam( - self.actor.parameters(), lr=self.lr_actor - ) - self.critic_optimizer = torch.optim.Adam( - self.critic.parameters(), lr=self.lr_critic - ) + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr_actor) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.lr_critic) self.eta = torch.tensor(init_eta).to(device) self.eta.requires_grad = True @@ -198,9 +193,7 @@ def learn(self, iteration: int = 10, replace=True): polcy_loss = [] val_loss = [] eta_loss = [] - trajectories = self.memory.sample_from( - num_samples=iteration, replace=replace - ) + trajectories = self.memory.sample_from(num_samples=iteration, replace=replace) if not trajectories: return polcy_loss, val_loss, eta_loss for trajectory in trajectories: @@ -209,15 +202,9 @@ def learn(self, iteration: int = 10, replace=True): val_loss.append(val_loss_.cpu().data.numpy()) eta_loss.append(eta_loss_.cpu().data.numpy()) - return ( - np.array(polcy_loss).mean(), - np.array(val_loss).mean(), - np.array(eta_loss).mean(), - ) + return (np.array(polcy_loss).mean(), np.array(val_loss).mean(), np.array(eta_loss).mean()) - def policy_evaluation( - self, states, actions, rewards, next_states, terminates - ): + def policy_evaluation(self, states, actions, rewards, next_states, terminates): self.critic.train() self.critic_target.eval() @@ -248,14 +235,9 @@ def find_action_weights(self, states, actions): sample_action = action_dist.sample().reshape(actions.shape) sample_actions.append(sample_action) sample_actions = torch.cat(sample_actions, dim=0) # shape [BxN, action_dim] - tiled_states = torch.tile( - states, (self.action_sample_round, 1) - ) # shape [BxN, state_dim] - target_q = self.critic_target.forward( - tiled_states, sample_actions - ) # shape [BxN, 1] - target_q = target_q.reshape(-1, self.action_sample_round).detach( - ) # shape [B, N] + tiled_states = torch.tile(states, (self.action_sample_round, 1)) # shape [BxN, state_dim] + target_q = self.critic_target.forward(tiled_states, sample_actions) # shape [BxN, 1] + target_q = target_q.reshape(-1, self.action_sample_round).detach() # shape [B, N] # η = argmin[ η * ε + η * Σₖ ( 1/K * log(Σₙ( 1/N * exp(Q(sₙ, aₖ)) / η)) ) ] # This is for numberic stability. @@ -265,9 +247,11 @@ def find_action_weights(self, states, actions): # ).mean() max_q = target_q.max(dim=-1, keepdim=True).values - eta_loss = self.eta * self.eta_epsilon + self.eta * torch.log( - torch.exp((target_q - max_q) / self.eta).mean(dim=-1) - ).mean() + torch.mean(max_q) + eta_loss = ( + self.eta * self.eta_epsilon + + self.eta * torch.log(torch.exp((target_q - max_q) / self.eta).mean(dim=-1)).mean() + + torch.mean(max_q) + ) self.eta_optimizer.zero_grad() eta_loss.backward() @@ -279,12 +263,9 @@ def find_action_weights(self, states, actions): return action_weights, sample_actions, tiled_states, eta_loss - def fit_an_improved_policy( - self, action_weights, sample_actions, tiled_states - ): + def fit_an_improved_policy(self, action_weights, sample_actions, tiled_states): _, action_dist = self.action(tiled_states, mode="train") - log_prob = action_dist.log_prob(sample_actions.detach().T - ).T.reshape(-1, self.action_sample_round) + log_prob = action_dist.log_prob(sample_actions.detach().T).T.reshape(-1, self.action_sample_round) # π(k+1) = argmax Σₖ Σₙ qₙₖ * log(πθ(aₙ|sₖ)) policy_loss = torch.mean(log_prob * action_weights.detach()) @@ -297,20 +278,15 @@ def fit_an_improved_policy( kl = torch.clamp(kl, min=self.kl_clip_min, max=self.kl_clip_max) if self.kl_alpha_scaler > 0: - # pylint: disable=line-too-long # Update lagrange multipliers by gradient descent # this equation is derived from last eq of [2] p.5, - # just differentiate with respect to α - # and update α so that the equation is to be minimized. + # just differentiate with respect to α # noqa: RUF003 + # and update α so that the equation is to be minimized. # noqa: RUF003 # inspired by https://github.com/daisatojp/mpo/blob/13da541861f901436c993d0e9b0d369bf7f771d1/mpo/mpo.py#L394 - # pylint: enable=line-too-long self.kl_alpha -= self.kl_alpha_scaler * (self.kl_epsilon - kl).detach() - self.kl_alpha = torch.clamp( - self.kl_alpha, min=1e-8, max=self.kl_alpha_max - ) - # pylint: disable=line-too-long - # max_θ min_α L(θ,η) = Σₖ Σₙ qₙₖ * log(πθ(aₙ|sₖ)) + α * (ε - Σₖ 1/K * KL(πₖ(a|sₖ)||πθ(a|sₖ))) - # pylint: enable=line-too-long + self.kl_alpha = torch.clamp(self.kl_alpha, min=1e-8, max=self.kl_alpha_max) + + # max_θ min_α L(θ,η) = Σₖ Σₙ qₙₖ * log(πθ(aₙ|sₖ)) + α * (ε - Σₖ 1/K * KL(πₖ(a|sₖ)||πθ(a|sₖ))) # noqa: RUF003, E501 policy_loss = -(policy_loss + self.kl_alpha * (self.kl_epsilon - kl)) self.actor_optimizer.zero_grad() @@ -322,31 +298,20 @@ def fit_an_improved_policy( def policy_improvement(self, states, actions): # step 2 - (action_weights, sample_actions, tiled_states, - eta_loss) = self.find_action_weights(states, actions) + (action_weights, sample_actions, tiled_states, eta_loss) = self.find_action_weights(states, actions) # step 3 for _ in range(self.improved_policy_iteration): - policy_loss = self.fit_an_improved_policy( - action_weights, sample_actions, tiled_states - ) + policy_loss = self.fit_an_improved_policy(action_weights, sample_actions, tiled_states) return policy_loss, eta_loss def _learn(self, trajectory: Trajectory): - states = torch.from_numpy(np.vstack([e.state for e in trajectory]) - ).float().to(device) - actions = torch.from_numpy(np.vstack([e.action for e in trajectory]) - ).long().to(device) - rewards = torch.from_numpy(np.vstack([e.reward for e in trajectory]) - ).float().to(device) - next_states = torch.from_numpy( - np.vstack([e.next_state for e in trajectory]) - ).float().to(device) - terminates = torch.from_numpy(np.vstack([e.done for e in trajectory]) - ).float().to(device) - - val_loss = self.policy_evaluation( - states, actions, rewards, next_states, terminates - ) + states = torch.from_numpy(np.vstack([e.state for e in trajectory])).float().to(device) + actions = torch.from_numpy(np.vstack([e.action for e in trajectory])).long().to(device) + rewards = torch.from_numpy(np.vstack([e.reward for e in trajectory])).float().to(device) + next_states = torch.from_numpy(np.vstack([e.next_state for e in trajectory])).float().to(device) + terminates = torch.from_numpy(np.vstack([e.done for e in trajectory])).float().to(device) + + val_loss = self.policy_evaluation(states, actions, rewards, next_states, terminates) policy_loss, eta_loss = self.policy_improvement(states, actions) @@ -374,11 +339,11 @@ def action(self, state, mode="eval", target_policy=False): def take_action(self, state, _=0): """Returns action for given state as per current policy - Params - ======= - state (array_like): current state - epsilon (float): epsilon, for epsilon-greedy action selection - """ + Params + ======= + state (array_like): current state + epsilon (float): epsilon, for epsilon-greedy action selection + """ state = torch.from_numpy(state).float().unsqueeze(0).to(device) with torch.no_grad(): @@ -387,25 +352,21 @@ def take_action(self, state, _=0): return action_values.item() def log_prob(self, action): - return self.dist.log_prob(torch.Tensor([action]).to(device) - ).data.cpu().item() + return self.dist.log_prob(torch.Tensor([action]).to(device)).data.cpu().item() def remember(self, scenario: Trajectory): self.memory.enqueue(scenario) def soft_update(self, local_model, target_model): """ - Soft update model parameters. - θ_target = τ * θ_local + (1 - τ) * θ_target - Token from - https://github.com/udacity/deep-reinforcement-learning/blob/master/dqn/exercise/dqn_agent.py + Soft update model parameters. + θ_target = τ * θ_local + (1 - τ) * θ_target + Token from + https://github.com/udacity/deep-reinforcement-learning/blob/master/dqn/exercise/dqn_agent.py """ - for target_param, local_param in zip( - target_model.parameters(), local_model.parameters() - ): + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_( - self.update_tau * local_param.data + - (1.0 - self.update_tau) * target_param.data + self.update_tau * local_param.data + (1.0 - self.update_tau) * target_param.data ) def update_critic_target_network(self): diff --git a/PPG/__init__.py b/PPG/__init__.py index 8875f60..b7b77a8 100644 --- a/PPG/__init__.py +++ b/PPG/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/PPG/main.py b/PPG/main.py index cb1fc6b..d68a9a4 100644 --- a/PPG/main.py +++ b/PPG/main.py @@ -1,16 +1,18 @@ """main executable file for PPG""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience, Trajectory -from collections import deque +import torch + # pylint: disable=invalid-name from PPG.ppg import PPGAgent as PPG_torch +from util import generate_gif +from util.buffer import Experience, Trajectory +from util.wrappers import TrainMonitor Agent = PPG_torch logging.basicConfig(level=logging.INFO) @@ -30,7 +32,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/PPG/ppg.py b/PPG/ppg.py index 2098658..102eb48 100644 --- a/PPG/ppg.py +++ b/PPG/ppg.py @@ -6,8 +6,8 @@ from torch.nn import functional as F from util.agent import Agent -from util.buffer import ReplayBuffer, Trajectory, Experience from util.algo import calc_gaes, calc_nstep_return, standardize +from util.buffer import Experience, ReplayBuffer, Trajectory class Actor(nn.Module): @@ -20,7 +20,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=0.01, init_bias=0 ): @@ -75,7 +75,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_value_weight_gain=1, init_bias=0 ): diff --git a/PPO/__init__.py b/PPO/__init__.py index 8875f60..b7b77a8 100644 --- a/PPO/__init__.py +++ b/PPO/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/PPO/main.py b/PPO/main.py index 2aa9cab..1c451c0 100644 --- a/PPO/main.py +++ b/PPO/main.py @@ -1,16 +1,18 @@ """main executable file for ppo""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience, Trajectory -from collections import deque +import torch + # pylint: disable=invalid-name from PPO.ppo import PPOAgent as PPO_torch +from util import generate_gif +from util.buffer import Experience, Trajectory +from util.wrappers import TrainMonitor Agent = PPO_torch logging.basicConfig(level=logging.INFO) @@ -30,7 +32,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/PPO/ppo.py b/PPO/ppo.py index 815fc42..e32ce30 100644 --- a/PPO/ppo.py +++ b/PPO/ppo.py @@ -6,8 +6,8 @@ from torch.nn import functional as F from util.agent import Agent -from util.buffer import ReplayBuffer, Trajectory from util.algo import calc_gaes, calc_nstep_return, standardize +from util.buffer import ReplayBuffer, Trajectory class Actor(nn.Module): @@ -20,7 +20,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=0.01, init_bias=0 ): @@ -68,7 +68,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_value_weight_gain=1, init_bias=0 ): diff --git a/SACv1/__init__.py b/SACv1/__init__.py index 8875f60..b7b77a8 100644 --- a/SACv1/__init__.py +++ b/SACv1/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/SACv1/main.py b/SACv1/main.py index a06d473..d93e4bf 100644 --- a/SACv1/main.py +++ b/SACv1/main.py @@ -1,16 +1,19 @@ """main executable file for SAC""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from SACv1.sac import SACv1Agent as SAC_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = SAC_torch @@ -31,7 +34,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/SACv1/sac.py b/SACv1/sac.py index 876838f..0751f4c 100644 --- a/SACv1/sac.py +++ b/SACv1/sac.py @@ -3,10 +3,10 @@ import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience -from util.dist import SquashedNormal, DiagonalGaussian +from util.buffer import Experience, ReplayBuffer +from util.dist import DiagonalGaussian, SquashedNormal class Actor(nn.Module): @@ -20,7 +20,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -80,7 +80,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -126,7 +126,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -327,7 +327,7 @@ def _train_critic(self, states, actions, rewards, next_states, terminate): min_target_q_value = torch.min(current_q, current_q_1) # Compute the target Value with - # V (sₜ₊₁) = E aₜ∼π [Q(sₜ₊₁, aₜ₊₁) − α log π(aₜ₊₁|sₜ₊₁)] + # V (sₜ₊₁) = E aₜ∼π [Q(sₜ₊₁, aₜ₊₁) − α log π(aₜ₊₁|sₜ₊₁)] # noqa: RUF003 target_v = min_target_q_value - self.log_alpha.exp().detach() * log_prob # Compute value loss @@ -339,7 +339,7 @@ def _train_critic(self, states, actions, rewards, next_states, terminate): self.value_optimizer.step() # Compute the target Q with - # JQ(θ)=E (sₜ₊₁, aₜ₊₁)∼D [ 1/2 (Q(st,at)− r(st,at)+γE sₜ₊₁∼p [V(st+1)])² ] + # JQ(θ)=E (sₜ₊₁, aₜ₊₁)∼D [ 1/2 (Q(st,at)− r(st,at)+γE sₜ₊₁∼p [V(st+1)])² ] # noqa: RUF003 target_v = self.value_target.forward(next_states) target_q = rewards + ((1 - terminate) * self.gamma * target_v).detach() @@ -365,7 +365,7 @@ def _train_actor(self, states): min_target_q_value = torch.min(target_q, target_q_1) - # Jπ(φ)=E sₜ∼D [E aₜ∼π [αlog(π(aₜ|sₜ))−Qᶿ(sₜ, aₜ)]] + # Jπ(φ)=E sₜ∼D [E aₜ∼π [αlog(π(aₜ|sₜ))−Qᶿ(sₜ, aₜ)]] # noqa: RUF003 actor_loss = ( self.log_alpha.exp().detach() * log_prob - min_target_q_value ).mean() @@ -384,7 +384,6 @@ def _train_actor(self, states): self.alpha_optimizer.step() def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/SACv2/__init__.py b/SACv2/__init__.py index 8875f60..b7b77a8 100644 --- a/SACv2/__init__.py +++ b/SACv2/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/SACv2/main.py b/SACv2/main.py index a3d2aad..cf57631 100644 --- a/SACv2/main.py +++ b/SACv2/main.py @@ -1,16 +1,19 @@ """main executable file for SAC""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from SACv2.sac import SACv2Agent as SAC_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = SAC_torch @@ -31,7 +34,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/SACv2/sac.py b/SACv2/sac.py index adeb135..70cf76d 100644 --- a/SACv2/sac.py +++ b/SACv2/sac.py @@ -3,10 +3,10 @@ import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience -from util.dist import SquashedNormal, DiagonalGaussian +from util.buffer import Experience, ReplayBuffer +from util.dist import DiagonalGaussian, SquashedNormal class Actor(nn.Module): @@ -20,7 +20,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -81,7 +81,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -277,7 +277,7 @@ def _train_critic(self, states, actions, rewards, next_states, terminate): min_target_q_value = torch.min(target_q, target_q_1) # Compute the target Value with - # V (sₜ₊₁) = E aₜ∼π [Q(sₜ₊₁, aₜ₊₁) − α log π(aₜ₊₁|sₜ₊₁)] + # V (sₜ₊₁) = E aₜ∼π [Q(sₜ₊₁, aₜ₊₁) − α log π(aₜ₊₁|sₜ₊₁)] # noqa: RUF003 target_v = min_target_q_value - self.log_alpha.exp().detach() * log_prob target_q = rewards + ((1 - terminate) * self.gamma * target_v).detach() @@ -306,7 +306,7 @@ def _train_actor(self, states): min_target_q_value = torch.min(target_q, target_q_1) - # Jπ(φ)=E sₜ∼D [E aₜ∼π [αlog(π(aₜ|sₜ))−Qᶿ(sₜ, aₜ)]] + # Jπ(φ)=E sₜ∼D [E aₜ∼π [αlog(π(aₜ|sₜ))−Qᶿ(sₜ, aₜ)]] # noqa: RUF003 actor_loss = ( self.log_alpha.exp().detach() * log_prob - min_target_q_value ).mean() @@ -325,7 +325,6 @@ def _train_actor(self, states): self.alpha_optimizer.step() def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/TD3/__init__.py b/TD3/__init__.py index 8875f60..b7b77a8 100644 --- a/TD3/__init__.py +++ b/TD3/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/TD3/main.py b/TD3/main.py index ce70615..df9f7d3 100644 --- a/TD3/main.py +++ b/TD3/main.py @@ -1,16 +1,18 @@ """main executable file for TD3""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np -from util import generate_gif -from util.wrappers import TrainMonitor -from util.buffer import Experience -from collections import deque +import torch + # pylint: disable=invalid-name from TD3.td3 import TD3Agent as TD3_torch +from util import generate_gif +from util.buffer import Experience +from util.wrappers import TrainMonitor Agent = TD3_torch logging.basicConfig(level=logging.INFO) @@ -30,7 +32,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/TD3/td3.py b/TD3/td3.py index b61b9a7..59d57e9 100644 --- a/TD3/td3.py +++ b/TD3/td3.py @@ -3,9 +3,9 @@ import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience +from util.buffer import Experience, ReplayBuffer from util.dist import OrnsteinUhlenbeckNoise @@ -20,7 +20,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -73,7 +73,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -249,7 +249,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/XAWR/__init__.py b/XAWR/__init__.py index 8875f60..b7b77a8 100644 --- a/XAWR/__init__.py +++ b/XAWR/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/XAWR/main.py b/XAWR/main.py index 921a7ce..cdc1d44 100644 --- a/XAWR/main.py +++ b/XAWR/main.py @@ -1,14 +1,17 @@ """main executable file for xawr""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np +import torch + from util import generate_gif -from util.wrappers import TrainMonitor from util.buffer import Experience, Trajectory -from collections import deque +from util.wrappers import TrainMonitor + # pylint: disable=invalid-name from XAWR.xawr import XAWRAgent as XAWR_torch @@ -30,7 +33,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/XAWR/xawr.py b/XAWR/xawr.py index 0dd165c..a0ff696 100644 --- a/XAWR/xawr.py +++ b/XAWR/xawr.py @@ -1,5 +1,6 @@ """AWR implementation with pytorch.""" from functools import partial + import numpy as np import torch from torch import nn @@ -7,9 +8,13 @@ from torch.nn import functional as F from util.agent import Agent +from util.algo import ( # pylint: disable=unused-import + gumbel_rescale_loss, + scale_down_values, + scale_up_values, + standardize, +) from util.buffer import ReplayBuffer, Trajectory -from util.algo import standardize, scale_down_values, scale_up_values -from util.algo import gumbel_rescale_loss, gumbel_loss # pylint: disable=unused-import class Actor(nn.Module): @@ -22,7 +27,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=0.01, init_bias=0 ): @@ -72,7 +77,7 @@ def __init__( seed=0, fc1_unit=256, fc2_unit=256, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_value_weight_gain=1, init_bias=0 ): diff --git a/XDDPG/__init__.py b/XDDPG/__init__.py index 8875f60..b7b77a8 100644 --- a/XDDPG/__init__.py +++ b/XDDPG/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/XDDPG/main.py b/XDDPG/main.py index ec8bd7f..2d2a627 100644 --- a/XDDPG/main.py +++ b/XDDPG/main.py @@ -1,16 +1,20 @@ """main executable file for XDDPG""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np +import torch + from util import generate_gif -from util.wrappers import TrainMonitor from util.buffer import Experience -from collections import deque +from util.wrappers import TrainMonitor + # pylint: disable=invalid-name from XDDPG.xddpg import XDDPGAgent as XDDPG_torch + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = XDDPG_torch @@ -31,7 +35,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/XDDPG/xddpg.py b/XDDPG/xddpg.py index 20b72fd..854725e 100644 --- a/XDDPG/xddpg.py +++ b/XDDPG/xddpg.py @@ -1,14 +1,17 @@ """DDPG implementation with pytorch.""" +from functools import partial + import numpy as np import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience -from functools import partial +from util.algo import ( # pylint: disable=unused-import + gumbel_rescale_loss, +) +from util.buffer import Experience, ReplayBuffer from util.dist import OrnsteinUhlenbeckNoise -from util.algo import gumbel_rescale_loss, gumbel_loss # pylint: disable=unused-import class Actor(nn.Module): @@ -22,7 +25,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -75,7 +78,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -246,7 +249,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/XSAC/__init__.py b/XSAC/__init__.py index 8875f60..b7b77a8 100644 --- a/XSAC/__init__.py +++ b/XSAC/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/XSAC/main.py b/XSAC/main.py index 61ccb78..4a2b2de 100644 --- a/XSAC/main.py +++ b/XSAC/main.py @@ -1,16 +1,20 @@ """main executable file for XSAC""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np +import torch + from util import generate_gif -from util.wrappers import TrainMonitor from util.buffer import Experience -from collections import deque +from util.wrappers import TrainMonitor + # pylint: disable=invalid-name from XSAC.xsac import XSACAgent as XSAC_torch + # from DQN.dqn_torch import DQNAgent as DQN_torch Agent = XSAC_torch @@ -31,7 +35,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/XSAC/xsac.py b/XSAC/xsac.py index b5f47fc..69b8ff5 100644 --- a/XSAC/xsac.py +++ b/XSAC/xsac.py @@ -1,14 +1,17 @@ """SAC implementation with pytorch.""" +from functools import partial + import numpy as np import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience -from util.dist import SquashedNormal, DiagonalGaussian -from functools import partial -from util.algo import gumbel_rescale_loss, gumbel_loss # pylint: disable=unused-import +from util.algo import ( # pylint: disable=unused-import + gumbel_rescale_loss, +) +from util.buffer import Experience, ReplayBuffer +from util.dist import DiagonalGaussian, SquashedNormal class Actor(nn.Module): @@ -22,7 +25,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -82,7 +85,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -128,7 +131,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -336,7 +339,7 @@ def _train_critic(self, states, actions, rewards, next_states, terminate): min_target_q_value = torch.min(current_q, current_q_1) # Compute the target Value with - # V (sₜ₊₁) = E aₜ∼π [Q(sₜ₊₁, aₜ₊₁) − α log π(aₜ₊₁|sₜ₊₁)] + # V (sₜ₊₁) = E aₜ∼π [Q(sₜ₊₁, aₜ₊₁) − α log π(aₜ₊₁|sₜ₊₁)] # noqa: RUF003 target_v = min_target_q_value - self.log_alpha.exp().detach() * log_prob # Compute value loss @@ -348,7 +351,7 @@ def _train_critic(self, states, actions, rewards, next_states, terminate): self.value_optimizer.step() # Compute the target Q with - # JQ(θ)=E (sₜ₊₁, aₜ₊₁)∼D [ 1/2 (Q(st,at)− r(st,at)+γE sₜ₊₁∼p [V(st+1)])² ] + # JQ(θ)=E (sₜ₊₁, aₜ₊₁)∼D [ 1/2 (Q(st,at)− r(st,at)+γE sₜ₊₁∼p [V(st+1)])² ] # noqa: RUF003 target_v = self.value_target.forward(next_states) target_q = rewards + ((1 - terminate) * self.gamma * target_v).detach() @@ -374,7 +377,7 @@ def _train_actor(self, states): min_target_q_value = torch.min(target_q, target_q_1) - # Jπ(φ)=E sₜ∼D [E aₜ∼π [αlog(π(aₜ|sₜ))−Qᶿ(sₜ, aₜ)]] + # Jπ(φ)=E sₜ∼D [E aₜ∼π [αlog(π(aₜ|sₜ))−Qᶿ(sₜ, aₜ)]] # noqa: RUF003 actor_loss = ( self.log_alpha.exp().detach() * log_prob - min_target_q_value ).mean() @@ -393,7 +396,6 @@ def _train_actor(self, states): self.alpha_optimizer.step() def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/XTD3/__init__.py b/XTD3/__init__.py index 8875f60..b7b77a8 100644 --- a/XTD3/__init__.py +++ b/XTD3/__init__.py @@ -1 +1 @@ -# pylint: disable=all +# noqa: N999 diff --git a/XTD3/main.py b/XTD3/main.py index 62e44f0..56ebc30 100644 --- a/XTD3/main.py +++ b/XTD3/main.py @@ -1,14 +1,17 @@ """main executable file for XTD3""" -import os import logging +import os +from collections import deque from itertools import repeat + import gymnasium as gym -import torch import numpy as np +import torch + from util import generate_gif -from util.wrappers import TrainMonitor from util.buffer import Experience -from collections import deque +from util.wrappers import TrainMonitor + # pylint: disable=invalid-name from XTD3.xtd3 import XTD3Agent as XTD3_torch @@ -30,7 +33,6 @@ def main( score_term_rules=lambda s: False, time_interval="25ms" ): - # pylint: disable=line-too-long """Deep Q-Learning Params diff --git a/XTD3/xtd3.py b/XTD3/xtd3.py index b3e0a37..681d23f 100644 --- a/XTD3/xtd3.py +++ b/XTD3/xtd3.py @@ -1,14 +1,17 @@ """XTD3 implementation with pytorch.""" +from functools import partial + import numpy as np import torch from torch import nn from torch.nn import functional as F -from util.buffer import ReplayBuffer + from util.agent import Agent -from util.buffer import Experience -from functools import partial +from util.algo import ( # pylint: disable=unused-import + gumbel_rescale_loss, +) +from util.buffer import Experience, ReplayBuffer from util.dist import OrnsteinUhlenbeckNoise -from util.algo import gumbel_rescale_loss, gumbel_loss # pylint: disable=unused-import class Actor(nn.Module): @@ -22,7 +25,7 @@ def __init__( fc1_unit=64, fc2_unit=64, max_action=1, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_policy_weight_gain=1, init_bias=0 ): @@ -75,7 +78,7 @@ def __init__( seed=0, fc1_unit=64, fc2_unit=64, - init_weight_gain=np.sqrt(2), + init_weight_gain=np.sqrt(2), # noqa: B008 init_bias=0 ): """ @@ -264,7 +267,6 @@ def remember(self, scenario: Experience): self.memory.enqueue(scenario) def _learn(self, experiences): - # pylint: disable=line-too-long """Update value parameters using given batch of experience tuples. Params ======= diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..06312bc --- /dev/null +++ b/ruff.toml @@ -0,0 +1,34 @@ +target-version = "py312" +line-length = 108 +indent-width = 2 +respect-gitignore = true +exclude = ["third_party"] +fix = true +show-fixes = true + +[lint] +select = [ + "E", "F", "W", "I", "N", "UP", "B", "C4", "RUF" +] + +ignore = [ + "E731", "B905", "C901", "B006", + "N802", "N803", "N806", "N812", "N813", "N815", "N818" +] + +[lint.pep8-naming] +ignore-names = ["foo", "bar"] +extend-ignore-names = ["baz"] +classmethod-decorators = ["classmethod"] +staticmethod-decorators = ["staticmethod"] + +[format] +quote-style = "double" +line-ending = "lf" +indent-style = "space" +skip-magic-trailing-comma = true + +[lint.isort] +force-single-line = false +known-third-party = ["absl", "enchant"] +combine-as-imports = true diff --git a/tests/gym_test.py b/tests/gym_test.py index ade4e00..5a60479 100644 --- a/tests/gym_test.py +++ b/tests/gym_test.py @@ -1,4 +1,5 @@ import gymnasium as gym + from util import generate_gif from util.wrappers import TrainMonitor diff --git a/util/_misc.py b/util/_misc.py index b4642ea..f6a9971 100644 --- a/util/_misc.py +++ b/util/_misc.py @@ -1,9 +1,10 @@ """Some basic helper function""" -import os import logging +import os import numpy as np from PIL import Image + from .wrappers import TrainMonitor @@ -15,7 +16,6 @@ def generate_gif( duration=50, max_episode_steps=None ): - # pylint: disable=line-too-long r""" Store a gif from the episode frames. Parameters @@ -40,7 +40,7 @@ def generate_gif( """ logger = logging.getLogger('generate_gif') max_episode_steps = max_episode_steps \ - or getattr(getattr(env, 'spec'), 'max_episode_steps', 10000) + or getattr(env.spec, 'max_episode_steps', 10000) if isinstance(env, TrainMonitor): env = env.env # unwrap to strip off TrainMonitor diff --git a/util/agent.py b/util/agent.py index 605a021..4633297 100644 --- a/util/agent.py +++ b/util/agent.py @@ -1,7 +1,7 @@ """The basic frame for Agent""" -class Agent(): +class Agent: """The basic class for agent""" def __call__(self, *args, **kwds): diff --git a/util/algo.py b/util/algo.py index 7b4c21c..5e1bfd9 100644 --- a/util/algo.py +++ b/util/algo.py @@ -58,7 +58,7 @@ def calc_gaes( gae_lambda=0.95 ): # GAE = ∑ₗ (γλ)ˡδₜ₊ₗ - # δₜ₊ₗ = rₜ + γV(sₜ₊₁) − V(sₜ) + # δₜ₊ₗ = rₜ + γV(sₜ₊₁) − V(sₜ) # noqa: RUF003 T = len(rewards) # pylint: disable=invalid-name device = rewards.device gaes = torch.zeros_like(rewards, device=device) diff --git a/util/buffer.py b/util/buffer.py index f8eff47..8f926ed 100644 --- a/util/buffer.py +++ b/util/buffer.py @@ -1,25 +1,20 @@ """The buffer protocol described here""" + +import random +import warnings from collections import deque from copy import deepcopy -import warnings -import random + import numpy as np + from util.tree import SumTree class Experience: - # pylint: disable=line-too-long """Experience is a pickle of (state, action, reward, next_state, done, log_prob...)""" def __init__( - self, - state=None, - action=None, - reward=None, - next_state=None, - done=None, - log_prob=None, - **kwargs + self, state=None, action=None, reward=None, next_state=None, done=None, log_prob=None, **kwargs ) -> None: self.state = state self.action = action @@ -33,7 +28,6 @@ def __init__( class Trajectory: - # pylint: disable=line-too-long """ The Trajectory class is used to store the experiences of a whole trajectory policy. @@ -68,13 +62,13 @@ def __iter__(self): return iter(self.q) def __repr__(self) -> str: - return f"{self.__class__.__name__}({", ".join(repr(e) for e in self.q)})" + return f"{self.__class__.__name__}({', '.join(repr(e) for e in self.q)})" def __str__(self) -> str: - return f"{self.__class__.__name__}({", ".join(str(e) for e in self.q)})" + return f"{self.__class__.__name__}({', '.join(str(e) for e in self.q)})" -class ReplayBuffer(): +class ReplayBuffer: """Replay Buffer for off-policy training""" def __init__(self, max_size=None) -> None: @@ -82,12 +76,7 @@ def __init__(self, max_size=None) -> None: self.q: deque = deque([], maxlen=max_size) def sample_from( - self, - sample_ratio=None, - num_samples=1, - drop_samples=False, - sample_distribution_fn=None, - replace=True, + self, sample_ratio=None, num_samples=1, drop_samples=False, sample_distribution_fn=None, replace=True ): """Sample a batch of experiences from the replay buffer""" if not self.q: @@ -100,10 +89,10 @@ def sample_from( return [] selected_sample_ids = np.random.choice( - range(len(self.q)), - size=(num_samples, ), - replace=replace, - p=sample_distribution_fn() if sample_distribution_fn else None + range(len(self.q)), + size=(num_samples,), + replace=replace, + p=sample_distribution_fn() if sample_distribution_fn else None, ) samples = [deepcopy(self.q[idx]) for idx in selected_sample_ids] @@ -118,7 +107,7 @@ def sample_from( def enqueue(self, sample): if not self.isfull(): return self.q.append(sample) - warnings.warn("the buffer is full, the first sample will be dropped.") + warnings.warn("the buffer is full, the first sample will be dropped.", stacklevel=1) self._dequeue() return self.q.append(sample) @@ -146,10 +135,10 @@ def delete_none(self): self.q = deque([sample for sample in self.q if sample is not None]) def __repr__(self) -> str: - return f"{self.__class__.__name__}({", ".join(repr(e) for e in self.q)})" + return f"{self.__class__.__name__}({', '.join(repr(e) for e in self.q)})" def __str__(self) -> str: - return f"{self.__class__.__name__}({", ".join(str(e) for e in self.q)})" + return f"{self.__class__.__name__}({', '.join(str(e) for e in self.q)})" def __contain__(self, e): return e in self.q @@ -161,7 +150,6 @@ def __iter__(self): return iter(self.q) -# pylint: disable=line-too-long # # The Code is taken from https://github.com/Howuhh/prioritized_experience_replay/blob/main/memory/buffer.py class ProportionalPrioritizedReplayBuffer: """Proportional Prioritized ReplayBuffer for off-policy training""" @@ -171,9 +159,14 @@ def __init__(self, max_size=None, eps=1e-2, alpha=0.1, beta=0.1): self.tree = SumTree(size=self.size) # PER params - self.eps = eps # minimal priority, prevents zero probabilities - self.alpha = alpha # determines how much prioritization is used, α = 0 corresponding to the uniform case - self.beta = beta # determines the amount of importance-sampling correction, b = 1 fully compensate for the non-uniform probabilities + # minimal priority, prevents zero probabilities + self.eps = eps + # determines how much prioritization is used + # α = 0 corresponding to the uniform case # noqa: RUF003 + self.alpha = alpha + # determines the amount of importance-sampling correction, + # β = 1 fully compensate for the non-uniform probabilities + self.beta = beta self.max_priority = eps # priority for new samples, init as eps self.count = 0 @@ -182,7 +175,6 @@ def __init__(self, max_size=None, eps=1e-2, alpha=0.1, beta=0.1): self.sample_indices = [] def enqueue(self, sample): - # store transition index with maximum priority in sum tree self.tree.add(self.max_priority, sample) @@ -205,7 +197,8 @@ def sample_from(self, sample_ratio=None, num_samples=1, **kwargs): # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges. # Next, a value is uniformly sampled from each range. Finally the transitions that correspond - # to each of these sampled values are retrieved from the tree. (Appendix B.2.1, Proportional prioritization) + # to each of these sampled values are retrieved from the tree. + # (Appendix B.2.1, Proportional prioritization) segment = self.tree.total / num_samples for i in range(num_samples): a, b = segment * i, segment * (i + 1) @@ -219,24 +212,35 @@ def sample_from(self, sample_ratio=None, num_samples=1, **kwargs): indices.append(index) samples.append(sample_idx) - # Concretely, we define the probability of sampling transition i as P(i) = p_i^α / \sum_{k} p_k^α + # Concretely, we define the probability of sampling transition i as P(i) = p_i^α / \sum_{k} p_k^α # noqa: RUF003, E501 # where p_i > 0 is the priority of transition i. (Section 3.3) probs = priorities / self.tree.total - # The estimation of the expected value with stochastic updates relies on those updates corresponding - # to the same distribution as its expectation. Prioritized replay introduces bias because it changes this - # distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will - # converge to (even if the policy and state distribution are fixed). We can correct this bias by using - # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β that fully compensates for the non-uniform - # probabilities P(i) if β = 1. These weights can be folded into the Q-learning update by using w_i * δ_i - # instead of δ_i (this is thus weighted IS, not ordinary IS, see e.g. Mahmood et al., 2014). - # For stability reasons, we always normalize weights by 1/maxi wi so that they only scale the + # The estimation of the expected value with stochastic updates + # relies on those updates corresponding + # to the same distribution as its expectation. + # Prioritized replay introduces bias because it changes this + # distribution in an uncontrolled fashion, a + # nd therefore changes the solution that the estimates will + # converge to (even if the policy and state distribution are fixed). + # We can correct this bias by using + # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β + # that fully compensates for the non-uniform + # probabilities P(i) if β = 1. These weights can be + # folded into the Q-learning update by using w_i * δ_i + # instead of δ_i (this is thus weighted IS, not ordinary IS, + # see e.g. Mahmood et al., 2014). + # For stability reasons, we always normalize weights + # by 1/maxi wi so that they only scale the # update downwards (Section 3.4, first paragraph) weights = (self.real_size * probs) ** -self.beta - # As mentioned in Section 3.4, whenever importance sampling is used, all weights w_i were scaled - # so that max_i w_i = 1. We found that this worked better in practice as it kept all weights - # within a reasonable range, avoiding the possibility of extremely large updates. (Appendix B.2.1, Proportional prioritization) + # As mentioned in Section 3.4, whenever importance sampling is used, + # all weights w_i were scaled + # so that max_i w_i = 1. We found that this worked better + # in practice as it kept all weights + # within a reasonable range, avoiding the possibility + # of extremely large updates. (Appendix B.2.1, Proportional prioritization) weights = weights / weights.max() self.sample_weights = weights diff --git a/util/dist.py b/util/dist.py index c7197d7..33e7f0e 100644 --- a/util/dist.py +++ b/util/dist.py @@ -21,21 +21,19 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ============================================================================= -# pylint: disable=line-too-long,unused-argument # Borrow from `https://github.com/RLE-Foundation/rllte` """Distributions for action noise and policy.""" import math import re -from typing import Any, Tuple, Optional, Union +from typing import Any import numpy as np import torch as th -from torch.distributions import register_kl from torch import distributions as pyd -from torch.nn import functional as F -from torch.distributions import Distribution +from torch.distributions import Distribution, register_kl from torch.distributions.utils import _standard_normal +from torch.nn import functional as F def schedule(schdl: str, step: int) -> float: @@ -113,7 +111,7 @@ def logits(self) -> th.Tensor: """Returns the unnormalized log probabilities.""" return self.dist.logits - def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -179,7 +177,7 @@ def logits(self) -> th.Tensor: """Returns the unnormalized log probabilities.""" return self.dist.logits - def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -223,7 +221,7 @@ class MultiCategorical(BaseDistribution): def __init__(self) -> None: super().__init__() - def __call__(self, logits: Tuple[th.Tensor, ...]): + def __call__(self, logits: tuple[th.Tensor, ...]): """Create the distribution. Args: @@ -237,16 +235,16 @@ def __call__(self, logits: Tuple[th.Tensor, ...]): return self @property - def probs(self) -> Tuple[th.Tensor, ...]: + def probs(self) -> tuple[th.Tensor, ...]: """Return probabilities.""" return (dist.probs for dist in self.dist) # type: ignore @property - def logits(self) -> Tuple[th.Tensor, ...]: + def logits(self) -> tuple[th.Tensor, ...]: """Returns the unnormalized log probabilities.""" return (dist.logits for dist in self.dist) # type: ignore - def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -341,7 +339,7 @@ def __call__(self, mu: th.Tensor, sigma: th.Tensor): ) return self - def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -353,7 +351,7 @@ def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 """ return self.dist.sample(sample_shape) - def rsample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def rsample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. @@ -411,7 +409,7 @@ def __call__(self, mu: th.Tensor, sigma: th.Tensor): self.dist = pyd.Normal(loc=mu, scale=sigma) return self - def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -423,7 +421,7 @@ def sample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 """ return self.dist.sample(sample_shape) - def rsample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # B008 + def rsample(self, sample_shape: th.Size = th.Size()) -> th.Tensor: # noqa: B008 """Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. @@ -518,8 +516,8 @@ class NormalNoise(BaseDistribution): def __init__( self, - mu: Union[float, th.Tensor] = 0.0, - sigma: Union[float, th.Tensor] = 1.0, + mu: float | th.Tensor = 0.0, + sigma: float | th.Tensor = 1.0, low: float = -1.0, high: float = 1.0, eps: float = 1e-6, @@ -553,8 +551,8 @@ def _clamp(self, x: th.Tensor) -> th.Tensor: def sample( self, - clip: Optional[float] = None, - sample_shape: th.Size = th.Size() + clip: float | None = None, + sample_shape: th.Size = th.Size() # noqa: B008 ) -> th.Tensor: # type: ignore[override] """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -611,8 +609,8 @@ class OrnsteinUhlenbeckNoise(BaseDistribution): def __init__( self, - mu: Union[float, th.Tensor] = 0.0, - sigma: Union[float, th.Tensor] = 1.0, + mu: float | th.Tensor = 0.0, + sigma: float | th.Tensor = 1.0, low: float = -1.0, high: float = 1.0, eps: float = 1e-6, @@ -629,7 +627,7 @@ def __init__( self.eps = eps self.theta = theta self.dt = dt - self.noise_prev: Union[None, th.Tensor] = None + self.noise_prev: None | th.Tensor = None if sigma_schedule and isinstance(sigma, float): self.sigma_schedule = sigma_schedule else: @@ -660,8 +658,8 @@ def _clamp(self, x: th.Tensor) -> th.Tensor: def sample( self, - clip: Optional[float] = None, - sample_shape: th.Size = th.Size() + clip: float | None = None, + sample_shape: th.Size = th.Size() # noqa: B008 ) -> th.Tensor: # type: ignore[override] """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -731,8 +729,8 @@ class TruncatedNormalNoise(BaseDistribution): def __init__( self, - mu: Union[float, th.Tensor] = 0.0, - sigma: Union[float, th.Tensor] = 1.0, + mu: float | th.Tensor = 0.0, + sigma: float | th.Tensor = 1.0, low: float = -1.0, high: float = 1.0, eps: float = 1e-6, @@ -769,8 +767,8 @@ def _clamp(self, x: th.Tensor) -> th.Tensor: def sample( self, - clip: Optional[float] = None, - sample_shape: th.Size = th.Size() + clip: float | None = None, + sample_shape: th.Size = th.Size() # noqa: B008 ) -> th.Tensor: # type: ignore[override] """Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. diff --git a/util/tree.py b/util/tree.py index 9c64fb9..d2cd86f 100644 --- a/util/tree.py +++ b/util/tree.py @@ -1,7 +1,6 @@ """Sum Tree data structure for Prioritized Experience Replay.""" -# pylint: disable=line-too-long # Copy from https://github.com/Howuhh/prioritized_experience_replay/blob/main/memory/tree.py -# The ‘sum-tree’ data structure used here is very similar in spirit to the array representation +# The `sum-tree` data structure used here is very similar in spirit to the array representation # of a binary heap. However, instead of the usual heap property, the value of a parent node is # the sum of its children. Leaf nodes store the transition priorities and the internal nodes are # intermediate sums, with the parent node containing the sum over all priorities, p_total. This diff --git a/util/wrappers.py b/util/wrappers.py index 0524ae0..4e40ce5 100644 --- a/util/wrappers.py +++ b/util/wrappers.py @@ -1,21 +1,21 @@ """Some wrappers for gymnasium environment.""" + +import datetime import os import re -import datetime import time from collections import deque -from typing import Mapping +from collections.abc import Mapping import numpy as np from gymnasium import Wrapper from gymnasium.spaces import Discrete from tensorboardX import SummaryWriter -__all__ = ('TrainMonitor', ) +__all__ = ("TrainMonitor",) class StreamingSample: - # pylint: disable=line-too-long """Samples are being produced by a wrapped environment at some point in time.""" def __init__(self, maxlen, random_seed=None): @@ -50,9 +50,8 @@ def __bool__(self): return bool(self._deque) -#pylint: disable=invalid-name +# pylint: disable=invalid-name class TrainMonitor(Wrapper): - # pylint: disable=line-too-long r""" Environment wrapper for monitoring the training process. This wrapper logs some diagnostics at the end of each episode and it also gives us some handy @@ -100,20 +99,29 @@ class TrainMonitor(Wrapper): dt_ms : float The average wall time of a single step, in milliseconds. """ + _COUNTER_ATTRS = ( - 'T', 'ep', 't', 'G', 'avg_G', '_n_avg_G', '_ep_starttime', '_ep_metrics', - '_ep_actions', '_tensorboard_dir', '_period' + "T", + "ep", + "t", + "G", + "avg_G", + "_n_avg_G", + "_ep_starttime", + "_ep_metrics", + "_ep_actions", + "_tensorboard_dir", + "_period", ) def __init__( - self, - env, - tensorboard_dir=None, - tensorboard_write_all=False, - log_all_metrics=False, - smoothing=10 + self, + env, + tensorboard_dir=None, + tensorboard_write_all=False, + log_all_metrics=False, + smoothing=10, ): - super().__init__(env) self.log_all_metrics = log_all_metrics self.tensorboard_write_all = tensorboard_write_all @@ -122,7 +130,7 @@ def __init__( self._init_tensorboard(tensorboard_dir) def reset_global(self): - r""" Reset the global counters, not just the episodic ones. """ + r"""Reset the global counters, not just the episodic ones.""" self.T = 0 self.ep = 0 self.t = 0 @@ -132,7 +140,7 @@ def reset_global(self): self._ep_starttime = time.time() self._ep_metrics = {} self._ep_actions = StreamingSample(maxlen=1000) - self._period = {'T': {}, 'ep': {}} + self._period = {"T": {}, "ep": {}} def reset(self): # write logs from previous episode: @@ -170,77 +178,77 @@ def step(self, action): if info is None: info = {} - info['monitor'] = {'T': self.T, 'ep': self.ep} + info["monitor"] = {"T": self.T, "ep": self.ep} self.t += 1 self.T += 1 self.G += r if done: if self._n_avg_G < self.smoothing: - self._n_avg_G += 1. + self._n_avg_G += 1.0 self.avg_G += (self.G - self.avg_G) / self._n_avg_G return s_next, r, done, truncated, info def record_metrics(self, metrics): r""" - Record metrics during the training process. - These are used to print more diagnostics. - Parameters - ---------- - metrics : dict - A dict of metrics, of type ``{name : value }``. - """ + Record metrics during the training process. + These are used to print more diagnostics. + Parameters + ---------- + metrics : dict + A dict of metrics, of type ``{name : value }``. + """ if not isinstance(metrics, Mapping): - raise TypeError('metrics must be a Mapping') + raise TypeError("metrics must be a Mapping") # write metrics to tensoboard if self.tensorboard is not None and self.tensorboard_write_all: for name, metric in metrics.items(): self.tensorboard.add_scalar( - str(name), float(metric), global_step=self.T + str(name), float(metric), global_step=self.T ) # compute episode averages for k, v in metrics.items(): if k not in self._ep_metrics: - self._ep_metrics[k] = v, 1. + self._ep_metrics[k] = v, 1.0 else: x, n = self._ep_metrics[k] self._ep_metrics[k] = x + v, n + 1 def get_metrics(self): r""" - Return the current state of the metrics. - Returns - ------- - metrics : dict - A dict of metrics, of type ``{name : value }``. - """ + Return the current state of the metrics. + Returns + ------- + metrics : dict + A dict of metrics, of type ``{name : value }``. + """ return {k: float(x) / n for k, (x, n) in self._ep_metrics.items()} def period(self, name, T_period=None, ep_period=None): if T_period is not None: T_period = int(T_period) assert T_period > 0 - if name not in self._period['T']: - self._period['T'][name] = 1 - if self.T >= self._period['T'][name] * T_period: - self._period['T'][name] += 1 + if name not in self._period["T"]: + self._period["T"][name] = 1 + if self.T >= self._period["T"][name] * T_period: + self._period["T"][name] += 1 return True or self.period(name, None, ep_period) return self.period(name, None, ep_period) if ep_period is not None: ep_period = int(ep_period) assert ep_period > 0 - if name not in self._period['ep']: - self._period['ep'][name] = 1 - if self.ep >= self._period['ep'][name] * ep_period: - self._period['ep'][name] += 1 + if name not in self._period["ep"]: + self._period["ep"][name] = 1 + if self.ep >= self._period["ep"][name] * ep_period: + self._period["ep"][name] += 1 return True return False @property def tensorboard(self): - if not hasattr(self, '_tensorboard'): + if not hasattr(self, "_tensorboard"): assert self._tensorboard_dir is not None self._tensorboard = SummaryWriter(self._tensorboard_dir) return self._tensorboard @@ -252,40 +260,39 @@ def _init_tensorboard(self, tensorboard_dir): return # append timestamp to disambiguate instances - if not re.match(r'.*/\d{8}_\d{6}$', tensorboard_dir): + if not re.match(r".*/\d{8}_\d{6}$", tensorboard_dir): tensorboard_dir = os.path.join( - tensorboard_dir, - datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + tensorboard_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S") ) # only set/update if necessary - if tensorboard_dir != getattr(self, '_tensorboard_dir', None): + if tensorboard_dir != getattr(self, "_tensorboard_dir", None): self._tensorboard_dir = tensorboard_dir - if hasattr(self, '_tensorboard'): + if hasattr(self, "_tensorboard"): del self._tensorboard if self.tensorboard is not None: metrics = { - 'episode/episode': self.ep, - 'episode/avg_reward': self.avg_r, - 'episode/return': self.G, - 'episode/steps': self.t, - 'episode/avg_step_duration_ms': self.dt_ms + "episode/episode": self.ep, + "episode/avg_reward": self.avg_r, + "episode/return": self.G, + "episode/steps": self.t, + "episode/avg_step_duration_ms": self.dt_ms, } for name, metric in metrics.items(): self.tensorboard.add_scalar( - str(name), float(metric), global_step=self.T + str(name), float(metric), global_step=self.T ) if self._ep_actions: if isinstance(self.action_space, Discrete): bins = np.arange(self.action_space.n + 1) else: - bins = 'auto' # see also: np.histogram_bin_edges.__doc__ + bins = "auto" # see also: np.histogram_bin_edges.__doc__ self.tensorboard.add_histogram( - tag='actions', - values=self._ep_actions.values, - global_step=self.T, - bins=bins + tag="actions", + values=self._ep_actions.values, + global_step=self.T, + bins=bins, ) if self._ep_metrics and not self.tensorboard_write_all: for k, (x, n) in self._ep_metrics.items(): @@ -294,36 +301,39 @@ def _init_tensorboard(self, tensorboard_dir): def _write_episode_logs(self): metrics = ( - f'{k:s}: {float(x) / n:.3g}' for k, (x, n) in self._ep_metrics.items() - if ( - self.log_all_metrics or str(k).endswith('/loss') - or str(k).endswith('/entropy') or str(k).endswith('/kl_div') - or str(k).startswith('throughput/') - ) + f"{k:s}: {float(x) / n:.3g}" + for k, (x, n) in self._ep_metrics.items() + if ( + self.log_all_metrics + or str(k).endswith("/loss") + or str(k).endswith("/entropy") + or str(k).endswith("/kl_div") + or str(k).startswith("throughput/") + ) ) if self.tensorboard is not None: metrics = { - 'episode/episode': self.ep, - 'episode/avg_reward': self.avg_r, - 'episode/return': self.G, - 'episode/steps': self.t, - 'episode/avg_step_duration_ms': self.dt_ms + "episode/episode": self.ep, + "episode/avg_reward": self.avg_r, + "episode/return": self.G, + "episode/steps": self.t, + "episode/avg_step_duration_ms": self.dt_ms, } for name, metric in metrics.items(): self.tensorboard.add_scalar( - str(name), float(metric), global_step=self.T + str(name), float(metric), global_step=self.T ) if self._ep_actions: if isinstance(self.action_space, Discrete): bins = np.arange(self.action_space.n + 1) else: - bins = 'auto' # see also: np.histogram_bin_edges.__doc__ + bins = "auto" # see also: np.histogram_bin_edges.__doc__ self.tensorboard.add_histogram( - tag='actions', - values=self._ep_actions.values, - global_step=self.T, - bins=bins + tag="actions", + values=self._ep_actions.values, + global_step=self.T, + bins=bins, ) if self._ep_metrics and not self.tensorboard_write_all: for k, (x, n) in self._ep_metrics.items(): @@ -332,35 +342,34 @@ def _write_episode_logs(self): def __getstate__(self): state = self.__dict__.copy() # shallow copy - if '_tensorboard' in state: - del state['_tensorboard'] # remove reference to non-pickleable attr + if "_tensorboard" in state: + del state["_tensorboard"] # remove reference to non-pickleable attr return state def __setstate__(self, state): self.__dict__.update(state) - self._init_tensorboard(state['_tensorboard_dir']) + self._init_tensorboard(state["_tensorboard_dir"]) def get_counters(self): r""" - Get the current state of all internal counters. - Returns - ------- - counter : dict - The dict that contains the counters. - """ + Get the current state of all internal counters. + Returns + ------- + counter : dict + The dict that contains the counters. + """ return {k: getattr(self, k) for k in self._COUNTER_ATTRS} def set_counters(self, counters): r""" - Restore the state of all internal counters. - Parameters - ---------- - counter : dict - The dict that contains the counters. - """ + Restore the state of all internal counters. + Parameters + ---------- + counter : dict + The dict that contains the counters. + """ if not ( - isinstance(counters, dict) - and set(counters) == set(self._COUNTER_ATTRS) + isinstance(counters, dict) and set(counters) == set(self._COUNTER_ATTRS) ): - raise TypeError(f'invalid counters dict: {counters}') + raise TypeError(f"invalid counters dict: {counters}") self.__setstate__(counters)