diff --git a/.github/workflows/ci-lint-cvat-recording-oracle.yaml b/.github/workflows/ci-lint-cvat-recording-oracle.yaml new file mode 100644 index 0000000000..8c91ac49ec --- /dev/null +++ b/.github/workflows/ci-lint-cvat-recording-oracle.yaml @@ -0,0 +1,33 @@ +name: CVAT Recording Oracle Lint + +on: + push: + paths: + - 'packages/examples/cvat/recording-oracle/**' + - '.github/workflows/ci-lint-cvat-recording-oracle.yaml' + +env: + WORKING_DIR: ./packages/examples/cvat/recording-oracle + +defaults: + run: + working-directory: ./packages/examples/cvat/recording-oracle + +jobs: + cvat-exo-lint: + name: CVAT Recording Oracle Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: ${{ env.WORKING_DIR }}/poetry.lock + - run: python -m pip install poetry + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'poetry' + cache-dependency-path: ${{ env.WORKING_DIR }}/poetry.lock + - run: poetry install --no-root --only lint + - run: poetry run pre-commit run --all-files diff --git a/packages/examples/cvat/recording-oracle/.pre-commit-config.yaml b/packages/examples/cvat/recording-oracle/.pre-commit-config.yaml index 211285f89c..395275745a 100644 --- a/packages/examples/cvat/recording-oracle/.pre-commit-config.yaml +++ b/packages/examples/cvat/recording-oracle/.pre-commit-config.yaml @@ -1,11 +1,17 @@ repos: - - repo: https://github.com/ambv/black - rev: 22.6.0 + - repo: local hooks: - - id: black - language_version: python3.10 - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - language_version: python3.11 \ No newline at end of file + - id: lint + name: lint + entry: ruff check --fix --unsafe-fixes --show-fixes + language: system + require_serial: true + files: "^packages/examples/cvat/recording-oracle/.*" + types: [python] + - id: format + name: format + entry: ruff format + require_serial: true + language: system + files: "^packages/examples/cvat/recording-oracle/.*" + types: [python] diff --git a/packages/examples/cvat/recording-oracle/alembic.ini b/packages/examples/cvat/recording-oracle/alembic.ini index d9f94e0c5a..e5d0ea3454 100644 --- a/packages/examples/cvat/recording-oracle/alembic.ini +++ b/packages/examples/cvat/recording-oracle/alembic.ini @@ -68,11 +68,15 @@ sqlalchemy.url = driver://user:pass@localhost/dbname # on newly generated revision scripts. See the documentation for further # detail and examples -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME +hooks=ruff, ruff_format, types_update + +ruff.type = exec +ruff.executable = ruff +ruff.options = check --fix --unsafe-fixes REVISION_SCRIPT_FILENAME + +ruff_format.type = exec +ruff_format.executable = ruff +ruff_format.options = format REVISION_SCRIPT_FILENAME # Logging configuration [loggers] diff --git a/packages/examples/cvat/recording-oracle/alembic/env.py b/packages/examples/cvat/recording-oracle/alembic/env.py index e6c51d07f3..1e85b134f0 100644 --- a/packages/examples/cvat/recording-oracle/alembic/env.py +++ b/packages/examples/cvat/recording-oracle/alembic/env.py @@ -16,13 +16,12 @@ if config.config_file_name is not None: fileConfig(config.config_file_name) -from src.db import Base +from src.db import Base # noqa: E402 # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -from src.models.webhook import Webhook target_metadata = Base.metadata diff --git a/packages/examples/cvat/recording-oracle/alembic/versions/00271dfae3b1_add_task_iterations.py b/packages/examples/cvat/recording-oracle/alembic/versions/00271dfae3b1_add_task_iterations.py index d16700325f..971891752b 100644 --- a/packages/examples/cvat/recording-oracle/alembic/versions/00271dfae3b1_add_task_iterations.py +++ b/packages/examples/cvat/recording-oracle/alembic/versions/00271dfae3b1_add_task_iterations.py @@ -5,24 +5,25 @@ Create Date: 2024-05-08 18:48:53.897599 """ -from alembic import op + import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = '00271dfae3b1' -down_revision = 'a0c5c3a4c13f' +revision = "00271dfae3b1" +down_revision = "a0c5c3a4c13f" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('tasks', sa.Column('iteration', sa.Integer(), server_default='0', nullable=False)) + op.add_column("tasks", sa.Column("iteration", sa.Integer(), server_default="0", nullable=False)) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('tasks', 'iteration') + op.drop_column("tasks", "iteration") # ### end Alembic commands ### diff --git a/packages/examples/cvat/recording-oracle/alembic/versions/a0c5c3a4c13f_add_gt_stats.py b/packages/examples/cvat/recording-oracle/alembic/versions/a0c5c3a4c13f_add_gt_stats.py index eb6471554d..be55735af5 100644 --- a/packages/examples/cvat/recording-oracle/alembic/versions/a0c5c3a4c13f_add_gt_stats.py +++ b/packages/examples/cvat/recording-oracle/alembic/versions/a0c5c3a4c13f_add_gt_stats.py @@ -5,32 +5,34 @@ Create Date: 2024-03-08 11:34:02.458845 """ -from alembic import op + import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = 'a0c5c3a4c13f' -down_revision = 'ca93dce1a618' +revision = "a0c5c3a4c13f" +down_revision = "ca93dce1a618" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('gt_stats', - sa.Column('task_id', sa.String(), nullable=False), - sa.Column('gt_key', sa.String(), nullable=False), - sa.Column('failed_attempts', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['task_id'], ['tasks.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('task_id', 'gt_key') + op.create_table( + "gt_stats", + sa.Column("task_id", sa.String(), nullable=False), + sa.Column("gt_key", sa.String(), nullable=False), + sa.Column("failed_attempts", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["task_id"], ["tasks.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("task_id", "gt_key"), ) - op.create_index(op.f('ix_gt_stats_gt_key'), 'gt_stats', ['gt_key'], unique=False) + op.create_index(op.f("ix_gt_stats_gt_key"), "gt_stats", ["gt_key"], unique=False) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_gt_stats_gt_key'), table_name='gt_stats') - op.drop_table('gt_stats') + op.drop_index(op.f("ix_gt_stats_gt_key"), table_name="gt_stats") + op.drop_table("gt_stats") # ### end Alembic commands ### diff --git a/packages/examples/cvat/recording-oracle/alembic/versions/ca93dce1a618_init.py b/packages/examples/cvat/recording-oracle/alembic/versions/ca93dce1a618_init.py index b737376fe0..76f102d52c 100644 --- a/packages/examples/cvat/recording-oracle/alembic/versions/ca93dce1a618_init.py +++ b/packages/examples/cvat/recording-oracle/alembic/versions/ca93dce1a618_init.py @@ -1,10 +1,11 @@ """init Revision ID: ca93dce1a618 -Revises: +Revises: Create Date: 2023-09-05 15:02:51.779529 """ + import sqlalchemy as sa from alembic import op diff --git a/packages/examples/cvat/recording-oracle/poetry.lock b/packages/examples/cvat/recording-oracle/poetry.lock index e8fb4542ef..8d5acb95b5 100644 --- a/packages/examples/cvat/recording-oracle/poetry.lock +++ b/packages/examples/cvat/recording-oracle/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -397,52 +397,6 @@ files = [ {file = "bitarray-2.9.2.tar.gz", hash = "sha256:a8f286a51a32323715d77755ed959f94bef13972e9a2fe71b609e40e6d27957e"}, ] -[[package]] -name = "black" -version = "23.12.1" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "boto3" version = "1.34.30" @@ -1870,20 +1824,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "isort" -version = "5.13.2" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, - {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, -] - -[package.extras] -colors = ["colorama (>=0.4.6)"] - [[package]] name = "jmespath" version = "1.0.1" @@ -2466,17 +2406,6 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] -[[package]] -name = "mypy-extensions" -version = "1.0.0" -description = "Type system extensions for programs checked with the mypy type checker." -optional = false -python-versions = ">=3.5" -files = [ - {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, - {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, -] - [[package]] name = "networkx" version = "3.2.1" @@ -2762,17 +2691,6 @@ files = [ [package.dependencies] regex = ">=2022.3.15" -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "pgpy" version = "0.6.0" @@ -2951,6 +2869,8 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -3597,35 +3517,82 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, ] +[[package]] +name = "ruff" +version = "0.6.1" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.6.1-py3-none-linux_armv6l.whl", hash = "sha256:b4bb7de6a24169dc023f992718a9417380301b0c2da0fe85919f47264fb8add9"}, + {file = "ruff-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:45efaae53b360c81043e311cdec8a7696420b3d3e8935202c2846e7a97d4edae"}, + {file = "ruff-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bc60c7d71b732c8fa73cf995efc0c836a2fd8b9810e115be8babb24ae87e0850"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c7477c3b9da822e2db0b4e0b59e61b8a23e87886e727b327e7dcaf06213c5cf"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a0af7ab3f86e3dc9f157a928e08e26c4b40707d0612b01cd577cc84b8905cc9"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:392688dbb50fecf1bf7126731c90c11a9df1c3a4cdc3f481b53e851da5634fa5"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5278d3e095ccc8c30430bcc9bc550f778790acc211865520f3041910a28d0024"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe6d5f65d6f276ee7a0fc50a0cecaccb362d30ef98a110f99cac1c7872df2f18"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e0dd11e2ae553ee5c92a81731d88a9883af8db7408db47fc81887c1f8b672e"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d812615525a34ecfc07fd93f906ef5b93656be01dfae9a819e31caa6cfe758a1"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faaa4060f4064c3b7aaaa27328080c932fa142786f8142aff095b42b6a2eb631"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:99d7ae0df47c62729d58765c593ea54c2546d5de213f2af2a19442d50a10cec9"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9eb18dfd7b613eec000e3738b3f0e4398bf0153cb80bfa3e351b3c1c2f6d7b15"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c62bc04c6723a81e25e71715aa59489f15034d69bf641df88cb38bdc32fd1dbb"}, + {file = "ruff-0.6.1-py3-none-win32.whl", hash = "sha256:9fb4c4e8b83f19c9477a8745e56d2eeef07a7ff50b68a6998f7d9e2e3887bdc4"}, + {file = "ruff-0.6.1-py3-none-win_amd64.whl", hash = "sha256:c2ebfc8f51ef4aca05dad4552bbcf6fe8d1f75b2f6af546cc47cc1c1ca916b5b"}, + {file = "ruff-0.6.1-py3-none-win_arm64.whl", hash = "sha256:3bc81074971b0ffad1bd0c52284b22411f02a11a012082a76ac6da153536e014"}, + {file = "ruff-0.6.1.tar.gz", hash = "sha256:af3ffd8c6563acb8848d33cd19a69b9bfe943667f0419ca083f8ebe4224a3436"}, +] + [[package]] name = "s3transfer" version = "0.10.0" @@ -4208,4 +4175,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "cedfa7b5254c80e53fb8a6a51fed164273319b165c45f085501eb4559db5787e" +content-hash = "eba0dfb48ad8e340536c8dc20a91fb5f1511e15d6c45438197d99ee79a72e85a" diff --git a/packages/examples/cvat/recording-oracle/pyproject.toml b/packages/examples/cvat/recording-oracle/pyproject.toml index b3a5b5180d..2e504f6678 100644 --- a/packages/examples/cvat/recording-oracle/pyproject.toml +++ b/packages/examples/cvat/recording-oracle/pyproject.toml @@ -24,20 +24,121 @@ google-cloud-storage = "^2.14.0" datumaro = {git = "https://github.com/cvat-ai/datumaro.git", rev = "ff83c00c2c1bc4b8fdfcc55067fcab0a9b5b6b11"} [tool.poetry.group.dev.dependencies] -black = "^23.3.0" -pre-commit = "^3.3.3" hypothesis = "^6.82.6" -isort = "^5.12.0" -[tool.isort] -profile = "black" -forced_separate = ["tests"] -line_length = 100 -skip_gitignore = true # align tool behavior with Black +[tool.poetry.group.lint.dependencies] +pre-commit = "^3.3.3" +ruff = "^0.6.0" -[tool.black] +[tool.ruff] line-length = 100 -target-version = ['py310'] +target-version = "py310" + + +[tool.ruff.lint] +select = ["ALL"] +unfixable = [ + "RUF005", # messes up concantenation with numpy structures +] +ignore = [ + "W191", # Rules conflicting with ruff format (https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules) + "E111", # | + "E114", # | + "E117", # | + "D206", # | + "D300", # | + "Q000", # | + "Q001", # | + "Q002", # | + "Q003", # | + "COM812", # | + "COM819", # | + "ISC001", # | + "ISC002", # | + "ANN101", # Method args annotations (mypy will take care of that) + "ANN001", # | + "ANN202", # | + "ANN201", # | + "ANN401", # | + "ANN102", # | + "RUF001", # Allow cyrillic letters in comments + + "B904", # Raise from: modern pythons preserve previous exceptions + "EM", # Forbids using literal strings in exceptions. + # Sujested way of dealing with exceptions increases verbosity + # while giving little to no benefit in readability + "TRY003", # | + "G004", # Forbids using f-strings in logging. This project doesn't rely on lazy % formatting when using logging. + "A003", # Class attribute `id` is shadowing a Python builtin — it's ok in class body + "FIX001", # Forbids using TODOs, but TODOs are useful + "FIX002", # | + "TD001", # | + "TD002", # | + "TD003", # | + "E711", # Allow == None comparisons for sqlalchemy queries + "E712", # Allow == True comparisons for sqlalchemy queries + "PERF203", # Noisy microoptimisation + # Want to resolve eventually, but not now: + "S101", # Allow asserts (there are too many of them right now to fix) + "TRY401", # Checks for excesive logging of exception objects + "G001", # Forbid str.format for logging + "PTH123", # Checks for uses of `os.path.splitext` + "D", # Docstrings + "N806", # Variable in function should be lowercase + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` + "SLF001", # Private member accessed + "F811", # Redefinition of unused + "RUF005", # Consider iterable unpacking instead of concatenation + "A002", # Argument is shadowing a Python builtin + "N818", # Exception name should be named with an Error suffix + "TRY002", # Create your own exception + "ANN003", # Missing type annotation for `**kwargs` + "ANN204", # Missing return type annotation for special method + "ERA001", # Found commented-out code + "N801", # Class name should use CapWords convention + "PLR0915", # Too many statements + "PLR2004", # Magic value used in comparison, consider replacing with a constant variable + "ANN002", # Missing type annotation for `*args` + "TRY300", # Consider moving this statement to an `else` block + "C901", # Function is too complex + "PLW2901", # Variable overwritten by assignment target + "PTH118", # Prefer pathlib instead of os.path + "PTH119", # `os.path.basename()` should be replaced by `Path.name` + "PTH122", # `os.path.splitext()` should be replaced by `Path.suffix`, `Path.stem`, and `Path.parent` + "PTH207", # Replace `glob` with `Path.glob` or `Path.rglob` +] + + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ + "PLR0913", # Annotations and args + "ANN202", # | + "ANN201", # | + "ANN001", # | + "ANN003", # | + "ARG001", # | + "SLF001", # Allow private attrs access + "PLR2004", # Allow magic values + "S", # security + "DTZ005", # allow datetimes without timezones +] +# alembic is not a package in a traditional sense, so putting __init__.py there doesn't make sense +"alembic/*" = ["INP001"] +"__init__.py" = ["F401"] + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + "pydantic.validator", +] + +[tool.ruff.lint.pylint] +max-args = 9 # Lower number might be beneficial to reduce cognitive load. Consider using data containers. + +[tool.ruff.lint.isort] +forced-separate = ["tests"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" [build-system] requires = ["poetry-core"] diff --git a/packages/examples/cvat/recording-oracle/run.py b/packages/examples/cvat/recording-oracle/run.py index 5b59b70a51..2f09b1a09f 100644 --- a/packages/examples/cvat/recording-oracle/run.py +++ b/packages/examples/cvat/recording-oracle/run.py @@ -11,7 +11,7 @@ uvicorn.run( app="src:app", - host="0.0.0.0", + host="0.0.0.0", # noqa: S104 port=int(Config.port), workers=Config.workers_amount, reload=is_dev, diff --git a/packages/examples/cvat/recording-oracle/src/chain/escrow.py b/packages/examples/cvat/recording-oracle/src/chain/escrow.py index 00381be5f4..5651f21e3c 100644 --- a/packages/examples/cvat/recording-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/recording-oracle/src/chain/escrow.py @@ -1,8 +1,7 @@ import json -from typing import List from human_protocol_sdk.constants import ChainId, Status -from human_protocol_sdk.encryption import Encryption, EncryptionUtils +from human_protocol_sdk.encryption import Encryption from human_protocol_sdk.escrow import EscrowClient, EscrowData, EscrowUtils from human_protocol_sdk.storage import StorageUtils @@ -22,9 +21,11 @@ def validate_escrow( chain_id: int, escrow_address: str, *, - accepted_states: List[Status] = [Status.Pending], + accepted_states: list[Status] | None = None, allow_no_funds: bool = False, ) -> None: + if accepted_states is None: + accepted_states = [Status.Pending] assert accepted_states escrow = get_escrow(chain_id, escrow_address) @@ -37,9 +38,8 @@ def validate_escrow( ) ) - if status == Status.Pending and not allow_no_funds: - if int(escrow.balance) == 0: - raise ValueError("Escrow doesn't have funds") + if status == Status.Pending and not allow_no_funds and int(escrow.balance) == 0: + raise ValueError("Escrow doesn't have funds") def get_escrow_manifest(chain_id: int, escrow_address: str) -> dict: diff --git a/packages/examples/cvat/recording-oracle/src/chain/kvstore.py b/packages/examples/cvat/recording-oracle/src/chain/kvstore.py index b91f2ec3a1..0d7c763914 100644 --- a/packages/examples/cvat/recording-oracle/src/chain/kvstore.py +++ b/packages/examples/cvat/recording-oracle/src/chain/kvstore.py @@ -11,9 +11,7 @@ def get_role_by_address(chain_id: int, address: str) -> str: web3 = get_web3(chain_id) kvstore_client = KVStoreClient(web3) - role = kvstore_client.get(address, "role") - - return role + return kvstore_client.get(address, "role") def get_exchange_oracle_url(chain_id: int, escrow_address: str) -> str: diff --git a/packages/examples/cvat/recording-oracle/src/chain/web3.py b/packages/examples/cvat/recording-oracle/src/chain/web3.py index cc427c96b9..e76f77bf95 100644 --- a/packages/examples/cvat/recording-oracle/src/chain/web3.py +++ b/packages/examples/cvat/recording-oracle/src/chain/web3.py @@ -71,9 +71,7 @@ def sign_message(chain_id: Networks, message) -> str: def recover_signer(chain_id: Networks, message, signature: str) -> str: w3 = get_web3(chain_id) message_hash = encode_defunct(text=serialize_message(message)) - signer = w3.eth.account.recover_message(message_hash, signature=signature) - - return signer + return w3.eth.account.recover_message(message_hash, signature=signature) def validate_address(escrow_address: str) -> str: diff --git a/packages/examples/cvat/recording-oracle/src/core/__init__.py b/packages/examples/cvat/recording-oracle/src/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/recording-oracle/src/core/annotation_meta.py b/packages/examples/cvat/recording-oracle/src/core/annotation_meta.py index e19a77efd5..d539787315 100644 --- a/packages/examples/cvat/recording-oracle/src/core/annotation_meta.py +++ b/packages/examples/cvat/recording-oracle/src/core/annotation_meta.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List from pydantic import BaseModel @@ -15,4 +14,4 @@ class JobMeta(BaseModel): class AnnotationMeta(BaseModel): - jobs: List[JobMeta] + jobs: list[JobMeta] diff --git a/packages/examples/cvat/recording-oracle/src/core/config.py b/packages/examples/cvat/recording-oracle/src/core/config.py index c948fe474e..db5534ac4c 100644 --- a/packages/examples/cvat/recording-oracle/src/core/config.py +++ b/packages/examples/cvat/recording-oracle/src/core/config.py @@ -1,8 +1,10 @@ # pylint: disable=too-few-public-methods,missing-class-docstring -""" Project configuration from env vars """ +"""Project configuration from env vars""" + import inspect import os -from typing import ClassVar, Iterable, Optional +from collections.abc import Iterable +from typing import ClassVar from attrs.converters import to_bool from dotenv import load_dotenv @@ -14,7 +16,7 @@ from src.utils.net import is_ipv4 dotenv_path = os.getenv("DOTENV_PATH", None) -if dotenv_path and not os.path.exists(dotenv_path): +if dotenv_path and not os.path.exists(dotenv_path): # noqa: PTH110 raise FileNotFoundError(dotenv_path) load_dotenv(dotenv_path) @@ -28,22 +30,22 @@ def validate(cls) -> None: class Postgres: port = os.environ.get("PG_PORT", "5434") - host = os.environ.get("PG_HOST", "0.0.0.0") + host = os.environ.get("PG_HOST", "0.0.0.0") # noqa: S104 user = os.environ.get("PG_USER", "admin") password = os.environ.get("PG_PASSWORD", "admin") database = os.environ.get("PG_DB", "recording_oracle") lock_timeout = int(os.environ.get("PG_LOCK_TIMEOUT", "3000")) # milliseconds @classmethod - def connection_url(cls): + def connection_url(cls) -> str: return f"postgresql://{cls.user}:{cls.password}@{cls.host}:{cls.port}/{cls.database}" class _NetworkConfig: chain_id: ClassVar[int] - rpc_api: ClassVar[Optional[str]] - private_key: ClassVar[Optional[str]] - addr: ClassVar[Optional[str]] + rpc_api: ClassVar[str | None] + private_key: ClassVar[str | None] + addr: ClassVar[str | None] @classmethod def is_configured(cls) -> bool: @@ -101,12 +103,12 @@ class IStorageConfig: data_bucket_name: ClassVar[str] secure: ClassVar[bool] endpoint_url: ClassVar[str] # TODO: probably should be optional - region: ClassVar[Optional[str]] + region: ClassVar[str | None] # AWS S3 specific attributes - access_key: ClassVar[Optional[str]] - secret_key: ClassVar[Optional[str]] + access_key: ClassVar[str | None] + secret_key: ClassVar[str | None] # GCS specific attributes - key_file_path: ClassVar[Optional[str]] + key_file_path: ClassVar[str | None] @classmethod def get_scheme(cls) -> str: @@ -120,8 +122,7 @@ def provider_endpoint_url(cls) -> str: def bucket_url(cls) -> str: if is_ipv4(cls.endpoint_url): return f"{cls.get_scheme()}{cls.endpoint_url}/{cls.data_bucket_name}/" - else: - return f"{cls.get_scheme()}{cls.data_bucket_name}.{cls.endpoint_url}/" + return f"{cls.get_scheme()}{cls.data_bucket_name}.{cls.endpoint_url}/" class StorageConfig(IStorageConfig): @@ -192,7 +193,7 @@ class ValidationConfig: Each such job will be accepted "blindly", as we can't validate the annotations. """ - max_escrow_iterations = int(os.getenv("MAX_ESCROW_ITERATIONS", 0)) + max_escrow_iterations = int(os.getenv("MAX_ESCROW_ITERATIONS", "0")) """ Maximum escrow annotation-validation iterations. After this, the escrow is finished automatically. @@ -210,12 +211,12 @@ def validate(cls) -> None: ex_prefix = "Wrong server configuration." if (cls.pgp_public_key_url or cls.pgp_passphrase) and not cls.pgp_private_key: - raise Exception(" ".join([ex_prefix, "The PGP_PRIVATE_KEY environment is not set."])) + raise Exception(f"{ex_prefix} The PGP_PRIVATE_KEY environment is not set.") if cls.pgp_private_key: try: Encryption(cls.pgp_private_key, passphrase=cls.pgp_passphrase) - except Exception as ex: + except Exception as ex: # noqa: BLE001 # Possible reasons: # - private key is invalid # - private key is locked but no passphrase is provided @@ -251,9 +252,12 @@ def validate(cls) -> None: attr_or_method.validate() @classmethod - def get_network_configs(cls, only_configured: bool = True) -> Iterable[_NetworkConfig]: + def get_network_configs(cls, *, only_configured: bool = True) -> Iterable[_NetworkConfig]: for attr_or_method in cls.__dict__: attr_or_method = getattr(cls, attr_or_method) - if inspect.isclass(attr_or_method) and issubclass(attr_or_method, _NetworkConfig): - if not only_configured or attr_or_method.is_configured(): - yield attr_or_method + if ( + inspect.isclass(attr_or_method) + and issubclass(attr_or_method, _NetworkConfig) + and (not only_configured or attr_or_method.is_configured()) + ): + yield attr_or_method diff --git a/packages/examples/cvat/recording-oracle/src/core/manifest.py b/packages/examples/cvat/recording-oracle/src/core/manifest.py index c8a4c2260d..c70c1ed91a 100644 --- a/packages/examples/cvat/recording-oracle/src/core/manifest.py +++ b/packages/examples/cvat/recording-oracle/src/core/manifest.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import Enum -from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Literal from pydantic import AnyUrl, BaseModel, Field, root_validator @@ -28,22 +28,22 @@ class AwsBucketUrl(BucketUrlBase, BaseModel): class GcsBucketUrl(BucketUrlBase, BaseModel): provider: Literal[BucketProviders.gcs] - service_account_key: Dict[str, Any] = {} # (optional) Contents of GCS key file + service_account_key: dict[str, Any] = {} # (optional) Contents of GCS key file -BucketUrl = Annotated[Union[AwsBucketUrl, GcsBucketUrl], Field(discriminator="provider")] +BucketUrl = Annotated[AwsBucketUrl | GcsBucketUrl, Field(discriminator="provider")] class DataInfo(BaseModel): - data_url: Union[AnyUrl, BucketUrl] + data_url: AnyUrl | BucketUrl "Bucket URL, AWS S3 | GCS, virtual-hosted-style access" # https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html - points_url: Optional[Union[AnyUrl, BucketUrl]] = None + points_url: AnyUrl | BucketUrl | None = None "A path to an archive with a set of points in COCO Keypoints format, " "which provides information about all objects on images" - boxes_url: Optional[Union[AnyUrl, BucketUrl]] = None + boxes_url: AnyUrl | BucketUrl | None = None "A path to an archive with a set of boxes in COCO Instances format, " "which provides information about all objects on images" @@ -67,7 +67,7 @@ class PlainLabelInfo(LabelInfoBase): class SkeletonLabelInfo(LabelInfoBase): type: Literal[LabelTypes.skeleton] - nodes: List[str] = Field(min_items=1) + nodes: list[str] = Field(min_items=1) """ A list of node label names (only points are supposed to be nodes). Example: @@ -76,7 +76,7 @@ class SkeletonLabelInfo(LabelInfoBase): ] """ - joints: Optional[List[Tuple[int, int]]] = Field(default_factory=list) + joints: list[tuple[int, int]] | None = Field(default_factory=list) "A list of node adjacency, e.g. [[0, 1], [1, 2], [1, 3]]" @root_validator @@ -114,7 +114,7 @@ def validate_type(cls, values: dict) -> dict: return values -LabelInfo = Annotated[Union[PlainLabelInfo, SkeletonLabelInfo], Field(discriminator="type")] +LabelInfo = Annotated[PlainLabelInfo | SkeletonLabelInfo, Field(discriminator="type")] class AnnotationInfo(BaseModel): @@ -132,7 +132,7 @@ class AnnotationInfo(BaseModel): job_size: int = 10 "Frames per job, validation frames are not included" - max_time: Optional[int] = None # deprecated, TODO: mark deprecated with pydantic 2.7+ + max_time: int | None = None # deprecated, TODO: mark deprecated with pydantic 2.7+ "Maximum time per job (assignment) for an annotator, in seconds" @root_validator(pre=True) @@ -161,7 +161,7 @@ class ValidationInfo(BaseModel): val_size: int = Field(default=2, gt=0) "Validation frames per job" - gt_url: Union[AnyUrl, BucketUrl] + gt_url: AnyUrl | BucketUrl "URL to the archive with Ground Truth annotations, the format is COCO keypoints" diff --git a/packages/examples/cvat/recording-oracle/src/core/oracle_events.py b/packages/examples/cvat/recording-oracle/src/core/oracle_events.py index 519152e2de..721e3f0647 100644 --- a/packages/examples/cvat/recording-oracle/src/core/oracle_events.py +++ b/packages/examples/cvat/recording-oracle/src/core/oracle_events.py @@ -1,13 +1,8 @@ -from typing import Optional, Type, Union - from pydantic import BaseModel from src.core.types import ExchangeOracleEventTypes, OracleWebhookTypes, RecordingOracleEventTypes -EventTypeTag = Union[ - ExchangeOracleEventTypes, - RecordingOracleEventTypes, -] +EventTypeTag = ExchangeOracleEventTypes | RecordingOracleEventTypes class OracleEvent(BaseModel): @@ -50,7 +45,7 @@ class ExchangeOracleEvent_TaskFinished(OracleEvent): } -def get_class_for_event_type(event_type: str) -> Type[OracleEvent]: +def get_class_for_event_type(event_type: str) -> type[OracleEvent]: event_class = next((v for k, v in _event_type_map.items() if k == event_type), None) if not event_class: @@ -59,7 +54,7 @@ def get_class_for_event_type(event_type: str) -> Type[OracleEvent]: return event_class -def get_type_tag_for_event_class(event_class: Type[OracleEvent]) -> EventTypeTag: +def get_type_tag_for_event_class(event_class: type[OracleEvent]) -> EventTypeTag: event_type = next((k for k, v in _event_type_map.items() if v == event_class), None) if not event_type: @@ -69,7 +64,7 @@ def get_type_tag_for_event_class(event_class: Type[OracleEvent]) -> EventTypeTag def parse_event( - sender: OracleWebhookTypes, event_type: str, event_data: Optional[dict] = None + sender: OracleWebhookTypes, event_type: str, event_data: dict | None = None ) -> OracleEvent: sender_events_mapping = { OracleWebhookTypes.recording_oracle: RecordingOracleEventTypes, @@ -78,10 +73,10 @@ def parse_event( sender_events = sender_events_mapping.get(sender) if sender_events is not None: - if not event_type in sender_events: + if event_type not in sender_events: raise ValueError(f"Unknown event '{sender}.{event_type}'") else: - assert False, f"Unknown event sender type '{sender}'" + raise AssertionError(f"Unknown event sender type '{sender}'") event_class = get_class_for_event_type(event_type) return event_class.parse_obj(event_data or {}) diff --git a/packages/examples/cvat/recording-oracle/src/core/storage.py b/packages/examples/cvat/recording-oracle/src/core/storage.py index 6c5c53dbe2..939be882fd 100644 --- a/packages/examples/cvat/recording-oracle/src/core/storage.py +++ b/packages/examples/cvat/recording-oracle/src/core/storage.py @@ -7,4 +7,7 @@ def compose_data_bucket_filename(escrow_address: str, chain_id: Networks, filena def compose_results_bucket_filename(escrow_address: str, chain_id: Networks, filename: str) -> str: - return f"{escrow_address}@{chain_id}{Config.exchange_oracle_storage_config.results_dir_suffix}/{filename}" + return ( + f"{escrow_address}@{chain_id}{Config.exchange_oracle_storage_config.results_dir_suffix}" + f"/{filename}" + ) diff --git a/packages/examples/cvat/recording-oracle/src/core/tasks/boxes_from_points.py b/packages/examples/cvat/recording-oracle/src/core/tasks/boxes_from_points.py index c9320473b6..bd26aacc14 100644 --- a/packages/examples/cvat/recording-oracle/src/core/tasks/boxes_from_points.py +++ b/packages/examples/cvat/recording-oracle/src/core/tasks/boxes_from_points.py @@ -1,14 +1,14 @@ import os +from collections.abc import Sequence from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Sequence import attrs import datumaro as dm from attrs import frozen from datumaro.util import dump_json, parse_json -BboxPointMapping = Dict[int, int] +BboxPointMapping = dict[int, int] @frozen @@ -28,7 +28,7 @@ def asdict(self) -> dict: RoiInfos = Sequence[RoiInfo] -RoiFilenames = Dict[int, str] +RoiFilenames = dict[int, str] class TaskMetaLayout: diff --git a/packages/examples/cvat/recording-oracle/src/core/tasks/skeletons_from_boxes.py b/packages/examples/cvat/recording-oracle/src/core/tasks/skeletons_from_boxes.py index b6b680c19a..fb5f09b418 100644 --- a/packages/examples/cvat/recording-oracle/src/core/tasks/skeletons_from_boxes.py +++ b/packages/examples/cvat/recording-oracle/src/core/tasks/skeletons_from_boxes.py @@ -1,14 +1,14 @@ import os +from collections.abc import Sequence from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Sequence, Tuple import attrs import datumaro as dm from attrs import frozen from datumaro.util import dump_json, parse_json -SkeletonBboxMapping = Dict[int, int] +SkeletonBboxMapping = dict[int, int] # TODO: migrate to pydantic @@ -34,9 +34,9 @@ def asdict(self) -> dict: RoiInfos = Sequence[RoiInfo] -RoiFilenames = Dict[int, str] +RoiFilenames = dict[int, str] -PointLabelsMapping = Dict[Tuple[str, str], str] +PointLabelsMapping = dict[tuple[str, str], str] "(skeleton, point) -> job point name" diff --git a/packages/examples/cvat/recording-oracle/src/core/validation_meta.py b/packages/examples/cvat/recording-oracle/src/core/validation_meta.py index 027e3ed419..66b9aca8e0 100644 --- a/packages/examples/cvat/recording-oracle/src/core/validation_meta.py +++ b/packages/examples/cvat/recording-oracle/src/core/validation_meta.py @@ -1,5 +1,3 @@ -from typing import List - from pydantic import BaseModel VALIDATION_METAFILE_NAME = "validation_meta.json" @@ -19,5 +17,5 @@ class ResultMeta(BaseModel): class ValidationMeta(BaseModel): - jobs: List[JobMeta] - results: List[ResultMeta] + jobs: list[JobMeta] + results: list[ResultMeta] diff --git a/packages/examples/cvat/recording-oracle/src/core/validation_results.py b/packages/examples/cvat/recording-oracle/src/core/validation_results.py index 8d78bfc7c8..79d8ba001e 100644 --- a/packages/examples/cvat/recording-oracle/src/core/validation_results.py +++ b/packages/examples/cvat/recording-oracle/src/core/validation_results.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Dict from src.core.validation_errors import DatasetValidationError from src.core.validation_meta import ValidationMeta @@ -14,4 +13,4 @@ class ValidationSuccess: @dataclass class ValidationFailure: - rejected_jobs: Dict[int, DatasetValidationError] + rejected_jobs: dict[int, DatasetValidationError] diff --git a/packages/examples/cvat/recording-oracle/src/crons/process_exchange_oracle_webhooks.py b/packages/examples/cvat/recording-oracle/src/crons/process_exchange_oracle_webhooks.py index 86d1097ece..95f7cd0bd9 100644 --- a/packages/examples/cvat/recording-oracle/src/crons/process_exchange_oracle_webhooks.py +++ b/packages/examples/cvat/recording-oracle/src/crons/process_exchange_oracle_webhooks.py @@ -1,5 +1,4 @@ import logging -from typing import Dict import httpx from sqlalchemy.orm import Session @@ -75,7 +74,7 @@ def handle_exchange_oracle_event(webhook: Webhook, *, db_session: Session, logge ) case _: - assert False, f"Unknown exchange oracle event {webhook.event_type}" + raise AssertionError(f"Unknown exchange oracle event {webhook.event_type}") def process_outgoing_exchange_oracle_webhooks(): diff --git a/packages/examples/cvat/recording-oracle/src/db/__init__.py b/packages/examples/cvat/recording-oracle/src/db/__init__.py index 6e9c85cded..68068770e3 100644 --- a/packages/examples/cvat/recording-oracle/src/db/__init__.py +++ b/packages/examples/cvat/recording-oracle/src/db/__init__.py @@ -8,7 +8,7 @@ engine = sqlalchemy.create_engine( DATABASE_URL, echo="debug" if Config.loglevel <= src.utils.logging.TRACE else False, - connect_args={"options": "-c lock_timeout={:d}".format(Config.postgres_config.lock_timeout)}, + connect_args={"options": f"-c lock_timeout={Config.postgres_config.lock_timeout:d}"}, ) SessionLocal = sessionmaker(autocommit=False, bind=engine) diff --git a/packages/examples/cvat/recording-oracle/src/db/utils.py b/packages/examples/cvat/recording-oracle/src/db/utils.py index 24dfb41561..77651fa42a 100644 --- a/packages/examples/cvat/recording-oracle/src/db/utils.py +++ b/packages/examples/cvat/recording-oracle/src/db/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TypeVar, Union +from typing import TypeVar from sqlalchemy import Select from sqlalchemy.orm import Query @@ -14,14 +14,11 @@ class ForUpdateParams: T = TypeVar("T", Query, Select) -def maybe_for_update(query: T, enable: Union[bool, ForUpdateParams]) -> T: +def maybe_for_update(query: T, enable: bool | ForUpdateParams) -> T: if not enable: return query - if isinstance(enable, ForUpdateParams): - params = enable - else: - params = ForUpdateParams() + params = enable if isinstance(enable, ForUpdateParams) else ForUpdateParams() return query.with_for_update( skip_locked=params.skip_locked, diff --git a/packages/examples/cvat/recording-oracle/src/endpoints/__init__.py b/packages/examples/cvat/recording-oracle/src/endpoints/__init__.py index 4b1802bc67..f3b167a2d7 100644 --- a/packages/examples/cvat/recording-oracle/src/endpoints/__init__.py +++ b/packages/examples/cvat/recording-oracle/src/endpoints/__init__.py @@ -1,4 +1,5 @@ -""" API endpoints """ +"""API endpoints""" + from fastapi import APIRouter, FastAPI from src.core.config import Config @@ -21,11 +22,11 @@ def meta_route() -> MetaResponse: ] return MetaResponse.parse_obj( - dict( - message="Recording Oracle API", - version="0.1.0", - supported_networks=networks_info, - ) + { + "message": "Recording Oracle API", + "version": "0.1.0", + "supported_networks": networks_info, + } ) diff --git a/packages/examples/cvat/recording-oracle/src/endpoints/error_handlers.py b/packages/examples/cvat/recording-oracle/src/endpoints/error_handlers.py index b8d1c424b1..294867ea07 100644 --- a/packages/examples/cvat/recording-oracle/src/endpoints/error_handlers.py +++ b/packages/examples/cvat/recording-oracle/src/endpoints/error_handlers.py @@ -1,4 +1,5 @@ -""" Custom error handlers for the FastAPI""" +"""Custom error handlers for the FastAPI""" + from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse diff --git a/packages/examples/cvat/recording-oracle/src/endpoints/webhook.py b/packages/examples/cvat/recording-oracle/src/endpoints/webhook.py index f5fee7b483..df2e44dc80 100644 --- a/packages/examples/cvat/recording-oracle/src/endpoints/webhook.py +++ b/packages/examples/cvat/recording-oracle/src/endpoints/webhook.py @@ -1,5 +1,3 @@ -from typing import Union - from fastapi import APIRouter, Header, HTTPException, Request import src.services.webhook as oracle_db_service @@ -15,7 +13,7 @@ async def receive_oracle_webhook( webhook: OracleWebhook, request: Request, - human_signature: Union[str, None] = Header(default=None), + human_signature: str | None = Header(default=None), ) -> OracleWebhookResponse: try: sender = await validate_oracle_webhook_signature(request, human_signature, webhook) diff --git a/packages/examples/cvat/recording-oracle/src/handlers/__init__.py b/packages/examples/cvat/recording-oracle/src/handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/recording-oracle/src/handlers/error_handlers.py b/packages/examples/cvat/recording-oracle/src/handlers/error_handlers.py index b8d1c424b1..294867ea07 100644 --- a/packages/examples/cvat/recording-oracle/src/handlers/error_handlers.py +++ b/packages/examples/cvat/recording-oracle/src/handlers/error_handlers.py @@ -1,4 +1,5 @@ -""" Custom error handlers for the FastAPI""" +"""Custom error handlers for the FastAPI""" + from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse diff --git a/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py b/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py index 1ff839f4f7..17d9e2522f 100644 --- a/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py +++ b/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py @@ -6,11 +6,10 @@ from dataclasses import dataclass, field from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, NamedTuple, Optional, Set, Type, TypeVar, Union +from typing import TYPE_CHECKING, NamedTuple, TypeVar import datumaro as dm import numpy as np -from sqlalchemy.orm import Session import src.core.tasks.boxes_from_points as boxes_from_points_task import src.core.tasks.simple as simple_task @@ -18,7 +17,6 @@ import src.services.validation as db_service from src.core.annotation_meta import AnnotationMeta from src.core.config import Config -from src.core.manifest import TaskManifest from src.core.storage import compose_data_bucket_filename from src.core.types import TaskTypes from src.core.validation_errors import DatasetValidationError, LowAccuracyError @@ -37,6 +35,11 @@ TooFewGtError, ) +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from src.core.manifest import TaskManifest + DM_DATASET_FORMAT_MAPPING = { TaskTypes.image_label_binary: "cvat_images", TaskTypes.image_points: "coco_person_keypoints", @@ -54,7 +57,7 @@ } -DATASET_COMPARATOR_TYPE_MAP: Dict[TaskTypes, Type[DatasetComparator]] = { +DATASET_COMPARATOR_TYPE_MAP: dict[TaskTypes, type[DatasetComparator]] = { # TaskType.image_label_binary: TagDatasetComparator, # TODO: implement if support is needed TaskTypes.image_boxes: BboxDatasetComparator, TaskTypes.image_points: PointsDatasetComparator, @@ -62,21 +65,21 @@ TaskTypes.image_skeletons_from_boxes: SkeletonDatasetComparator, } -_JobResults = Dict[int, float] +_JobResults = dict[int, float] -_RejectedJobs = Dict[int, DatasetValidationError] +_RejectedJobs = dict[int, DatasetValidationError] -_FailedGtAttempts = Dict[str, int] +_FailedGtAttempts = dict[str, int] "gt key -> attempts" @dataclass class _UpdatedFailedGtInfo: - failed_jobs: Set[int] = field(default_factory=set) + failed_jobs: set[int] = field(default_factory=set) occurrences: int = 0 -_UpdatedFailedGtStats = Dict[str, _UpdatedFailedGtInfo] +_UpdatedFailedGtStats = dict[str, _UpdatedFailedGtInfo] @dataclass @@ -100,27 +103,27 @@ def __init__( chain_id: int, manifest: TaskManifest, *, - job_annotations: Dict[int, io.IOBase], + job_annotations: dict[int, io.IOBase], merged_annotations: io.IOBase, - gt_stats: Optional[_FailedGtAttempts] = None, - ): + gt_stats: _FailedGtAttempts | None = None, + ) -> None: self.escrow_address = escrow_address self.chain_id = chain_id self.manifest = manifest self._initial_gt_attempts: _FailedGtAttempts = gt_stats or {} - self._job_annotations: Dict[int, io.IOBase] = job_annotations + self._job_annotations: dict[int, io.IOBase] = job_annotations self._merged_annotations: io.IOBase = merged_annotations - self._updated_merged_dataset_archive: Optional[io.IOBase] = None - self._updated_gt_stats: Optional[_UpdatedFailedGtStats] = None - self._job_results: Optional[_JobResults] = None - self._rejected_jobs: Optional[_RejectedJobs] = None + self._updated_merged_dataset_archive: io.IOBase | None = None + self._updated_gt_stats: _UpdatedFailedGtStats | None = None + self._job_results: _JobResults | None = None + self._rejected_jobs: _RejectedJobs | None = None - self._temp_dir: Optional[Path] = None - self._gt_dataset: Optional[dm.Dataset] = None + self._temp_dir: Path | None = None + self._gt_dataset: dm.Dataset | None = None - def _require_field(self, field: Optional[T]) -> T: + def _require_field(self, field: T | None) -> T: assert field is not None return field @@ -133,7 +136,7 @@ def _get_gt_weight(self, failed_attempts: int) -> float: return weight - def _get_gt_weights(self) -> Dict[str, float]: + def _get_gt_weights(self) -> dict[str, float]: weights = {} ban_threshold = Config.validation.gt_ban_threshold @@ -164,7 +167,7 @@ def _parse_gt(self): ) ) - def _load_job_dataset(self, job_id: int, job_dataset_path: Path) -> dm.Dataset: + def _load_job_dataset(self, job_id: int, job_dataset_path: Path) -> dm.Dataset: # noqa: ARG002 manifest = self._require_field(self.manifest) return dm.Dataset.import_from( @@ -217,7 +220,7 @@ def _validate_jobs(self): def _restore_original_image_paths(self, merged_dataset: dm.Dataset) -> dm.Dataset: class RemoveCommonPrefix(dm.ItemTransform): - def __init__(self, extractor: dm.IExtractor, *, prefix: str): + def __init__(self, extractor: dm.IExtractor, *, prefix: str) -> None: super().__init__(extractor) self._prefix = prefix @@ -262,7 +265,7 @@ def _prepare_merged_dataset(self): @classmethod def _put_gt_into_merged_dataset( cls, gt_dataset: dm.Dataset, merged_dataset: dm.Dataset, *, manifest: TaskManifest - ): + ) -> None: """ Updates the merged dataset inplace, writing GT annotations corresponding to the task type. """ @@ -309,7 +312,7 @@ def _put_gt_into_merged_dataset( ) merged_dataset.update(gt_dataset) case _: - assert False, f"Unknown task type {manifest.annotation.type}" + raise AssertionError(f"Unknown task type {manifest.annotation.type}") def validate(self) -> _ValidationResult: with TemporaryDirectory() as tempdir: @@ -331,7 +334,7 @@ class _TaskValidatorWithPerJobGt(_TaskValidator): def _make_gt_dataset_for_job(self, job_id: int, job_dataset: dm.Dataset) -> dm.Dataset: raise NotImplementedError - def _get_gt_weights(self, *, job_cvat_id: int, job_gt_dataset: dm.Dataset) -> Dict[str, float]: + def _get_gt_weights(self, *, job_cvat_id: int, job_gt_dataset: dm.Dataset) -> dict[str, float]: weights = {} ban_threshold = Config.validation.gt_ban_threshold @@ -350,8 +353,12 @@ def _get_gt_weights(self, *, job_cvat_id: int, job_gt_dataset: dm.Dataset) -> Di return weights def _gt_key_to_sample_id( - self, gt_key: str, *, job_cvat_id: int, job_gt_dataset: dm.Dataset - ) -> Optional[str]: + self, + gt_key: str, + *, + job_cvat_id: int, # noqa: ARG002 + job_gt_dataset: dm.Dataset, # noqa: ARG002 + ) -> str | None: return gt_key def _update_gt_stats( @@ -418,7 +425,7 @@ def _validate_jobs(self): class _BoxesFromPointsValidator(_TaskValidatorWithPerJobGt): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) ( @@ -448,7 +455,7 @@ def __init__(self, *args, **kwargs): self._point_key_to_bbox_key = {v: k for k, v in boxes_to_points_mapping.items()} self._roi_info_by_id = {roi_info.point_id: roi_info for roi_info in roi_infos} - self._roi_name_to_roi_info: Dict[str, boxes_from_points_task.RoiInfo] = { + self._roi_name_to_roi_info: dict[str, boxes_from_points_task.RoiInfo] = { os.path.splitext(roi_filename)[0]: self._roi_info_by_id[roi_id] for roi_id, roi_filename in roi_filenames.items() } @@ -521,7 +528,7 @@ def _download_task_meta(self): return boxes_to_points_mapping, roi_filenames, rois, gt_dataset, points_dataset - def _make_gt_dataset_for_job(self, job_id: int, job_dataset: dm.Dataset) -> dm.Dataset: + def _make_gt_dataset_for_job(self, job_id: int, job_dataset: dm.Dataset) -> dm.Dataset: # noqa: ARG002 job_gt_dataset = dm.Dataset(categories=self._gt_dataset.categories(), media_type=dm.Image) for job_sample in job_dataset: @@ -554,7 +561,7 @@ def _prepare_merged_dataset(self): class _SkeletonsFromBoxesValidator(_TaskValidatorWithPerJobGt): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) ( @@ -589,7 +596,7 @@ def __init__(self, *args, **kwargs): self._bbox_key_to_skeleton_key = {v: k for k, v in skeletons_to_boxes_mapping.items()} self._roi_info_by_id = {roi_info.bbox_id: roi_info for roi_info in roi_infos} - self._roi_name_to_roi_info: Dict[str, skeletons_from_boxes_task.RoiInfo] = { + self._roi_name_to_roi_info: dict[str, skeletons_from_boxes_task.RoiInfo] = { os.path.splitext(roi_filename)[0]: self._roi_info_by_id[roi_id] for roi_id, roi_filename in roi_filenames.items() } @@ -731,7 +738,7 @@ def _load_job_dataset(self, job_id: int, job_dataset_path: Path) -> dm.Dataset: return updated_dataset - def _make_gt_dataset_for_job(self, job_id: int, job_dataset: dm.Dataset) -> dm.Dataset: + def _make_gt_dataset_for_job(self, job_id: int, job_dataset: dm.Dataset) -> dm.Dataset: # noqa: ARG002 job_label_cat: dm.LabelCategories = job_dataset.categories()[dm.AnnotationType.label] assert len(job_label_cat) == 2 job_skeleton_label_id, job_skeleton_label = next( @@ -832,8 +839,8 @@ class _LabelId(NamedTuple): def _get_gt_dataset_label_id(self, job_gt_dataset: dm.Dataset) -> _LabelId: label_cat: dm.LabelCategories = job_gt_dataset.categories()[dm.AnnotationType.label] assert len(label_cat) == 2 - job_skeleton_label = next(l for l in label_cat if not l.parent) - job_point_label = next(l for l in label_cat if l.parent) + job_skeleton_label = next(label for label in label_cat if not label.parent) + job_point_label = next(label for label in label_cat if label.parent) return self._LabelId( *next( @@ -846,8 +853,12 @@ def _get_gt_dataset_label_id(self, job_gt_dataset: dm.Dataset) -> _LabelId: ) def _gt_key_to_sample_id( - self, gt_key: str, *, job_cvat_id: int, job_gt_dataset: dm.Dataset - ) -> Optional[str]: + self, + gt_key: str, + *, + job_cvat_id: int, # noqa: ARG002 + job_gt_dataset: dm.Dataset, + ) -> str | None: parsed_gt_key = self._parse_gt_key(gt_key) job_label_id = self._get_gt_dataset_label_id(job_gt_dataset) if (parsed_gt_key.skeleton_id, parsed_gt_key.point_id) != job_label_id: @@ -904,17 +915,17 @@ def _compute_gt_stats_update( return updated_gt_stats -def process_intermediate_results( +def process_intermediate_results( # noqa: PLR0912 session: Session, *, escrow_address: str, chain_id: int, meta: AnnotationMeta, - job_annotations: Dict[int, io.RawIOBase], + job_annotations: dict[int, io.RawIOBase], merged_annotations: io.RawIOBase, manifest: TaskManifest, logger: logging.Logger, -) -> Union[ValidationSuccess, ValidationFailure]: +) -> ValidationSuccess | ValidationFailure: # actually validate jobs task_type = manifest.annotation.type @@ -979,7 +990,7 @@ def process_intermediate_results( db_service.update_gt_stats(session, task.id, updated_gt_stats) - job_final_result_ids: Dict[int, str] = {} + job_final_result_ids: dict[int, str] = {} for job_meta in meta.jobs: job = db_service.get_job_by_cvat_id(session, job_meta.job_id) if not job: @@ -1006,13 +1017,12 @@ def process_intermediate_results( should_complete = False - if 0 < Config.validation.max_escrow_iterations: + if Config.validation.max_escrow_iterations > 0: escrow_iteration = task.iteration if escrow_iteration and Config.validation.max_escrow_iterations <= escrow_iteration: logger.info( - "Validation for escrow_address={}: too many iterations, stopping annotation".format( - escrow_address - ) + f"Validation for escrow_address={escrow_address}:" + f" too many iterations, stopping annotation" ) should_complete = True @@ -1028,24 +1038,18 @@ def process_intermediate_results( < unverifiable_jobs_count ): logger.info( - "Validation for escrow_address={}: " - "too many assignments have insufficient GT for validation ({} of {} ({:.2f}%)), " - "stopping annotation".format( - escrow_address, - unverifiable_jobs_count, - total_jobs, - unverifiable_jobs_count / total_jobs * 100, - ) + f"Validation for escrow_address={escrow_address}: " + f"too many assignments have insufficient GT for validation " + f"({unverifiable_jobs_count} of {total_jobs} " + f"({unverifiable_jobs_count / total_jobs * 100:.2f}%)), stopping annotation" ) should_complete = True elif len(rejected_jobs) == unverifiable_jobs_count: if unverifiable_jobs_count: logger.info( - "Validation for escrow_address={}: " - "only unverifiable assignments left ({}), stopping annotation".format( - escrow_address, - unverifiable_jobs_count, - ) + f"Validation for escrow_address={escrow_address}: " + f"only unverifiable assignments left ({unverifiable_jobs_count})," + f" stopping annotation" ) should_complete = True @@ -1082,7 +1086,7 @@ def process_intermediate_results( validation_meta=validation_meta, resulting_annotations=updated_merged_dataset_archive.getvalue(), average_quality=np.mean( - list(v for v in job_results.values() if v != _TaskValidator.UNKNOWN_QUALITY and v >= 0) + [v for v in job_results.values() if v != _TaskValidator.UNKNOWN_QUALITY and v >= 0] or [0] ), ) diff --git a/packages/examples/cvat/recording-oracle/src/handlers/validation.py b/packages/examples/cvat/recording-oracle/src/handlers/validation.py index e398e91a6f..b6be970315 100644 --- a/packages/examples/cvat/recording-oracle/src/handlers/validation.py +++ b/packages/examples/cvat/recording-oracle/src/handlers/validation.py @@ -2,14 +2,13 @@ import os from collections import Counter from logging import Logger -from typing import Dict, Optional, Union from sqlalchemy.orm import Session -import src.chain.escrow as escrow import src.core.annotation_meta as annotation import src.core.validation_meta as validation import src.services.webhook as oracle_db_service +from src.chain import escrow from src.core.config import Config from src.core.manifest import TaskManifest, parse_manifest from src.core.oracle_events import ( @@ -48,9 +47,9 @@ def __init__( self.data_bucket = BucketAccessInfo.parse_obj(Config.exchange_oracle_storage_config) - self.annotation_meta: Optional[annotation.AnnotationMeta] = None - self.job_annotations: Optional[Dict[int, bytes]] = None - self.merged_annotations: Optional[bytes] = None + self.annotation_meta: annotation.AnnotationMeta | None = None + self.job_annotations: dict[int, bytes] | None = None + self.merged_annotations: bytes | None = None def set_logger(self, logger: Logger): self.logger = logger @@ -94,7 +93,7 @@ def _download_results(self): self._download_results_meta() self._download_annotations() - ValidationResult = Union[ValidationSuccess, ValidationFailure] + ValidationResult = ValidationSuccess | ValidationFailure def _process_annotation_results(self) -> ValidationResult: assert self.annotation_meta is not None @@ -159,7 +158,7 @@ def _handle_validation_result(self, validation_result: ValidationResult): escrow.store_results( chain_id, escrow_address, - Config.storage_config.bucket_url() + os.path.dirname(recor_merged_annotations_path), + Config.storage_config.bucket_url() + os.path.dirname(recor_merged_annotations_path), # noqa: PTH120 compute_resulting_annotations_hash(validation_result.resulting_annotations), ) @@ -194,13 +193,13 @@ def _handle_validation_result(self, validation_result: ValidationResult): OracleWebhookTypes.exchange_oracle, event=RecordingOracleEvent_TaskRejected( # TODO: update wrt. M2 API changes, send reason - rejected_job_ids=list( + rejected_job_ids=[ jid for jid, reason in validation_result.rejected_jobs.items() if not isinstance( reason, TooFewGtError ) # prevent such jobs from reannotation, can also be handled in ExcOr - ) + ] ), ) diff --git a/packages/examples/cvat/recording-oracle/src/log.py b/packages/examples/cvat/recording-oracle/src/log.py index c533f89d0e..9a2f7e4e7e 100644 --- a/packages/examples/cvat/recording-oracle/src/log.py +++ b/packages/examples/cvat/recording-oracle/src/log.py @@ -1,4 +1,5 @@ -""" Config for the application logger""" +"""Config for the application logger""" + import logging from logging.config import dictConfig diff --git a/packages/examples/cvat/recording-oracle/src/models/__init__.py b/packages/examples/cvat/recording-oracle/src/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/recording-oracle/src/models/validation.py b/packages/examples/cvat/recording-oracle/src/models/validation.py index 80a250891f..51f8f681f5 100644 --- a/packages/examples/cvat/recording-oracle/src/models/validation.py +++ b/packages/examples/cvat/recording-oracle/src/models/validation.py @@ -1,8 +1,6 @@ # pylint: disable=too-few-public-methods from __future__ import annotations -from typing import List - from sqlalchemy import Column, DateTime, Enum, Float, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, relationship from sqlalchemy.sql import func @@ -20,10 +18,10 @@ class Task(Base): updated_at = Column(DateTime(timezone=True), onupdate=func.now()) iteration = Column(Integer, server_default="0", nullable=False) - jobs: Mapped[List["Job"]] = relationship( + jobs: Mapped[list[Job]] = relationship( back_populates="task", cascade="all, delete", passive_deletes=True ) - gt_stats: Mapped[List["GtStats"]] = relationship( + gt_stats: Mapped[list[GtStats]] = relationship( back_populates="task", cascade="all, delete", passive_deletes=True ) @@ -34,8 +32,8 @@ class Job(Base): cvat_id = Column(Integer, unique=True, index=True, nullable=False) task_id = Column(String, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False) - task: Mapped["Task"] = relationship(back_populates="jobs") - validation_results: Mapped[List["ValidationResult"]] = relationship( + task: Mapped[Task] = relationship(back_populates="jobs") + validation_results: Mapped[list[ValidationResult]] = relationship( back_populates="job", cascade="all, delete", passive_deletes=True ) @@ -48,7 +46,7 @@ class ValidationResult(Base): annotator_wallet_address = Column(String, nullable=False) annotation_quality = Column(Float, nullable=False) - job: Mapped["Job"] = relationship(back_populates="validation_results") + job: Mapped[Job] = relationship(back_populates="validation_results") class GtStats(Base): @@ -64,4 +62,4 @@ class GtStats(Base): failed_attempts = Column(Integer, default=0, nullable=False) - task: Mapped["Task"] = relationship(back_populates="gt_stats") + task: Mapped[Task] = relationship(back_populates="gt_stats") diff --git a/packages/examples/cvat/recording-oracle/src/models/webhook.py b/packages/examples/cvat/recording-oracle/src/models/webhook.py index 10c4bec69c..64ca2fb0e7 100644 --- a/packages/examples/cvat/recording-oracle/src/models/webhook.py +++ b/packages/examples/cvat/recording-oracle/src/models/webhook.py @@ -26,5 +26,5 @@ class Webhook(Base): event_data = Column(JSON, nullable=True, server_default=None) direction = Column(String, nullable=False) - def __repr__(self): + def __repr__(self) -> str: return f"Webhook. id={self.id} type={self.type}.{self.event_type}" diff --git a/packages/examples/cvat/recording-oracle/src/schemas/__init__.py b/packages/examples/cvat/recording-oracle/src/schemas/__init__.py index a0811a145d..e26cbee23e 100644 --- a/packages/examples/cvat/recording-oracle/src/schemas/__init__.py +++ b/packages/examples/cvat/recording-oracle/src/schemas/__init__.py @@ -1,7 +1,4 @@ -# pylint: disable=too-few-public-methods -""" Schema for API input&output""" - -from typing import List, Optional +"""Schema for API input&output""" from pydantic import BaseModel @@ -27,7 +24,7 @@ class ResponseError(BaseModel): class SupportedNetwork(BaseModel): chain_id: int - addr: Optional[str] + addr: str | None class MetaResponse(BaseModel): @@ -35,4 +32,4 @@ class MetaResponse(BaseModel): message: str version: str - supported_networks: List[SupportedNetwork] + supported_networks: list[SupportedNetwork] diff --git a/packages/examples/cvat/recording-oracle/src/schemas/webhook.py b/packages/examples/cvat/recording-oracle/src/schemas/webhook.py index 62ae808f0a..c725d2cfa0 100644 --- a/packages/examples/cvat/recording-oracle/src/schemas/webhook.py +++ b/packages/examples/cvat/recording-oracle/src/schemas/webhook.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel, validator @@ -11,8 +10,8 @@ class OracleWebhook(BaseModel): escrow_address: str chain_id: Networks event_type: str - event_data: Optional[dict] = None - timestamp: Optional[datetime] = None # TODO: remove optional + event_data: dict | None = None + timestamp: datetime | None = None # TODO: remove optional @validator("escrow_address", allow_reuse=True) def validate_escrow_(cls, value): diff --git a/packages/examples/cvat/recording-oracle/src/services/cloud/client.py b/packages/examples/cvat/recording-oracle/src/services/cloud/client.py index 5bb92d77e3..53cf13dcc0 100644 --- a/packages/examples/cvat/recording-oracle/src/services/cloud/client.py +++ b/packages/examples/cvat/recording-oracle/src/services/cloud/client.py @@ -1,37 +1,29 @@ from abc import ABCMeta, abstractmethod -from typing import List, Optional from urllib.parse import unquote class StorageClient(metaclass=ABCMeta): def __init__( self, - bucket: Optional[str] = None, + bucket: str | None = None, ) -> None: self._bucket = unquote(bucket) if bucket else None @abstractmethod - def create_file(self, key: str, data: bytes = b"", *, bucket: Optional[str] = None): - ... + def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None): ... @abstractmethod - def remove_file(self, key: str, *, bucket: Optional[str] = None): - ... + def remove_file(self, key: str, *, bucket: str | None = None): ... @abstractmethod - def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: - ... + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: ... @abstractmethod - def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: - ... + def download_file(self, key: str, *, bucket: str | None = None) -> bytes: ... @abstractmethod - def list_files( - self, *, bucket: Optional[str] = None, prefix: Optional[str] = None - ) -> List[str]: - ... + def list_files(self, *, bucket: str | None = None, prefix: str | None = None) -> list[str]: ... @staticmethod - def normalize_prefix(prefix: Optional[str]) -> Optional[str]: + def normalize_prefix(prefix: str | None) -> str | None: return unquote(prefix).strip("/\\") + "/" if prefix else prefix diff --git a/packages/examples/cvat/recording-oracle/src/services/cloud/gcs.py b/packages/examples/cvat/recording-oracle/src/services/cloud/gcs.py index 36611b363f..014be4ce1a 100644 --- a/packages/examples/cvat/recording-oracle/src/services/cloud/gcs.py +++ b/packages/examples/cvat/recording-oracle/src/services/cloud/gcs.py @@ -1,5 +1,4 @@ from io import BytesIO -from typing import Dict, List, Optional from urllib.parse import unquote from google.cloud import storage @@ -13,8 +12,8 @@ class GcsClient(StorageClient): def __init__( self, *, - bucket: Optional[str] = None, - service_account_key: Optional[Dict] = None, + bucket: str | None = None, + service_account_key: dict | None = None, ) -> None: super().__init__(bucket) @@ -23,22 +22,22 @@ def __init__( else: self.client = storage.Client.create_anonymous_client() - def create_file(self, key: str, data: bytes = b"", *, bucket: Optional[str] = None) -> None: + def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None) -> None: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) bucket_client.blob(unquote(key)).upload_from_string(data) - def remove_file(self, key: str, *, bucket: Optional[str] = None) -> None: + def remove_file(self, key: str, *, bucket: str | None = None) -> None: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) bucket_client.delete_blob(unquote(key)) - def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) return bucket_client.blob(unquote(key)).exists() - def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: + def download_file(self, key: str, *, bucket: str | None = None) -> bytes: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) blob = bucket_client.blob(unquote(key)) @@ -47,9 +46,7 @@ def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: self.client.download_blob_to_file(blob, data) return data.getvalue() - def list_files( - self, *, bucket: Optional[str] = None, prefix: Optional[str] = None - ) -> List[str]: + def list_files(self, *, bucket: str | None = None, prefix: str | None = None) -> list[str]: bucket = unquote(bucket) if bucket else self._bucket prefix = self.normalize_prefix(prefix) diff --git a/packages/examples/cvat/recording-oracle/src/services/cloud/s3.py b/packages/examples/cvat/recording-oracle/src/services/cloud/s3.py index e8e608ce99..19b570216a 100644 --- a/packages/examples/cvat/recording-oracle/src/services/cloud/s3.py +++ b/packages/examples/cvat/recording-oracle/src/services/cloud/s3.py @@ -1,5 +1,4 @@ from io import BytesIO -from typing import List, Optional from urllib.parse import unquote import boto3 @@ -15,15 +14,15 @@ class S3Client(StorageClient): def __init__( self, *, - bucket: Optional[str] = None, - access_key: Optional[str] = None, - secret_key: Optional[str] = None, - endpoint_url: Optional[str] = None, + bucket: str | None = None, + access_key: str | None = None, + secret_key: str | None = None, + endpoint_url: str | None = None, ) -> None: super().__init__(bucket) session = boto3.Session( - **(dict(aws_access_key_id=access_key) if access_key else {}), - **(dict(aws_secret_access_key=secret_key) if secret_key else {}), + **({"aws_access_key_id": access_key} if access_key else {}), + **({"aws_secret_access_key": secret_key} if secret_key else {}), ) s3 = session.resource( "s3", **({"endpoint_url": unquote(endpoint_url)} if endpoint_url else {}) @@ -34,15 +33,15 @@ def __init__( if not access_key and not secret_key: self.client.meta.events.register("choose-signer.s3.*", disable_signing) - def create_file(self, key: str, data: bytes = b"", *, bucket: Optional[str] = None): + def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None): bucket = unquote(bucket) if bucket else self._bucket self.client.put_object(Body=data, Bucket=bucket, Key=unquote(key)) - def remove_file(self, key: str, *, bucket: Optional[str] = None): + def remove_file(self, key: str, *, bucket: str | None = None): bucket = unquote(bucket) if bucket else self._bucket self.client.delete_object(Bucket=bucket, Key=unquote(key)) - def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: bucket = unquote(bucket) if bucket else self._bucket try: self.client.head_object(Bucket=bucket, Key=unquote(key)) @@ -50,22 +49,16 @@ def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: except ClientError as e: if e.response["Error"]["Code"] == "404": return False - else: - raise + raise - def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: + def download_file(self, key: str, *, bucket: str | None = None) -> bytes: bucket = unquote(bucket) if bucket else self._bucket with BytesIO() as data: self.client.download_fileobj(Bucket=bucket, Key=unquote(key), Fileobj=data) return data.getvalue() - def list_files( - self, *, bucket: Optional[str] = None, prefix: Optional[str] = None - ) -> List[str]: + def list_files(self, *, bucket: str | None = None, prefix: str | None = None) -> list[str]: bucket = unquote(bucket) if bucket else self._bucket objects = self.resource.Bucket(bucket).objects - if prefix: - objects = objects.filter(Prefix=self.normalize_prefix(prefix)) - else: - objects = objects.all() + objects = objects.filter(Prefix=self.normalize_prefix(prefix)) if prefix else objects.all() return [file_info.key for file_info in objects] diff --git a/packages/examples/cvat/recording-oracle/src/services/cloud/types.py b/packages/examples/cvat/recording-oracle/src/services/cloud/types.py index ec8cd9df9e..8a4c318966 100644 --- a/packages/examples/cvat/recording-oracle/src/services/cloud/types.py +++ b/packages/examples/cvat/recording-oracle/src/services/cloud/types.py @@ -4,7 +4,6 @@ from dataclasses import asdict, dataclass, is_dataclass from enum import Enum, auto from inspect import isclass -from typing import Dict, Optional, Type, Union from urllib.parse import urlparse from src.core import manifest @@ -31,14 +30,14 @@ def from_str(cls, provider: str) -> CloudProviders: class BucketCredentials: - def to_dict(self) -> Dict: + def to_dict(self) -> dict: if not is_dataclass(self): raise NotImplementedError return asdict(self) @classmethod - def from_storage_config(cls, config: Type[IStorageConfig]) -> Optional[BucketCredentials]: + def from_storage_config(cls, config: type[IStorageConfig]) -> BucketCredentials | None: credentials = None if (config.access_key or config.secret_key) and config.provider.lower() != "aws": @@ -46,9 +45,7 @@ def from_storage_config(cls, config: Type[IStorageConfig]) -> Optional[BucketCre "Invalid storage configuration. The access_key/secret_key pair" f"cannot be specified with {config.provider} provider" ) - elif ( - bool(config.access_key) ^ bool(config.secret_key) - ) and config.provider.lower() == "aws": + if (bool(config.access_key) ^ bool(config.secret_key)) and config.provider.lower() == "aws": raise ValueError( "Invalid storage configuration. " "Either none or both access_key and secret_key must be specified for an AWS storage" @@ -71,7 +68,7 @@ def from_storage_config(cls, config: Type[IStorageConfig]) -> Optional[BucketCre @dataclass class GcsBucketCredentials(BucketCredentials): - service_account_key: Dict + service_account_key: dict @dataclass @@ -85,8 +82,8 @@ class BucketAccessInfo: provider: CloudProviders host_url: str bucket_name: str - path: Optional[str] = None - credentials: Optional[BucketCredentials] = None + path: str | None = None + credentials: BucketCredentials | None = None @classmethod def from_url(cls, url: str) -> BucketAccessInfo: @@ -100,7 +97,7 @@ def from_url(cls, url: str) -> BucketAccessInfo: bucket_name=parsed_url.netloc.split(".")[0], path=parsed_url.path.lstrip("/"), ) - elif parsed_url.netloc.endswith(DEFAULT_GCS_HOST): + if parsed_url.netloc.endswith(DEFAULT_GCS_HOST): # Google Cloud Storage (GCS) bucket # Virtual hosted-style is expected: # https://BUCKET_NAME.storage.googleapis.com/OBJECT_NAME @@ -110,7 +107,7 @@ def from_url(cls, url: str) -> BucketAccessInfo: host_url=f"{parsed_url.scheme}://{DEFAULT_GCS_HOST}", path=parsed_url.path.lstrip("/"), ) - elif Config.features.enable_custom_cloud_host: + if Config.features.enable_custom_cloud_host: if is_ipv4(parsed_url.netloc): host = parsed_url.netloc bucket_name, path = parsed_url.path.lstrip("/").split("/", maxsplit=1) @@ -125,11 +122,10 @@ def from_url(cls, url: str) -> BucketAccessInfo: bucket_name=bucket_name, path=path, ) - else: - raise ValueError(f"{parsed_url.netloc} cloud provider is not supported.") + raise ValueError(f"{parsed_url.netloc} cloud provider is not supported.") @classmethod - def _from_dict(cls, data: Dict) -> BucketAccessInfo: + def _from_dict(cls, data: dict) -> BucketAccessInfo: for required_field in ( "provider", "bucket_name", @@ -159,7 +155,7 @@ def _from_dict(cls, data: Dict) -> BucketAccessInfo: return BucketAccessInfo(**data) @classmethod - def from_storage_config(cls, config: Type[IStorageConfig]) -> BucketAccessInfo: + def from_storage_config(cls, config: type[IStorageConfig]) -> BucketAccessInfo: credentials = BucketCredentials.from_storage_config(config) return BucketAccessInfo( @@ -174,14 +170,12 @@ def from_bucket_url(cls, bucket_url: manifest.BucketUrl) -> BucketAccessInfo: return cls._from_dict(bucket_url.dict()) @classmethod - def parse_obj( - cls, data: Union[str, Type[IStorageConfig], manifest.BucketUrl] - ) -> BucketAccessInfo: + def parse_obj(cls, data: str | type[IStorageConfig] | manifest.BucketUrl) -> BucketAccessInfo: if isinstance(data, manifest.BucketUrlBase): return cls.from_bucket_url(data) - elif isinstance(data, str): + if isinstance(data, str): return cls.from_url(data) - elif isclass(data) and issubclass(data, IStorageConfig): + if isclass(data) and issubclass(data, IStorageConfig): return cls.from_storage_config(data) raise TypeError(f"Unsupported data type ({type(data)}) was provided") diff --git a/packages/examples/cvat/recording-oracle/src/services/cloud/utils.py b/packages/examples/cvat/recording-oracle/src/services/cloud/utils.py index a9f821d174..bfc23305c7 100644 --- a/packages/examples/cvat/recording-oracle/src/services/cloud/utils.py +++ b/packages/examples/cvat/recording-oracle/src/services/cloud/utils.py @@ -1,5 +1,3 @@ -from typing import Optional - from src.services.cloud.client import StorageClient from src.services.cloud.gcs import DEFAULT_GCS_HOST, GcsClient from src.services.cloud.s3 import DEFAULT_S3_HOST, S3Client @@ -7,7 +5,7 @@ def compose_bucket_url( - bucket_name: str, provider: CloudProviders, *, bucket_host: Optional[str] = None + bucket_name: str, provider: CloudProviders, *, bucket_host: str | None = None ) -> str: match provider: case CloudProviders.aws: diff --git a/packages/examples/cvat/recording-oracle/src/services/validation.py b/packages/examples/cvat/recording-oracle/src/services/validation.py index ba139d3fec..4204bd90df 100644 --- a/packages/examples/cvat/recording-oracle/src/services/validation.py +++ b/packages/examples/cvat/recording-oracle/src/services/validation.py @@ -1,5 +1,4 @@ import uuid -from typing import Dict, List, Optional, Union from sqlalchemy import update from sqlalchemy.orm import Session @@ -20,8 +19,8 @@ def create_task(session: Session, escrow_address: str, chain_id: int) -> str: def get_task_by_escrow_address( - session: Session, escrow_address: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Task]: + session: Session, escrow_address: str, *, for_update: bool | ForUpdateParams = False +) -> Task | None: return ( _maybe_for_update(session.query(Task), enable=for_update) .where(Task.escrow_address == escrow_address) @@ -30,16 +29,16 @@ def get_task_by_escrow_address( def get_task_by_id( - session: Session, task_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Task]: + session: Session, task_id: str, *, for_update: bool | ForUpdateParams = False +) -> Task | None: return ( _maybe_for_update(session.query(Task), enable=for_update).where(Task.id == task_id).first() ) def get_task_validation_results( - session: Session, task_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> List[ValidationResult]: + session: Session, task_id: str, *, for_update: bool | ForUpdateParams = False +) -> list[ValidationResult]: return ( _maybe_for_update(session.query(ValidationResult), enable=for_update) .where(ValidationResult.job.has(Job.task_id == task_id)) @@ -67,8 +66,8 @@ def create_job(session: Session, job_cvat_id: int, task_id: str) -> str: def get_job_by_cvat_id( - session: Session, job_cvat_id: int, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Job]: + session: Session, job_cvat_id: int, *, for_update: bool | ForUpdateParams = False +) -> Job | None: return ( _maybe_for_update(session.query(Job), enable=for_update) .where(Job.cvat_id == job_cvat_id) @@ -77,8 +76,8 @@ def get_job_by_cvat_id( def get_job_by_id( - session: Session, job_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Job]: + session: Session, job_id: str, *, for_update: bool | ForUpdateParams = False +) -> Job | None: return _maybe_for_update(session.query(Job), enable=for_update).where(Job.id == job_id).first() @@ -104,8 +103,8 @@ def create_validation_result( def get_validation_result_by_assignment_id( - session: Session, assignment_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[ValidationResult]: + session: Session, assignment_id: str, *, for_update: bool | ForUpdateParams = False +) -> ValidationResult | None: return ( _maybe_for_update(session.query(ValidationResult), enable=for_update) .where(ValidationResult.assignment_id == assignment_id) @@ -114,8 +113,8 @@ def get_validation_result_by_assignment_id( def get_task_gt_stats( - session: Session, task_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> List[GtStats]: + session: Session, task_id: str, *, for_update: bool | ForUpdateParams = False +) -> list[GtStats]: return ( _maybe_for_update(session.query(GtStats), enable=for_update) .where(GtStats.task_id == task_id) @@ -123,7 +122,7 @@ def get_task_gt_stats( ) -def update_gt_stats(session: Session, task_id: str, values: Dict[str, int]): +def update_gt_stats(session: Session, task_id: str, values: dict[str, int]): # Read more about upsert: # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements @@ -138,11 +137,11 @@ def update_gt_stats(session: Session, task_id: str, values: Dict[str, int]): statement = psql_insert(GtStats).values( [ - dict( - task_id=task_id, - gt_key=gt_key, - failed_attempts=failed_attempts, - ) + { + "task_id": task_id, + "gt_key": gt_key, + "failed_attempts": failed_attempts, + } for gt_key, failed_attempts in values.items() ], ) diff --git a/packages/examples/cvat/recording-oracle/src/services/webhook.py b/packages/examples/cvat/recording-oracle/src/services/webhook.py index e14af05321..41e9ca1c10 100644 --- a/packages/examples/cvat/recording-oracle/src/services/webhook.py +++ b/packages/examples/cvat/recording-oracle/src/services/webhook.py @@ -1,7 +1,6 @@ import datetime import uuid from enum import Enum -from typing import List, Optional, Union from attrs import define from sqlalchemy import case, update @@ -26,7 +25,7 @@ class OracleWebhookDirectionTags(str, Enum, metaclass=BetterEnumMeta): @define class OracleWebhookQueue: direction: OracleWebhookDirectionTags - default_sender: Optional[OracleWebhookTypes] = None + default_sender: OracleWebhookTypes | None = None def create_webhook( self, @@ -34,10 +33,10 @@ def create_webhook( escrow_address: str, chain_id: int, type: OracleWebhookTypes, - signature: Optional[str] = None, - event_type: Optional[str] = None, - event_data: Optional[dict] = None, - event: Optional[OracleEvent] = None, + signature: str | None = None, + event_type: str | None = None, + event_data: dict | None = None, + event: OracleEvent | None = None, ) -> str: """ Creates a webhook in a database @@ -45,7 +44,7 @@ def create_webhook( assert not event_data or event_type, "'event_data' requires 'event_type'" assert bool(event) ^ bool( event_type - ), f"'event' and 'event_type' cannot be used together. Please use only one of the fields" + ), "'event' and 'event_type' cannot be used together. Please use only one of the fields" if event_type: if self.direction == OracleWebhookDirectionTags.incoming: @@ -60,7 +59,7 @@ def create_webhook( if self.direction == OracleWebhookDirectionTags.incoming and not signature: raise ValueError("Webhook signature must be specified for incoming events") - elif self.direction == OracleWebhookDirectionTags.outgoing and signature: + if self.direction == OracleWebhookDirectionTags.outgoing and signature: raise ValueError("Webhook signature must not be specified for outgoing events") if signature: @@ -93,9 +92,9 @@ def get_pending_webhooks( type: OracleWebhookTypes, *, limit: int = 10, - for_update: Union[bool, ForUpdateParams] = False, - ) -> List[Webhook]: - webhooks = ( + for_update: bool | ForUpdateParams = False, + ) -> list[Webhook]: + return ( _maybe_for_update(session.query(Webhook), enable=for_update) .where( Webhook.direction == self.direction.value, @@ -106,7 +105,6 @@ def get_pending_webhooks( .limit(limit) .all() ) - return webhooks def update_webhook_status( self, session: Session, webhook_id: str, status: OracleWebhookStatuses diff --git a/packages/examples/cvat/recording-oracle/src/utils/annotations.py b/packages/examples/cvat/recording-oracle/src/utils/annotations.py index 75ec29ec15..ea138a703b 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/annotations.py +++ b/packages/examples/cvat/recording-oracle/src/utils/annotations.py @@ -1,5 +1,6 @@ +from argparse import ArgumentParser +from collections.abc import Iterable from copy import deepcopy -from typing import Iterable, Optional, Tuple, Union import datumaro as dm import numpy as np @@ -38,7 +39,7 @@ def shift_ann( ] ) else: - assert False, f"Unsupported annotation type '{ann.type}'" + raise TypeError(f"Unsupported annotation type '{ann.type}'") return shifted_ann @@ -65,7 +66,7 @@ class ProjectLabels(dm.ItemTransform): """ @classmethod - def build_cmdline_parser(cls, **kwargs): + def build_cmdline_parser(cls, **kwargs) -> ArgumentParser: parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "-l", @@ -76,19 +77,19 @@ def build_cmdline_parser(cls, **kwargs): ) return parser - def __init__( + def __init__( # noqa: PLR0912 self, extractor: dm.IExtractor, - dst_labels: Union[Iterable[Union[str, Tuple[str, str]]], dm.LabelCategories], - ): + dst_labels: Iterable[str | tuple[str, str]] | dm.LabelCategories, + ) -> None: super().__init__(extractor) self._categories = {} src_categories = self._extractor.categories() - src_label_cat: Optional[dm.LabelCategories] = src_categories.get(dm.AnnotationType.label) - src_point_cat: Optional[dm.PointsCategories] = src_categories.get(dm.AnnotationType.points) + src_label_cat: dm.LabelCategories | None = src_categories.get(dm.AnnotationType.label) + src_point_cat: dm.PointsCategories | None = src_categories.get(dm.AnnotationType.points) if isinstance(dst_labels, dm.LabelCategories): dst_label_cat = deepcopy(dst_labels) @@ -99,7 +100,7 @@ def __init__( dst_label_cat = dm.LabelCategories(attributes=deepcopy(src_label_cat.attributes)) for dst_label in dst_labels: - assert isinstance(dst_label, str) or isinstance(dst_label, tuple) + assert isinstance(dst_label, str | tuple) dst_parent = "" if isinstance(dst_label, tuple): @@ -181,7 +182,7 @@ def _make_label_id_map(self, src_label_cat, dst_label_cat): src_id: dst_label_cat.find(src_label_cat[src_id].name, src_label_cat[src_id].parent)[0] for src_id in range(len(src_label_cat or ())) } - self._map_id = lambda src_id: id_mapping.get(src_id, None) + self._map_id = lambda src_id: id_mapping.get(src_id) def categories(self): return self._categories diff --git a/packages/examples/cvat/recording-oracle/src/utils/enums.py b/packages/examples/cvat/recording-oracle/src/utils/enums.py index 4f3d688251..d4c133b0e5 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/enums.py +++ b/packages/examples/cvat/recording-oracle/src/utils/enums.py @@ -6,5 +6,5 @@ class BetterEnumMeta(EnumMeta): Extends the default enum metaclass with extra methods for better usability """ - def __contains__(cls, item): + def __contains__(cls, item) -> bool: return isinstance(item, cls) or item in [v.value for v in cls.__members__.values()] diff --git a/packages/examples/cvat/recording-oracle/src/utils/logging.py b/packages/examples/cvat/recording-oracle/src/utils/logging.py index be2c5feba3..e7660eb0d7 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/logging.py +++ b/packages/examples/cvat/recording-oracle/src/utils/logging.py @@ -1,5 +1,5 @@ import logging -from typing import NewType, Optional, Union +from typing import NewType from src.utils.stack import current_function_name @@ -15,7 +15,7 @@ def parse_log_level(level: str) -> LogLevel: def get_function_logger( - parent_logger: Optional[Union[str, logging.Logger]] = None, + parent_logger: str | logging.Logger | None = None, ) -> logging.Logger: if isinstance(parent_logger, str): parent_logger = logging.getLogger(parent_logger) diff --git a/packages/examples/cvat/recording-oracle/src/utils/net.py b/packages/examples/cvat/recording-oracle/src/utils/net.py index 360dafaebb..04ad799510 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/net.py +++ b/packages/examples/cvat/recording-oracle/src/utils/net.py @@ -1,7 +1,7 @@ import ipaddress -def is_ipv4(addr: str, allow_port: bool = True) -> bool: +def is_ipv4(addr: str, *, allow_port: bool = True) -> bool: try: if allow_port: addr = addr.split(":", maxsplit=1)[0] diff --git a/packages/examples/cvat/recording-oracle/src/utils/requests.py b/packages/examples/cvat/recording-oracle/src/utils/requests.py index 785c2cfc87..ef2174f9b9 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/requests.py +++ b/packages/examples/cvat/recording-oracle/src/utils/requests.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import TypeVar from fastapi import HTTPException @@ -7,11 +7,11 @@ def get_or_404( - obj: Optional[T], + obj: T | None, object_id: V, object_type_name: str, *, - reason: Optional[str] = None, + reason: str | None = None, ) -> T: if obj is None: raise HTTPException( diff --git a/packages/examples/cvat/recording-oracle/src/utils/webhooks.py b/packages/examples/cvat/recording-oracle/src/utils/webhooks.py index e6d39202a3..a1bf128887 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/webhooks.py +++ b/packages/examples/cvat/recording-oracle/src/utils/webhooks.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Dict, Optional, Tuple from src.chain.web3 import sign_message from src.core.oracle_events import parse_event @@ -11,8 +10,8 @@ def prepare_outgoing_webhook_body( chain_id: Networks, event_type: str, event_data: dict, - timestamp: Optional[datetime], -) -> Dict: + timestamp: datetime | None, +) -> dict: body = {"escrow_address": escrow_address, "chain_id": chain_id} if timestamp: @@ -29,11 +28,11 @@ def prepare_outgoing_webhook_body( def prepare_signed_message( - escrow_address: str, + escrow_address: str, # noqa: ARG001 chain_id: Networks, - message: Optional[str] = None, - body: Optional[dict] = None, -) -> Tuple[str, str]: + message: str | None = None, + body: dict | None = None, +) -> tuple[str, str]: """ Sign the message with the service identity. Optionally, can serialize the input structure. diff --git a/packages/examples/cvat/recording-oracle/src/validation/__init__.py b/packages/examples/cvat/recording-oracle/src/validation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py b/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py index fa3a965f7f..b5d7855a85 100644 --- a/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py +++ b/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py @@ -1,5 +1,6 @@ import itertools -from typing import Callable, List, NamedTuple, Sequence, Tuple, TypeVar +from collections.abc import Callable, Sequence +from typing import NamedTuple, TypeVar import numpy as np from scipy.optimize import linear_sum_assignment @@ -77,10 +78,10 @@ def point_to_bbox_cmp( class MatchResult(NamedTuple): - matches: List[Tuple[Annotation, Annotation]] - mispred: List[Tuple[Annotation, Annotation]] - a_extra: List[Annotation] - b_extra: List[Annotation] + matches: list[tuple[Annotation, Annotation]] + mispred: list[tuple[Annotation, Annotation]] + a_extra: list[Annotation] + b_extra: list[Annotation] def match_annotations( @@ -121,7 +122,7 @@ def match_annotations( a_unmatched = [] b_unmatched = [] - for a_idx, b_idx in zip(a_matches, b_matches): + for a_idx, b_idx in zip(a_matches, b_matches, strict=False): dist = distances[a_idx, b_idx] if dist > 1 - min_similarity or dist == 1: if a_idx < len(a_anns): diff --git a/packages/examples/cvat/recording-oracle/src/validation/dataset_comparison.py b/packages/examples/cvat/recording-oracle/src/validation/dataset_comparison.py index cce97e89aa..168d01c820 100644 --- a/packages/examples/cvat/recording-oracle/src/validation/dataset_comparison.py +++ b/packages/examples/cvat/recording-oracle/src/validation/dataset_comparison.py @@ -2,17 +2,15 @@ import itertools from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING import datumaro as dm import numpy as np from attrs import define, field -from datumaro.util.annotation_util import BboxCoords from src.core.config import Config from src.core.validation_errors import TooFewGtError - -from .annotation_matching import ( +from src.validation.annotation_matching import ( Bbox, MatchResult, Point, @@ -21,19 +19,24 @@ point_to_bbox_cmp, ) +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from datumaro.util.annotation_util import BboxCoords + class SimilarityFunction(metaclass=ABCMeta): "A function to compute similarity between 2 annotations" - def __call__(self, gt_ann: dm.Annotation, ds_ann: dm.Annotation) -> float: - ... + @abstractmethod + def __call__(self, gt_ann: dm.Annotation, ds_ann: dm.Annotation) -> float: ... class CachedSimilarityFunction(SimilarityFunction): def __init__( - self, sim_fn: Callable, *, cache: Optional[Dict[Tuple[int, int], float]] = None + self, sim_fn: Callable, *, cache: dict[tuple[int, int], float] | None = None ) -> None: - self.cache: Dict[Tuple[int, int], float] = cache or {} + self.cache: dict[tuple[int, int], float] = cache or {} self.sim_fn = sim_fn def __call__(self, gt_ann: dm.Annotation, ds_ann: dm.Annotation) -> float: @@ -56,9 +59,9 @@ def clear_cache(self): @define class DatasetComparator(metaclass=ABCMeta): _min_similarity_threshold: float - _gt_weights: Dict[str, float] = field(factory=dict) + _gt_weights: dict[str, float] = field(factory=dict) - failed_gts: Set[str] = field(factory=set, init=False) + failed_gts: set[str] = field(factory=set, init=False) "Recorded list of failed GT samples, available after compare() call" def compare(self, gt_dataset: dm.Dataset, ds_dataset: dm.Dataset) -> float: @@ -106,7 +109,7 @@ def compare(self, gt_dataset: dm.Dataset, ds_dataset: dm.Dataset) -> float: dataset_failed_gts.add(gt_sample.id) if dataset_excluded_gts_count == len(gt_dataset): - raise TooFewGtError() + raise TooFewGtError dataset_accuracy = 0 if dataset_total_anns_to_compare: @@ -119,14 +122,13 @@ def compare(self, gt_dataset: dm.Dataset, ds_dataset: dm.Dataset) -> float: @abstractmethod def compare_sample_annotations( self, gt_sample: dm.DatasetItem, ds_sample: dm.DatasetItem, *, similarity_threshold: float - ) -> Tuple[MatchResult, SimilarityFunction]: - ... + ) -> tuple[MatchResult, SimilarityFunction]: ... class BboxDatasetComparator(DatasetComparator): def compare_sample_annotations( self, gt_sample: dm.DatasetItem, ds_sample: dm.DatasetItem, *, similarity_threshold: float - ) -> Tuple[MatchResult, SimilarityFunction]: + ) -> tuple[MatchResult, SimilarityFunction]: similarity_fn = CachedSimilarityFunction(bbox_iou) ds_boxes = [ @@ -153,7 +155,7 @@ def compare_sample_annotations( class PointsDatasetComparator(DatasetComparator): def compare_sample_annotations( self, gt_sample: dm.DatasetItem, ds_sample: dm.DatasetItem, *, similarity_threshold: float - ) -> Tuple[MatchResult, SimilarityFunction]: + ) -> tuple[MatchResult, SimilarityFunction]: similarity_fn = CachedSimilarityFunction(point_to_bbox_cmp) ds_points = [ @@ -186,8 +188,8 @@ def compare_sample_annotations( @define class SkeletonDatasetComparator(DatasetComparator): - _skeleton_info: Dict[int, _SkeletonInfo] = field(factory=dict, init=False) - _categories: Optional[dm.CategoriesInfo] = field(default=None, init=False) + _skeleton_info: dict[int, _SkeletonInfo] = field(factory=dict, init=False) + _categories: dm.CategoriesInfo | None = field(default=None, init=False) # TODO: find better strategy for sigma estimation _oks_sigma: float = Config.validation.default_oks_sigma @@ -198,7 +200,7 @@ def compare(self, gt_dataset: dm.Dataset, ds_dataset: dm.Dataset) -> float: def compare_sample_annotations( self, gt_sample: dm.DatasetItem, ds_sample: dm.DatasetItem, *, similarity_threshold: float - ) -> Tuple[MatchResult, SimilarityFunction]: + ) -> tuple[MatchResult, SimilarityFunction]: return self._match_skeletons( gt_sample, ds_sample, similarity_threshold=similarity_threshold ) @@ -220,7 +222,7 @@ def _get_skeleton_info(self, skeleton_label_id: int) -> _SkeletonInfo: def _match_skeletons( self, item_a: dm.DatasetItem, item_b: dm.DatasetItem, *, similarity_threshold: float - ) -> Tuple[MatchResult, SimilarityFunction]: + ) -> tuple[MatchResult, SimilarityFunction]: a_skeletons = [a for a in item_a.annotations if isinstance(a, dm.Skeleton)] b_skeletons = [a for a in item_b.annotations if isinstance(a, dm.Skeleton)] @@ -303,7 +305,7 @@ def _match_skeletons( def _instance_bbox( self, instance_anns: Sequence[dm.Annotation] - ) -> Tuple[float, float, float, float]: + ) -> tuple[float, float, float, float]: return dm.ops.max_bbox( a.get_bbox() if isinstance(a, dm.Skeleton) else a for a in instance_anns @@ -336,11 +338,11 @@ def _compute_oks( a: dm.Points, b: dm.Points, *, - sigma: Union[float, np.ndarray] = 0.1, - bbox: Optional[BboxCoords] = None, - scale: Union[None, float, np.ndarray] = None, - visibility_a: Union[None, bool, Sequence[bool]] = None, - visibility_b: Union[None, bool, Sequence[bool]] = None, + sigma: float | np.ndarray = 0.1, + bbox: BboxCoords | None = None, + scale: None | float | np.ndarray = None, + visibility_a: None | bool | Sequence[bool] = None, + visibility_b: None | bool | Sequence[bool] = None, ) -> float: """ Computes Object Keypoint Similarity metric for a pair of point sets. @@ -353,12 +355,12 @@ def _compute_oks( return 0 if visibility_a is None: - visibility_a = np.full(len(p1), True) + visibility_a = np.full(len(p1), fill_value=True) else: visibility_a = np.asarray(visibility_a, dtype=bool) if visibility_b is None: - visibility_b = np.full(len(p2), True) + visibility_b = np.full(len(p2), fill_value=True) else: visibility_b = np.asarray(visibility_b, dtype=bool) diff --git a/packages/examples/cvat/recording-oracle/src/validators/__init__.py b/packages/examples/cvat/recording-oracle/src/validators/__init__.py index c8ce69d5f2..4d9d0279c8 100644 --- a/packages/examples/cvat/recording-oracle/src/validators/__init__.py +++ b/packages/examples/cvat/recording-oracle/src/validators/__init__.py @@ -1,4 +1,4 @@ -""" Validation utils""" +"""Validation utils""" class ValidationResult: @@ -7,7 +7,7 @@ class ValidationResult: It encapsulates validation logic and helping during generating response body """ - def __init__(self): + def __init__(self) -> None: self.is_valid = True self.errors = [] diff --git a/packages/examples/cvat/recording-oracle/src/validators/validation.py b/packages/examples/cvat/recording-oracle/src/validators/validation.py index c8ce69d5f2..4d9d0279c8 100644 --- a/packages/examples/cvat/recording-oracle/src/validators/validation.py +++ b/packages/examples/cvat/recording-oracle/src/validators/validation.py @@ -1,4 +1,4 @@ -""" Validation utils""" +"""Validation utils""" class ValidationResult: @@ -7,7 +7,7 @@ class ValidationResult: It encapsulates validation logic and helping during generating response body """ - def __init__(self): + def __init__(self) -> None: self.is_valid = True self.errors = [] diff --git a/packages/examples/cvat/recording-oracle/tests/conftest.py b/packages/examples/cvat/recording-oracle/tests/conftest.py index e027999628..f39f40a8ca 100644 --- a/packages/examples/cvat/recording-oracle/tests/conftest.py +++ b/packages/examples/cvat/recording-oracle/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Generator +from collections.abc import Generator import pytest from fastapi.testclient import TestClient diff --git a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py index b0565f5ee9..2bcb156c87 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py @@ -92,15 +92,16 @@ def test_validate_escrow_invalid_status(self): validate_escrow(self.w3.eth.chain_id, escrow_address) def test_get_escrow_manifest(self): - with patch("src.chain.escrow.get_escrow") as mock_get_escrow, patch( - "src.chain.escrow.StorageUtils.download_file_from_url" - ) as mock_download: + with ( + patch("src.chain.escrow.get_escrow") as mock_get_escrow, + patch("src.chain.escrow.StorageUtils.download_file_from_url") as mock_download, + ): mock_download.return_value = json.dumps({"title": "test"}).encode() mock_get_escrow.return_value = self.escrow() manifest = get_escrow_manifest(self.network_config.chain_id, self.escrow_address) - self.assertIsInstance(manifest, dict) - self.assertIsNotNone(manifest) + assert isinstance(manifest, dict) + assert manifest is not None def test_get_encrypted_escrow_manifest(self): with ( @@ -121,13 +122,13 @@ def test_get_encrypted_escrow_manifest(self): encrypted_manifest = EncryptionUtils.encrypt( original_manifest, public_keys=[PGP_PUBLIC_KEY1, PGP_PUBLIC_KEY2] ) - self.assertNotEqual(encrypted_manifest, original_manifest) + assert encrypted_manifest != original_manifest mock_download.return_value = encrypted_manifest.encode() downloaded_manifest_content = get_escrow_manifest( self.network_config.chain_id, self.escrow_address ) - self.assertDictEqual(downloaded_manifest_content, original_manifest_content) + assert downloaded_manifest_content == original_manifest_content def test_store_results(self): escrow_address = create_escrow(self.w3) @@ -136,38 +137,37 @@ def test_store_results(self): results = store_results( self.w3.eth.chain_id, escrow_address, DEFAULT_MANIFEST_URL, DEFAULT_HASH ) - self.assertIsNone(results) + assert results is None intermediate_results_url = get_intermediate_results_url(self.w3, escrow_address) - self.assertEqual(intermediate_results_url, DEFAULT_MANIFEST_URL) + assert intermediate_results_url == DEFAULT_MANIFEST_URL def test_store_results_invalid_url(self): escrow_address = create_escrow(self.w3) with patch("src.chain.escrow.get_web3") as mock_function: mock_function.return_value = self.w3 - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid URL: invalid_url"): store_results(self.w3.eth.chain_id, escrow_address, "invalid_url", DEFAULT_HASH) - self.assertEqual(f"Invalid URL: invalid_url", str(error.exception)) def test_store_results_invalid_hash(self): escrow_address = create_escrow(self.w3) with patch("src.chain.escrow.get_web3") as mock_function: mock_function.return_value = self.w3 - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid empty hash"): store_results(self.w3.eth.chain_id, escrow_address, DEFAULT_MANIFEST_URL, "") - self.assertEqual(f"Invalid empty hash", str(error.exception)) def test_get_reputation_oracle_address(self): escrow_address = create_escrow(self.w3) - with patch("src.chain.escrow.get_web3") as mock_get_web3, patch( - "src.chain.escrow.get_escrow" - ) as mock_get_escrow: + with ( + patch("src.chain.escrow.get_web3") as mock_get_web3, + patch("src.chain.escrow.get_escrow") as mock_get_escrow, + ): mock_get_web3.return_value = self.w3 mock_escrow = MagicMock() mock_escrow.reputation_oracle = REPUTATION_ORACLE_ADDRESS mock_get_escrow.return_value = mock_escrow address = get_reputation_oracle_address(self.w3.eth.chain_id, escrow_address) - self.assertIsInstance(address, str) - self.assertIsNotNone(address) + assert isinstance(address, str) + assert address is not None def test_get_reputation_oracle_address_invalid_address(self): with patch("src.chain.escrow.get_web3") as mock_function: diff --git a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_kvstore.py b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_kvstore.py index 245d50594e..081221c3d0 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_kvstore.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_kvstore.py @@ -47,7 +47,7 @@ def test_get_reputation_oracle_url(self): mock_leader.return_value = MagicMock(webhook_url=DEFAULT_MANIFEST_URL) reputation_url = get_reputation_oracle_url(self.w3.eth.chain_id, escrow_address) - self.assertEqual(reputation_url, DEFAULT_MANIFEST_URL) + assert reputation_url == DEFAULT_MANIFEST_URL def test_get_reputation_oracle_url_invalid_escrow(self): with patch("src.chain.kvstore.get_web3") as mock_function: @@ -72,21 +72,20 @@ def test_get_reputation_oracle_url_invalid_address(self): reputation_url = get_reputation_oracle_url( self.w3.eth.chain_id, REPUTATION_ORACLE_ADDRESS ) - self.assertEqual(reputation_url, "") + assert reputation_url == "" def test_get_role_by_address(self): store_kvstore_value("role", "Reputation Oracle") with patch("src.chain.kvstore.get_web3") as mock_function: mock_function.return_value = self.w3 reputation_url = get_role_by_address(self.w3.eth.chain_id, REPUTATION_ORACLE_ADDRESS) - self.assertEqual(reputation_url, "Reputation Oracle") + assert reputation_url == "Reputation Oracle" def test_get_role_by_address_invalid_escrow(self): with patch("src.chain.kvstore.get_web3") as mock_function: mock_function.return_value = self.w3 - with self.assertRaises(KVStoreClientError) as error: + with pytest.raises(KVStoreClientError, match="Invalid address: invalid_address"): get_role_by_address(self.w3.eth.chain_id, "invalid_address") - self.assertEqual(f"Invalid address: invalid_address", str(error.exception)) def test_get_role_by_address_invalid_address(self): create_escrow(self.w3) @@ -94,7 +93,7 @@ def test_get_role_by_address_invalid_address(self): with patch("src.chain.kvstore.get_web3") as mock_function: mock_function.return_value = self.w3 reputation_url = get_role_by_address(self.w3.eth.chain_id, REPUTATION_ORACLE_ADDRESS) - self.assertEqual(reputation_url, "") + assert reputation_url == "" def test_store_public_key(self): PGP_PUBLIC_KEY_URL_1 = "http://pgp-public-key-url-1" @@ -114,7 +113,7 @@ def get_file_url_and_verify_hash(*args, **kwargs): hash_ = store["public_key_hash"] if hash_ != hash(public_key): - raise KVStoreClientError(f"Invalid hash") + raise KVStoreClientError("Invalid hash") return public_key @@ -141,7 +140,7 @@ def set_file_url_and_hash(url: str, key: str): mock_web3.return_value = self.w3 kvstore_client = KVStoreClient(self.w3) - self.assertIsNone(kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr)) + assert kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr) is None # check that public key will be set to KVStore at first time with patch( @@ -150,9 +149,9 @@ def set_file_url_and_hash(url: str, key: str): mock_set_file_url_and_hash.side_effect = set_file_url_and_hash register_in_kvstore() mock_set_file_url_and_hash.assert_called_once() - self.assertEquals( - kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr), - PGP_PUBLIC_KEY_URL_1, + assert ( + kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr) + == PGP_PUBLIC_KEY_URL_1 ) # check that the same public key URL is not written to KVStore a second time @@ -163,7 +162,8 @@ def set_file_url_and_hash(url: str, key: str): register_in_kvstore() mock_set_file_url_and_hash.assert_not_called() - # check that public key URL and hash will be updated in KVStore if previous hash is outdated/corrupted + # check that public key URL and hash will be updated in KVStore + # if previous hash is outdated/corrupted with patch( "human_protocol_sdk.kvstore.KVStoreClient.set_file_url_and_hash", Mock() ) as mock_set_file_url_and_hash: @@ -171,9 +171,10 @@ def set_file_url_and_hash(url: str, key: str): store["public_key_hash"] = "corrupted_hash" register_in_kvstore() mock_set_file_url_and_hash.assert_called_once() - self.assertNotEquals(store["public_key_hash"], "corrupted_hash") + assert store["public_key_hash"] != "corrupted_hash" - # check that a new public key URL will be written to KVStore when an outdated URL is stored there + # check that a new public key URL will be written to KVStore + # when an outdated URL is stored there with ( patch( "src.core.config.Config.encryption_config.pgp_public_key_url", @@ -186,7 +187,7 @@ def set_file_url_and_hash(url: str, key: str): mock_set_file_url_and_hash.side_effect = set_file_url_and_hash register_in_kvstore() mock_set_file_url_and_hash.assert_called_once() - self.assertEquals( - kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr), - PGP_PUBLIC_KEY_URL_2, + assert ( + kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr) + == PGP_PUBLIC_KEY_URL_2 ) diff --git a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_web3.py b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_web3.py index 4e37ea8a45..c6a97fd696 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_web3.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_web3.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import patch +import pytest from human_protocol_sdk.constants import ChainId from web3 import HTTPProvider, Web3 from web3.middleware import construct_sign_and_send_raw_middleware @@ -31,9 +32,9 @@ class PolygonMainnetConfig: with patch("src.chain.web3.Config.polygon_mainnet", PolygonMainnetConfig): w3 = get_web3(ChainId.POLYGON.value) - self.assertIsInstance(w3, Web3) - self.assertEqual(w3.eth.default_account, DEFAULT_GAS_PAYER) - self.assertEqual(w3.manager._provider.endpoint_uri, PolygonMainnetConfig.rpc_api) + assert isinstance(w3, Web3) + assert w3.eth.default_account == DEFAULT_GAS_PAYER + assert w3.manager._provider.endpoint_uri == PolygonMainnetConfig.rpc_api def test_get_web3_amoy(self): class PolygonAmoyConfig: @@ -43,67 +44,64 @@ class PolygonAmoyConfig: with patch("src.chain.web3.Config.polygon_amoy", PolygonAmoyConfig): w3 = get_web3(ChainId.POLYGON_AMOY.value) - self.assertIsInstance(w3, Web3) - self.assertEqual(w3.eth.default_account, DEFAULT_GAS_PAYER) - self.assertEqual(w3.manager._provider.endpoint_uri, PolygonAmoyConfig.rpc_api) + assert isinstance(w3, Web3) + assert w3.eth.default_account == DEFAULT_GAS_PAYER + assert w3.manager._provider.endpoint_uri == PolygonAmoyConfig.rpc_api def test_get_web3_localhost(self): w3 = get_web3(ChainId.LOCALHOST.value) - self.assertIsInstance(w3, Web3) - self.assertEqual(w3.eth.default_account, DEFAULT_GAS_PAYER) - self.assertEqual(w3.manager._provider.endpoint_uri, LocalhostConfig.rpc_api) + assert isinstance(w3, Web3) + assert w3.eth.default_account == DEFAULT_GAS_PAYER + assert w3.manager._provider.endpoint_uri == LocalhostConfig.rpc_api def test_get_web3_invalid_chain_id(self): - with self.assertRaises(ValueError) as error: - w3 = get_web3(1234) - self.assertEqual( - "1234 is not in available list of networks.", - str(error.exception), - ) + with pytest.raises(ValueError, match="1234 is not in available list of networks."): + get_web3(1234) def test_sign_message_polygon(self): - with patch("src.chain.web3.get_web3") as mock_function, patch( - "src.chain.web3.Config.polygon_mainnet.private_key", - DEFAULT_GAS_PAYER_PRIV, + with ( + patch("src.chain.web3.get_web3") as mock_function, + patch( + "src.chain.web3.Config.polygon_mainnet.private_key", + DEFAULT_GAS_PAYER_PRIV, + ), ): mock_function.return_value = self.w3 signed_message, _ = sign_message(ChainId.POLYGON.value, "message") - self.assertEqual(signed_message, SIGNATURE) + assert signed_message == SIGNATURE def test_sign_message_amoy(self): - with patch("src.chain.web3.get_web3") as mock_function, patch( - "src.chain.web3.Config.polygon_amoy.private_key", - DEFAULT_GAS_PAYER_PRIV, + with ( + patch("src.chain.web3.get_web3") as mock_function, + patch( + "src.chain.web3.Config.polygon_amoy.private_key", + DEFAULT_GAS_PAYER_PRIV, + ), ): mock_function.return_value = self.w3 signed_message, _ = sign_message(ChainId.POLYGON_AMOY.value, "message") - self.assertEqual(signed_message, SIGNATURE) + assert signed_message == SIGNATURE def test_sign_message_invalid_chain_id(self): - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="1234 is not in available list of networks."): sign_message(1234, "message") - self.assertEqual( - "1234 is not in available list of networks.", - str(error.exception), - ) def test_recover_signer(self): with patch("src.chain.web3.get_web3") as mock_function: mock_function.return_value = self.w3 signer = recover_signer(ChainId.POLYGON.value, "message", SIGNATURE) - self.assertEqual(signer, DEFAULT_GAS_PAYER) + assert signer == DEFAULT_GAS_PAYER def test_recover_signer_invalid_signature(self): with patch("src.chain.web3.get_web3") as mock_function: mock_function.return_value = self.w3 signer = recover_signer(ChainId.POLYGON.value, "test", SIGNATURE) - self.assertNotEqual(signer, DEFAULT_GAS_PAYER) + assert signer != DEFAULT_GAS_PAYER def test_validate_address(self): address = validate_address(DEFAULT_GAS_PAYER) - self.assertEqual(address, DEFAULT_GAS_PAYER) + assert address == DEFAULT_GAS_PAYER def test_validate_address_invalid_address(self): - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="invalid_address is not a correct Web3 address"): validate_address("invalid_address") - self.assertEqual(f"invalid_address is not a correct Web3 address", str(error.exception)) diff --git a/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_exchange_oracle_webhooks.py b/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_exchange_oracle_webhooks.py index 7204d319da..ab90d5cad2 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_exchange_oracle_webhooks.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_exchange_oracle_webhooks.py @@ -63,8 +63,8 @@ def test_process_exchange_oracle_webhook(self): updated_webhook = ( self.session.execute(select(Webhook).where(Webhook.id == webhook.id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed + assert updated_webhook.attempts == 1 def test_process_recording_oracle_webhooks_invalid_escrow_address(self): escrow_address = "invalid_address" @@ -83,8 +83,8 @@ def test_process_recording_oracle_webhooks_invalid_escrow_address(self): self.session.execute(select(Webhook).where(Webhook.id == webhook.id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 def test_process_recording_oracle_webhooks_invalid_escrow_balance(self): escrow_address = create_escrow(self.w3) @@ -103,8 +103,8 @@ def test_process_recording_oracle_webhooks_invalid_escrow_balance(self): self.session.execute(select(Webhook).where(Webhook.id == webhook.id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 @patch("src.chain.escrow.EscrowClient.get_manifest_url") def test_process_job_launcher_webhooks_invalid_manifest_url(self, mock_manifest_url): @@ -123,5 +123,5 @@ def test_process_job_launcher_webhooks_invalid_manifest_url(self, mock_manifest_ self.session.execute(select(Webhook).where(Webhook.id == webhook.id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 diff --git a/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py b/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py index a31c512f85..743c18af0a 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py @@ -53,13 +53,15 @@ def get_webhook(self, escrow_address, chain_id, event_data): def test_process_reputation_oracle_webhooks(self): expected_url = "expected_url" - with patch( - "src.crons.process_reputation_oracle_webhooks.httpx.Client.post" - ) as mock_httpx, patch( - "src.crons.process_reputation_oracle_webhooks.get_reputation_oracle_url" - ) as mock_get_repo_url, patch( - "src.crons.process_reputation_oracle_webhooks.prepare_signed_message" - ) as mock_signature: + with ( + patch("src.crons.process_reputation_oracle_webhooks.httpx.Client.post") as mock_httpx, + patch( + "src.crons.process_reputation_oracle_webhooks.get_reputation_oracle_url" + ) as mock_get_repo_url, + patch( + "src.crons.process_reputation_oracle_webhooks.prepare_signed_message" + ) as mock_signature, + ): mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_httpx.return_value = mock_response @@ -68,7 +70,7 @@ def test_process_reputation_oracle_webhooks(self): chain_id = Networks.localhost.value escrow_address = create_escrow(self.w3) store_kvstore_value("webhook_url", expected_url) - event_data = dict() + event_data = {} mock_signature.return_value = (None, SIGNATURE) webhook = self.get_webhook(escrow_address, chain_id, event_data) @@ -94,8 +96,8 @@ def test_process_reputation_oracle_webhooks(self): "event_type": RecordingOracleEventTypes.task_completed.value, }, ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 def test_process_reputation_oracle_webhooks_invalid_escrow_address(self): chain_id = Networks.localhost.value @@ -112,8 +114,8 @@ def test_process_reputation_oracle_webhooks_invalid_escrow_address(self): self.session.execute(select(Webhook).where(Webhook.id == webhook.id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 def test_process_reputation_oracle_webhooks_invalid_reputation_oracle_url(self): with patch( @@ -131,5 +133,5 @@ def test_process_reputation_oracle_webhooks_invalid_reputation_oracle_url(self): .scalars() .first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 diff --git a/packages/examples/cvat/recording-oracle/tests/integration/services/cloud/test_client_service.py b/packages/examples/cvat/recording-oracle/tests/integration/services/cloud/test_client_service.py index 75cf94eec5..b26712d279 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/services/cloud/test_client_service.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/services/cloud/test_client_service.py @@ -35,7 +35,7 @@ def test_file_operations(self): assert len(client.list_files()) == 0 file_name = "test_file" - data = "this is a test".encode("utf-8") + data = b"this is a test" assert not client.file_exists(file_name) client.create_file(file_name, data) @@ -68,7 +68,7 @@ def test_degenerate_file_operations(self): client.remove_file(invalid_file, bucket=self.bucket_name) def test_degenerate_client(self): - with pytest.raises(EndpointConnectionError): + with pytest.raises(EndpointConnectionError): # noqa: PT012 invalid_client = S3Client( endpoint_url="http://not.an.url:1234", access_key=self.access_key, @@ -76,5 +76,5 @@ def test_degenerate_client(self): ) invalid_client.create_file("test.txt", bucket=self.bucket_name) - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 S3Client(endpoint_url="nonsense-stuff") diff --git a/packages/examples/cvat/recording-oracle/tests/integration/services/test_webhook_service.py b/packages/examples/cvat/recording-oracle/tests/integration/services/test_webhook_service.py index a7cf4673d4..e7f5d8b828 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/services/test_webhook_service.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/services/test_webhook_service.py @@ -2,6 +2,7 @@ import unittest import uuid +import pytest from sqlalchemy.exc import IntegrityError from src.core.types import Networks, OracleWebhookStatuses, OracleWebhookTypes @@ -13,14 +14,14 @@ class ServiceIntegrationTest(unittest.TestCase): def setUp(self): self.session = SessionLocal() - self.webhook_kwargs = dict( - session=self.session, - escrow_address="0x1234567890123456789012345678901234567890", - chain_id=Networks.polygon_mainnet.value, - type=OracleWebhookTypes.exchange_oracle, - signature="signature", - event_type="task_finished", - ) + self.webhook_kwargs = { + "session": self.session, + "escrow_address": "0x1234567890123456789012345678901234567890", + "chain_id": Networks.polygon_mainnet.value, + "type": OracleWebhookTypes.exchange_oracle, + "signature": "signature", + "event_type": "task_finished", + } random.seed(42) def tearDown(self): @@ -44,18 +45,18 @@ def test_create_webhook(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, self.webhook_kwargs["escrow_address"]) - self.assertEqual(webhook.chain_id, self.webhook_kwargs["chain_id"]) - self.assertEqual(webhook.attempts, 0) - self.assertEqual(webhook.signature, self.webhook_kwargs["signature"]) - self.assertEqual(webhook.type, OracleWebhookTypes.exchange_oracle) - self.assertEqual(webhook.status, OracleWebhookStatuses.pending) + assert webhook.escrow_address == self.webhook_kwargs["escrow_address"] + assert webhook.chain_id == self.webhook_kwargs["chain_id"] + assert webhook.attempts == 0 + assert webhook.signature == self.webhook_kwargs["signature"] + assert webhook.type == OracleWebhookTypes.exchange_oracle + assert webhook.status == OracleWebhookStatuses.pending # TODO: check intended fields and verify those def _test_none_webhook_argument(self, argument_name, error_type): kwargs = dict(**self.webhook_kwargs) kwargs[argument_name] = None - with self.assertRaises(error_type): + with pytest.raises(error_type): # noqa: PT012 inbox.create_webhook(**kwargs) self.session.commit() @@ -95,29 +96,29 @@ def test_get_pending_webhooks(self): pending_webhooks = inbox.get_pending_webhooks( self.session, OracleWebhookTypes.exchange_oracle ) - self.assertEqual(len(pending_webhooks), 2) - self.assertEqual(pending_webhooks[0].id, webhook1.id) - self.assertEqual(pending_webhooks[1].id, webhook2.id) + assert len(pending_webhooks) == 2 + assert pending_webhooks[0].id == webhook1.id + assert pending_webhooks[1].id == webhook2.id pending_webhooks = inbox.get_pending_webhooks( self.session, OracleWebhookTypes.reputation_oracle ) - self.assertEqual(len(pending_webhooks), 1) - self.assertEqual(pending_webhooks[0].id, webhook4.id) + assert len(pending_webhooks) == 1 + assert pending_webhooks[0].id == webhook4.id def test_update_webhook_status(self): webhook_id = inbox.create_webhook(**self.webhook_kwargs) webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.status, OracleWebhookStatuses.pending) + assert webhook.status == OracleWebhookStatuses.pending inbox.update_webhook_status(self.session, webhook_id, OracleWebhookStatuses.completed) webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.status, OracleWebhookStatuses.completed) + assert webhook.status == OracleWebhookStatuses.completed def test_update_webhook_invalid_status(self): webhook_id = inbox.create_webhook(**self.webhook_kwargs) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): inbox.update_webhook_status(self.session, webhook_id, "Invalid status") def test_handle_webhook_success(self): @@ -127,23 +128,23 @@ def test_handle_webhook_success(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.attempts, 1) - self.assertEqual(webhook.status, OracleWebhookStatuses.completed.value) + assert webhook.attempts == 1 + assert webhook.status == OracleWebhookStatuses.completed.value def test_handle_webhook_fail(self): webhook_id = inbox.create_webhook(**self.webhook_kwargs) inbox.handle_webhook_fail(self.session, webhook_id) webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.attempts, 1) - self.assertEqual(webhook.type, OracleWebhookTypes.exchange_oracle.value) - self.assertEqual(webhook.status, OracleWebhookStatuses.pending.value) + assert webhook.attempts == 1 + assert webhook.type == OracleWebhookTypes.exchange_oracle.value + assert webhook.status == OracleWebhookStatuses.pending.value # assumes Config.webhook_max_retries == 5 - for i in range(4): + for _i in range(4): inbox.handle_webhook_fail(self.session, webhook_id) webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.attempts, 5) - self.assertEqual(webhook.status, OracleWebhookStatuses.failed.value) + assert webhook.attempts == 5 + assert webhook.status == OracleWebhookStatuses.failed.value diff --git a/packages/examples/cvat/recording-oracle/tests/utils/constants.py b/packages/examples/cvat/recording-oracle/tests/utils/constants.py index 839008f892..2df6f087da 100644 --- a/packages/examples/cvat/recording-oracle/tests/utils/constants.py +++ b/packages/examples/cvat/recording-oracle/tests/utils/constants.py @@ -17,14 +17,20 @@ DEFAULT_MANIFEST_URL = "http://host.docker.internal:9000/manifests/manifest.json" DEFAULT_HASH = "test" -SIGNATURE = "0xa0c5626301e3c198cb91356e492890c0c28db8c37044846134939246911a693c4d7116d04aa4bc40a41077493868b8dd533d30980f6addb28d1b3610a84cb4091c" +SIGNATURE = ( + "0xa0c5626301e3c198cb91356e492890c0c28db8c37044846134939246911a693c" + "4d7116d04aa4bc40a41077493868b8dd533d30980f6addb28d1b3610a84cb4091c" +) WEBHOOK_MESSAGE = { "escrow_address": "0xFE776895f6b00AA53969b20119a4777Ed920676a", "chain_id": 80002, } -WEBHOOK_MESSAGE_SIGNED = "0xfeef93fdb26b9b855da432ca3ccc4425366e656401d3a4f67c2f0cab053fe34c2c7d0ce026d47f9fa100c05c12945f94c60dcb63ed94124fa4c60e888c4156281c" +WEBHOOK_MESSAGE_SIGNED = ( + "0xfeef93fdb26b9b855da432ca3ccc4425366e656401d3a4f67c2f0cab053fe34c" + "2c7d0ce026d47f9fa100c05c12945f94c60dcb63ed94124fa4c60e888c4156281c" +) JOB_REQUESTER_ID = "9001" diff --git a/packages/examples/cvat/recording-oracle/tests/utils/setup_escrow.py b/packages/examples/cvat/recording-oracle/tests/utils/setup_escrow.py index afc161a1ed..b840fbc44b 100644 --- a/packages/examples/cvat/recording-oracle/tests/utils/setup_escrow.py +++ b/packages/examples/cvat/recording-oracle/tests/utils/setup_escrow.py @@ -25,7 +25,7 @@ def create_escrow(web3: Web3): escrow_client = EscrowClient(web3) staking_client.approve_stake(amount) staking_client.stake(amount) - escrow_address = escrow_client.create_and_setup_escrow( + return escrow_client.create_and_setup_escrow( token_address=NETWORKS[ChainId.LOCALHOST]["hmt_address"], trusted_handlers=[web3.eth.default_account], job_requester_id=JOB_REQUESTER_ID, @@ -40,7 +40,6 @@ def create_escrow(web3: Web3): hash=DEFAULT_HASH, ), ) - return escrow_address def fund_escrow(web3: Web3, escrow_address: str): @@ -57,7 +56,6 @@ def bulk_payout(web3: Web3, escrow_address: str, recipient: str, amount: Decimal def get_intermediate_results_url(web3: Web3, escrow_address: str): escrow_client = EscrowClient(web3) - intermediate_results_url = ( + return ( escrow_client._get_escrow_contract(escrow_address).functions.intermediateResultsUrl().call() ) - return intermediate_results_url