diff --git a/.github/workflows/ci-lint-cvat-exchange-oracle.yaml b/.github/workflows/ci-lint-cvat-exchange-oracle.yaml new file mode 100644 index 0000000000..cb43ed75d4 --- /dev/null +++ b/.github/workflows/ci-lint-cvat-exchange-oracle.yaml @@ -0,0 +1,33 @@ +name: CVAT Exchange Oracle Lint + +on: + push: + paths: + - 'packages/examples/cvat/exchange-oracle/**' + - '.github/workflows/ci-lint-cvat-exchange-oracle.yaml' + +env: + WORKING_DIR: ./packages/examples/cvat/exchange-oracle + +defaults: + run: + working-directory: ./packages/examples/cvat/exchange-oracle + +jobs: + cvat-exo-lint: + name: CVAT Exchange Oracle Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: ${{ env.WORKING_DIR }}/poetry.lock + - run: python -m pip install poetry + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'poetry' + cache-dependency-path: ${{ env.WORKING_DIR }}/poetry.lock + - run: poetry install --no-root --only dev + - run: poetry run pre-commit run --all-files \ No newline at end of file diff --git a/packages/examples/cvat/exchange-oracle/.pre-commit-config.yaml b/packages/examples/cvat/exchange-oracle/.pre-commit-config.yaml index 08cd8c4879..75cdc64287 100644 --- a/packages/examples/cvat/exchange-oracle/.pre-commit-config.yaml +++ b/packages/examples/cvat/exchange-oracle/.pre-commit-config.yaml @@ -1,11 +1,17 @@ repos: - - repo: https://github.com/psf/black - rev: 22.6.0 + - repo: local hooks: - - id: black - language_version: python3.11 - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - language_version: python3.11 \ No newline at end of file + - id: lint + name: lint + entry: ruff check --fix --unsafe-fixes --show-fixes + language: system + require_serial: true + files: "^packages/examples/cvat/exchange-oracle/.*" + types: [python] + - id: format + name: format + entry: ruff format + require_serial: true + language: system + files: "^packages/examples/cvat/exchange-oracle/.*" + types: [python] diff --git a/packages/examples/cvat/exchange-oracle/alembic.ini b/packages/examples/cvat/exchange-oracle/alembic.ini index 86d1800182..12c6610f1b 100644 --- a/packages/examples/cvat/exchange-oracle/alembic.ini +++ b/packages/examples/cvat/exchange-oracle/alembic.ini @@ -63,11 +63,15 @@ sqlalchemy.url = # on newly generated revision scripts. See the documentation for further # detail and examples -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME +hooks=ruff, ruff_format, types_update + +ruff.type = exec +ruff.executable = ruff +ruff.options = check --fix --unsafe-fixes REVISION_SCRIPT_FILENAME + +ruff_format.type = exec +ruff_format.executable = ruff +ruff_format.options = format REVISION_SCRIPT_FILENAME # Logging configuration [loggers] diff --git a/packages/examples/cvat/exchange-oracle/alembic/env.py b/packages/examples/cvat/exchange-oracle/alembic/env.py index 5ecaa3c3f2..1328f73712 100644 --- a/packages/examples/cvat/exchange-oracle/alembic/env.py +++ b/packages/examples/cvat/exchange-oracle/alembic/env.py @@ -21,9 +21,7 @@ # from myapp import mymodel # target_metadata = mymodel.Base.metadata -from src.db import Base -from src.models.cvat import Job, Task -from src.models.webhook import Webhook +from src.db import Base # noqa: E402 target_metadata = Base.metadata diff --git a/packages/examples/cvat/exchange-oracle/alembic/versions/0f3fb7bfcbcf_add_escrow_creation_tracking.py b/packages/examples/cvat/exchange-oracle/alembic/versions/0f3fb7bfcbcf_add_escrow_creation_tracking.py index 11c6800a5d..b89bf427c6 100644 --- a/packages/examples/cvat/exchange-oracle/alembic/versions/0f3fb7bfcbcf_add_escrow_creation_tracking.py +++ b/packages/examples/cvat/exchange-oracle/alembic/versions/0f3fb7bfcbcf_add_escrow_creation_tracking.py @@ -5,37 +5,45 @@ Create Date: 2024-04-12 18:51:51.504971 """ -from alembic import op + import sqlalchemy as sa -import sqlalchemy_utils +from alembic import op # revision identifiers, used by Alembic. -revision = '0f3fb7bfcbcf' -down_revision = 'c1e74c227cfe' +revision = "0f3fb7bfcbcf" +down_revision = "c1e74c227cfe" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('escrow_creations', - sa.Column('id', sa.String(), nullable=False), - sa.Column('escrow_address', sa.String(length=42), nullable=False), - sa.Column('chain_id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('finished_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('total_jobs', sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "escrow_creations", + sa.Column("id", sa.String(), nullable=False), + sa.Column("escrow_address", sa.String(length=42), nullable=False), + sa.Column("chain_id", sa.Integer(), nullable=False), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True + ), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("total_jobs", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_escrow_creations_escrow_address"), + "escrow_creations", + ["escrow_address"], + unique=False, ) - op.create_index(op.f('ix_escrow_creations_escrow_address'), 'escrow_creations', ['escrow_address'], unique=False) - op.create_index(op.f('ix_escrow_creations_id'), 'escrow_creations', ['id'], unique=False) + op.create_index(op.f("ix_escrow_creations_id"), "escrow_creations", ["id"], unique=False) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_escrow_creations_id'), table_name='escrow_creations') - op.drop_index(op.f('ix_escrow_creations_escrow_address'), table_name='escrow_creations') - op.drop_table('escrow_creations') + op.drop_index(op.f("ix_escrow_creations_id"), table_name="escrow_creations") + op.drop_index(op.f("ix_escrow_creations_escrow_address"), table_name="escrow_creations") + op.drop_table("escrow_creations") # ### end Alembic commands ### diff --git a/packages/examples/cvat/exchange-oracle/alembic/versions/16ecc586d685_init.py b/packages/examples/cvat/exchange-oracle/alembic/versions/16ecc586d685_init.py index 6ff8ca49b5..1aca7bcaaf 100644 --- a/packages/examples/cvat/exchange-oracle/alembic/versions/16ecc586d685_init.py +++ b/packages/examples/cvat/exchange-oracle/alembic/versions/16ecc586d685_init.py @@ -1,14 +1,14 @@ """init Revision ID: 16ecc586d685 -Revises: +Revises: Create Date: 2023-10-05 13:56:16.966151 """ -from alembic import op + import sqlalchemy as sa -import sqlalchemy_utils +from alembic import op # revision identifiers, used by Alembic. revision = "16ecc586d685" diff --git a/packages/examples/cvat/exchange-oracle/alembic/versions/__init__.py b/packages/examples/cvat/exchange-oracle/alembic/versions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/alembic/versions/c1e74c227cfe_non_unique_escrows.py b/packages/examples/cvat/exchange-oracle/alembic/versions/c1e74c227cfe_non_unique_escrows.py index f2ea93a456..bc20bbb497 100644 --- a/packages/examples/cvat/exchange-oracle/alembic/versions/c1e74c227cfe_non_unique_escrows.py +++ b/packages/examples/cvat/exchange-oracle/alembic/versions/c1e74c227cfe_non_unique_escrows.py @@ -5,25 +5,23 @@ Create Date: 2024-02-05 22:54:42.478270 """ -from alembic import op -import sqlalchemy as sa -import sqlalchemy_utils +from alembic import op # revision identifiers, used by Alembic. -revision = 'c1e74c227cfe' -down_revision = '16ecc586d685' +revision = "c1e74c227cfe" +down_revision = "16ecc586d685" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint('projects_escrow_address_key', 'projects', type_='unique') + op.drop_constraint("projects_escrow_address_key", "projects", type_="unique") # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_unique_constraint('projects_escrow_address_key', 'projects', ['escrow_address']) + op.create_unique_constraint("projects_escrow_address_key", "projects", ["escrow_address"]) # ### end Alembic commands ### diff --git a/packages/examples/cvat/exchange-oracle/poetry.lock b/packages/examples/cvat/exchange-oracle/poetry.lock index b301689e1a..136c82ed69 100644 --- a/packages/examples/cvat/exchange-oracle/poetry.lock +++ b/packages/examples/cvat/exchange-oracle/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -397,52 +397,6 @@ files = [ {file = "bitarray-2.9.2.tar.gz", hash = "sha256:a8f286a51a32323715d77755ed959f94bef13972e9a2fe71b609e40e6d27957e"}, ] -[[package]] -name = "black" -version = "23.12.1" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "boto3" version = "1.34.30" @@ -1874,20 +1828,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "isort" -version = "5.13.2" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, - {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, -] - -[package.extras] -colors = ["colorama (>=0.4.6)"] - [[package]] name = "jmespath" version = "1.0.1" @@ -2470,17 +2410,6 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] -[[package]] -name = "mypy-extensions" -version = "1.0.0" -description = "Type system extensions for programs checked with the mypy type checker." -optional = false -python-versions = ">=3.5" -files = [ - {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, - {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, -] - [[package]] name = "networkx" version = "3.2.1" @@ -2766,17 +2695,6 @@ files = [ [package.dependencies] regex = ">=2022.3.15" -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "pgpy" version = "0.6.0" @@ -2955,6 +2873,8 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -3677,35 +3597,82 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, ] +[[package]] +name = "ruff" +version = "0.6.0" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.6.0-py3-none-linux_armv6l.whl", hash = "sha256:92dcce923e5df265781e5fc76f9a1edad52201a7aafe56e586b90988d5239013"}, + {file = "ruff-0.6.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:31b90ff9dc79ed476c04e957ba7e2b95c3fceb76148f2079d0d68a908d2cfae7"}, + {file = "ruff-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6d834a9ec9f8287dd6c3297058b3a265ed6b59233db22593379ee38ebc4b9768"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2089267692696aba342179471831a085043f218706e642564812145df8b8d0d"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aa62b423ee4bbd8765f2c1dbe8f6aac203e0583993a91453dc0a449d465c84da"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7344e1a964b16b1137ea361d6516ce4ee61a0403fa94252a1913ecc1311adcae"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:487f3a35c3f33bf82be212ce15dc6278ea854e35573a3f809442f73bec8b2760"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75db409984077a793cf344d499165298a6f65449e905747ac65983b12e3e64b1"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84908bd603533ecf1db456d8fc2665d1f4335d722e84bc871d3bbd2d1116c272"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f1749a0aef3ec41ed91a0e2127a6ae97d2e2853af16dbd4f3c00d7a3af726c5"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:016fea751e2bcfbbd2f8cb19b97b37b3fd33148e4df45b526e87096f4e17354f"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6ae80f141b53b2e36e230017e64f5ea2def18fac14334ffceaae1b780d70c4f7"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:eaaaf33ea4b3f63fd264d6a6f4a73fa224bbfda4b438ffea59a5340f4afa2bb5"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7667ddd1fc688150a7ca4137140867584c63309695a30016880caf20831503a0"}, + {file = "ruff-0.6.0-py3-none-win32.whl", hash = "sha256:ae48365aae60d40865a412356f8c6f2c0be1c928591168111eaf07eaefa6bea3"}, + {file = "ruff-0.6.0-py3-none-win_amd64.whl", hash = "sha256:774032b507c96f0c803c8237ce7d2ef3934df208a09c40fa809c2931f957fe5e"}, + {file = "ruff-0.6.0-py3-none-win_arm64.whl", hash = "sha256:a5366e8c3ae6b2dc32821749b532606c42e609a99b0ae1472cf601da931a048c"}, + {file = "ruff-0.6.0.tar.gz", hash = "sha256:272a81830f68f9bd19d49eaf7fa01a5545c5a2e86f32a9935bb0e4bb9a1db5b8"}, +] + [[package]] name = "s3transfer" version = "0.10.0" @@ -4370,4 +4337,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10,<3.13" -content-hash = "9c5912834d202daf5607eebb1405e9f13c15bf5770fb5f58567ffd4331698b56" +content-hash = "cca3e02a568c585382d045430da0c2d18234634f2a86bb81a12b6cad52ef226e" diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index 2bca744724..bcca6ac50b 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -28,19 +28,121 @@ pyinstrument = "^4.6.2" [tool.poetry.group.dev.dependencies] -black = "^23.1.0" pre-commit = "^3.0.4" -isort = "^5.12.0" +ruff = "^0.6.0" -[tool.isort] -profile = "black" -forced_separate = ["tests"] -line_length = 100 -skip_gitignore = true # align tool behavior with Black - -[tool.black] +[tool.ruff] line-length = 100 -target-version = ['py310'] +target-version = "py310" + + +[tool.ruff.lint] +select = ["ALL"] +unfixable = [ + "RUF005", # messes up concantenation with numpy structures +] +ignore = [ + "W191", # Rules conflicting with ruff format (https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules) + "E111", # | + "E114", # | + "E117", # | + "D206", # | + "D300", # | + "Q000", # | + "Q001", # | + "Q002", # | + "Q003", # | + "COM812", # | + "COM819", # | + "ISC001", # | + "ISC002", # | + "ANN101", # Method args annotations (mypy will take care of that) + "ANN001", # | + "ANN202", # | + "ANN201", # | + "ANN401", # | + "ANN102", # | + "RUF001", # Allow cyrillic letters in comments + + "B904", # Raise from: modern pythons preserve previous exceptions + "EM", # Forbids using literal strings in exceptions. + # Sujested way of dealing with exceptions increases verbosity + # while giving little to no benefit in readability + "TRY003", # | + "G004", # Forbids using f-strings in logging. This project doesn't rely on lazy % formatting when using logging. + "A003", # Class attribute `id` is shadowing a Python builtin — it's ok in class body + "FIX001", # Forbids using TODOs, but TODOs are useful + "FIX002", # | + "TD001", # | + "TD002", # | + "TD003", # | + "E711", # Allow == None comparisons for sqlalchemy queries + "E712", # Allow == True comparisons for sqlalchemy queries + "PERF203", # Noisy microoptimisation + # Want to resolve eventually, but not now: + "S101", # Allow asserts (there are too many of them right now to fix) + "TRY401", # Checks for excesive logging of exception objects + "G001", # Forbid str.format for logging + "PTH123", # Checks for uses of `os.path.splitext` + "D", # Docstrings + "N806", # Variable in function should be lowercase + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` + "SLF001", # Private member accessed + "F811", # Redefinition of unused + "RUF005", # Consider iterable unpacking instead of concatenation + "A002", # Argument is shadowing a Python builtin + "N818", # Exception name should be named with an Error suffix + "TRY002", # Create your own exception + "ANN003", # Missing type annotation for `**kwargs` + "ANN204", # Missing return type annotation for special method + "ERA001", # Found commented-out code + "N801", # Class name should use CapWords convention + "PLR0915", # Too many statements + "F401", # Imported but unused + "PLR2004", # Magic value used in comparison, consider replacing with a constant variable + "ANN002", # Missing type annotation for `*args` + "TRY300", # Consider moving this statement to an `else` block + "C901", # Function is too complex + "PLW2901", # Variable overwritten by assignment target + "PTH118", # Prefer pathlib instead of os.path + "PTH119", # `os.path.basename()` should be replaced by `Path.name` + "PTH122", # `os.path.splitext()` should be replaced by `Path.suffix`, `Path.stem`, and `Path.parent` + "PTH207", # Replace `glob` with `Path.glob` or `Path.rglob` + "UP032", # Upgrades .format() to f-strings. Disabling until it ignores line-length + "C401", # set comprehensions are apparently can be hard to distinguish from dict comprehensions + "RET505", # Unnecessary elif/else statements after/before raise/return. +] + + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ + "PLR0913", # Annotations and args + "ANN202", # | + "ANN201", # | + "ANN001", # | + "ANN003", # | + "ARG001", # | + "SLF001", # Allow private attrs access + "PLR2004", # Allow magic values + "S", # security + "DTZ005", # allow datetimes without timezones +] +# alembic is not a package in a traditional sense, so putting __init__.py there doesn't make sense +"alembic/*" = ["INP001"] + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + "pydantic.validator", +] + +[tool.ruff.lint.pylint] +max-args = 9 # Lower number might be beneficial to reduce cognitive load. Consider using data containers. + +[tool.ruff.lint.isort] +forced-separate = ["tests"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" [build-system] requires = ["poetry-core"] diff --git a/packages/examples/cvat/exchange-oracle/run.py b/packages/examples/cvat/exchange-oracle/run.py index 8e313c1bf1..92b0fc8a8c 100644 --- a/packages/examples/cvat/exchange-oracle/run.py +++ b/packages/examples/cvat/exchange-oracle/run.py @@ -10,7 +10,7 @@ uvicorn.run( app="src:app", - host="0.0.0.0", + host="0.0.0.0", # noqa: S104 port=int(Config.port), workers=Config.workers_amount, # reload=is_dev, diff --git a/packages/examples/cvat/exchange-oracle/src/chain/__init__.py b/packages/examples/cvat/exchange-oracle/src/chain/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py index 6b85ebe7c7..1d76214f59 100644 --- a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py @@ -1,8 +1,7 @@ import json -from typing import List from human_protocol_sdk.constants import ChainId, Status -from human_protocol_sdk.encryption import Encryption, EncryptionUtils +from human_protocol_sdk.encryption import Encryption from human_protocol_sdk.escrow import EscrowData, EscrowUtils from human_protocol_sdk.storage import StorageUtils @@ -21,9 +20,11 @@ def validate_escrow( chain_id: int, escrow_address: str, *, - accepted_states: List[Status] = [Status.Pending], + accepted_states: list[Status] | None = None, allow_no_funds: bool = False, ) -> None: + if accepted_states is None: + accepted_states = [Status.Pending] assert accepted_states escrow = get_escrow(chain_id, escrow_address) @@ -36,9 +37,8 @@ def validate_escrow( ) ) - if status == Status.Pending and not allow_no_funds: - if int(escrow.balance) == 0: - raise ValueError("Escrow doesn't have funds") + if status == Status.Pending and not allow_no_funds and int(escrow.balance) == 0: + raise ValueError("Escrow doesn't have funds") def get_escrow_manifest(chain_id: int, escrow_address: str) -> dict: diff --git a/packages/examples/cvat/exchange-oracle/src/chain/web3.py b/packages/examples/cvat/exchange-oracle/src/chain/web3.py index cc427c96b9..e76f77bf95 100644 --- a/packages/examples/cvat/exchange-oracle/src/chain/web3.py +++ b/packages/examples/cvat/exchange-oracle/src/chain/web3.py @@ -71,9 +71,7 @@ def sign_message(chain_id: Networks, message) -> str: def recover_signer(chain_id: Networks, message, signature: str) -> str: w3 = get_web3(chain_id) message_hash = encode_defunct(text=serialize_message(message)) - signer = w3.eth.account.recover_message(message_hash, signature=signature) - - return signer + return w3.eth.account.recover_message(message_hash, signature=signature) def validate_address(escrow_address: str) -> str: diff --git a/packages/examples/cvat/exchange-oracle/src/core/__init__.py b/packages/examples/cvat/exchange-oracle/src/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/core/annotation_meta.py b/packages/examples/cvat/exchange-oracle/src/core/annotation_meta.py index e19a77efd5..d539787315 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/annotation_meta.py +++ b/packages/examples/cvat/exchange-oracle/src/core/annotation_meta.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List from pydantic import BaseModel @@ -15,4 +14,4 @@ class JobMeta(BaseModel): class AnnotationMeta(BaseModel): - jobs: List[JobMeta] + jobs: list[JobMeta] diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 91669459bc..7a6481574b 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -1,8 +1,10 @@ # pylint: disable=too-few-public-methods,missing-class-docstring -""" Project configuration from env vars """ +"""Project configuration from env vars""" + import inspect import os -from typing import ClassVar, Iterable, Optional +from collections.abc import Iterable +from typing import ClassVar from attrs.converters import to_bool from dotenv import load_dotenv @@ -14,7 +16,7 @@ from src.utils.net import is_ipv4 dotenv_path = os.getenv("DOTENV_PATH", None) -if dotenv_path and not os.path.exists(dotenv_path): +if dotenv_path and not os.path.exists(dotenv_path): # noqa: PTH110 raise FileNotFoundError(dotenv_path) load_dotenv(dotenv_path) @@ -28,22 +30,22 @@ def validate(cls) -> None: class PostgresConfig: port = os.environ.get("PG_PORT", "5432") - host = os.environ.get("PG_HOST", "0.0.0.0") + host = os.environ.get("PG_HOST", "0.0.0.0") # noqa: S104 user = os.environ.get("PG_USER", "admin") password = os.environ.get("PG_PASSWORD", "admin") database = os.environ.get("PG_DB", "exchange_oracle") lock_timeout = int(os.environ.get("PG_LOCK_TIMEOUT", "3000")) # milliseconds @classmethod - def connection_url(cls): + def connection_url(cls) -> str: return f"postgresql://{cls.user}:{cls.password}@{cls.host}:{cls.port}/{cls.database}" class _NetworkConfig: chain_id: ClassVar[int] - rpc_api: ClassVar[Optional[str]] - private_key: ClassVar[Optional[str]] - addr: ClassVar[Optional[str]] + rpc_api: ClassVar[str | None] + private_key: ClassVar[str | None] + addr: ClassVar[str | None] @classmethod def is_configured(cls) -> bool: @@ -154,27 +156,27 @@ class StorageConfig: endpoint_url: ClassVar[str] = os.environ[ "STORAGE_ENDPOINT_URL" ] # TODO: probably should be optional - region: ClassVar[Optional[str]] = os.environ.get("STORAGE_REGION") + region: ClassVar[str | None] = os.environ.get("STORAGE_REGION") results_dir_suffix: ClassVar[str] = os.environ.get("STORAGE_RESULTS_DIR_SUFFIX", "-results") secure: ClassVar[bool] = to_bool(os.environ.get("STORAGE_USE_SSL", "true")) # S3 specific attributes - access_key: ClassVar[Optional[str]] = os.environ.get("STORAGE_ACCESS_KEY") - secret_key: ClassVar[Optional[str]] = os.environ.get("STORAGE_SECRET_KEY") + access_key: ClassVar[str | None] = os.environ.get("STORAGE_ACCESS_KEY") + secret_key: ClassVar[str | None] = os.environ.get("STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path: ClassVar[Optional[str]] = os.environ.get("STORAGE_KEY_FILE_PATH") + key_file_path: ClassVar[str | None] = os.environ.get("STORAGE_KEY_FILE_PATH") @classmethod def get_scheme(cls) -> str: return "https://" if cls.secure else "http://" @classmethod - def provider_endpoint_url(cls): + def provider_endpoint_url(cls) -> str: return f"{cls.get_scheme()}{cls.endpoint_url}" @classmethod - def bucket_url(cls): + def bucket_url(cls) -> str: if is_ipv4(cls.endpoint_url): return f"{cls.get_scheme()}{cls.endpoint_url}/{cls.data_bucket_name}/" else: @@ -188,10 +190,10 @@ class FeaturesConfig: default_export_timeout = int(os.environ.get("DEFAULT_EXPORT_TIMEOUT", 60)) "Timeout, in seconds, for annotations or dataset export waiting" - request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", False)) + request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", "0")) "Allow to log request details for each request" - profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", False)) + profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", "0")) "Allow to profile specific requests" @@ -222,12 +224,12 @@ def validate(cls) -> None: ex_prefix = "Wrong server configuration." if (cls.pgp_public_key_url or cls.pgp_passphrase) and not cls.pgp_private_key: - raise Exception(" ".join([ex_prefix, "The PGP_PRIVATE_KEY environment is not set."])) + raise Exception(f"{ex_prefix} The PGP_PRIVATE_KEY environment is not set.") if cls.pgp_private_key: try: Encryption(cls.pgp_private_key, passphrase=cls.pgp_passphrase) - except Exception as ex: + except Exception as ex: # noqa: BLE001 # Possible reasons: # - private key is invalid # - private key is locked but no passphrase is provided @@ -264,9 +266,12 @@ def validate(cls) -> None: attr_or_method.validate() @classmethod - def get_network_configs(cls, only_configured: bool = True) -> Iterable[_NetworkConfig]: + def get_network_configs(cls, *, only_configured: bool = True) -> Iterable[_NetworkConfig]: for attr_or_method in cls.__dict__: attr_or_method = getattr(cls, attr_or_method) - if inspect.isclass(attr_or_method) and issubclass(attr_or_method, _NetworkConfig): - if not only_configured or attr_or_method.is_configured(): - yield attr_or_method + if ( + inspect.isclass(attr_or_method) + and issubclass(attr_or_method, _NetworkConfig) + and (not only_configured or attr_or_method.is_configured()) + ): + yield attr_or_method diff --git a/packages/examples/cvat/exchange-oracle/src/core/manifest.py b/packages/examples/cvat/exchange-oracle/src/core/manifest.py index c8a4c2260d..c70c1ed91a 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/manifest.py +++ b/packages/examples/cvat/exchange-oracle/src/core/manifest.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import Enum -from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Literal from pydantic import AnyUrl, BaseModel, Field, root_validator @@ -28,22 +28,22 @@ class AwsBucketUrl(BucketUrlBase, BaseModel): class GcsBucketUrl(BucketUrlBase, BaseModel): provider: Literal[BucketProviders.gcs] - service_account_key: Dict[str, Any] = {} # (optional) Contents of GCS key file + service_account_key: dict[str, Any] = {} # (optional) Contents of GCS key file -BucketUrl = Annotated[Union[AwsBucketUrl, GcsBucketUrl], Field(discriminator="provider")] +BucketUrl = Annotated[AwsBucketUrl | GcsBucketUrl, Field(discriminator="provider")] class DataInfo(BaseModel): - data_url: Union[AnyUrl, BucketUrl] + data_url: AnyUrl | BucketUrl "Bucket URL, AWS S3 | GCS, virtual-hosted-style access" # https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html - points_url: Optional[Union[AnyUrl, BucketUrl]] = None + points_url: AnyUrl | BucketUrl | None = None "A path to an archive with a set of points in COCO Keypoints format, " "which provides information about all objects on images" - boxes_url: Optional[Union[AnyUrl, BucketUrl]] = None + boxes_url: AnyUrl | BucketUrl | None = None "A path to an archive with a set of boxes in COCO Instances format, " "which provides information about all objects on images" @@ -67,7 +67,7 @@ class PlainLabelInfo(LabelInfoBase): class SkeletonLabelInfo(LabelInfoBase): type: Literal[LabelTypes.skeleton] - nodes: List[str] = Field(min_items=1) + nodes: list[str] = Field(min_items=1) """ A list of node label names (only points are supposed to be nodes). Example: @@ -76,7 +76,7 @@ class SkeletonLabelInfo(LabelInfoBase): ] """ - joints: Optional[List[Tuple[int, int]]] = Field(default_factory=list) + joints: list[tuple[int, int]] | None = Field(default_factory=list) "A list of node adjacency, e.g. [[0, 1], [1, 2], [1, 3]]" @root_validator @@ -114,7 +114,7 @@ def validate_type(cls, values: dict) -> dict: return values -LabelInfo = Annotated[Union[PlainLabelInfo, SkeletonLabelInfo], Field(discriminator="type")] +LabelInfo = Annotated[PlainLabelInfo | SkeletonLabelInfo, Field(discriminator="type")] class AnnotationInfo(BaseModel): @@ -132,7 +132,7 @@ class AnnotationInfo(BaseModel): job_size: int = 10 "Frames per job, validation frames are not included" - max_time: Optional[int] = None # deprecated, TODO: mark deprecated with pydantic 2.7+ + max_time: int | None = None # deprecated, TODO: mark deprecated with pydantic 2.7+ "Maximum time per job (assignment) for an annotator, in seconds" @root_validator(pre=True) @@ -161,7 +161,7 @@ class ValidationInfo(BaseModel): val_size: int = Field(default=2, gt=0) "Validation frames per job" - gt_url: Union[AnyUrl, BucketUrl] + gt_url: AnyUrl | BucketUrl "URL to the archive with Ground Truth annotations, the format is COCO keypoints" diff --git a/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py b/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py index 334fd2ee80..6188d598f0 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py +++ b/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py @@ -1,4 +1,4 @@ -from typing import Optional, Type, Union +from typing import Union from pydantic import BaseModel @@ -9,11 +9,7 @@ RecordingOracleEventTypes, ) -EventTypeTag = Union[ - ExchangeOracleEventTypes, - JobLauncherEventTypes, - RecordingOracleEventTypes, -] +EventTypeTag = ExchangeOracleEventTypes | JobLauncherEventTypes | RecordingOracleEventTypes class OracleEvent(BaseModel): @@ -58,7 +54,7 @@ class ExchangeOracleEvent_TaskFinished(OracleEvent): } -def get_class_for_event_type(event_type: str) -> Type[OracleEvent]: +def get_class_for_event_type(event_type: str) -> type[OracleEvent]: event_class = next((v for k, v in _event_type_map.items() if k == event_type), None) if not event_class: @@ -68,7 +64,7 @@ def get_class_for_event_type(event_type: str) -> Type[OracleEvent]: def get_type_tag_for_event_class( - event_class: Type[OracleEvent], + event_class: type[OracleEvent], ) -> EventTypeTag: event_type = next((k for k, v in _event_type_map.items() if v == event_class), None) @@ -81,7 +77,7 @@ def get_type_tag_for_event_class( def parse_event( sender: OracleWebhookTypes, event_type: str, - event_data: Optional[dict] = None, + event_data: dict | None = None, ) -> OracleEvent: sender_events_mapping = { OracleWebhookTypes.job_launcher: JobLauncherEventTypes, @@ -91,10 +87,10 @@ def parse_event( sender_events = sender_events_mapping.get(sender) if sender_events is not None: - if not event_type in sender_events: + if event_type not in sender_events: raise ValueError(f"Unknown event '{sender}.{event_type}'") else: - assert False, f"Unknown event sender type '{sender}'" + raise AssertionError(f"Unknown event sender type '{sender}'") event_class = get_class_for_event_type(event_type) return event_class.parse_obj(event_data or {}) diff --git a/packages/examples/cvat/exchange-oracle/src/core/tasks/boxes_from_points.py b/packages/examples/cvat/exchange-oracle/src/core/tasks/boxes_from_points.py index c9320473b6..bd26aacc14 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/tasks/boxes_from_points.py +++ b/packages/examples/cvat/exchange-oracle/src/core/tasks/boxes_from_points.py @@ -1,14 +1,14 @@ import os +from collections.abc import Sequence from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Sequence import attrs import datumaro as dm from attrs import frozen from datumaro.util import dump_json, parse_json -BboxPointMapping = Dict[int, int] +BboxPointMapping = dict[int, int] @frozen @@ -28,7 +28,7 @@ def asdict(self) -> dict: RoiInfos = Sequence[RoiInfo] -RoiFilenames = Dict[int, str] +RoiFilenames = dict[int, str] class TaskMetaLayout: diff --git a/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py b/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py index 404e3b06ea..dc70d33a91 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py +++ b/packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py @@ -1,7 +1,7 @@ import os +from collections.abc import Sequence from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Sequence, Tuple import attrs import datumaro as dm @@ -12,7 +12,7 @@ DEFAULT_ASSIGNMENT_SIZE_MULTIPLIER = Config.core_config.skeleton_assignment_size_mult -SkeletonBboxMapping = Dict[int, int] +SkeletonBboxMapping = dict[int, int] @frozen(kw_only=True) @@ -37,9 +37,9 @@ def asdict(self) -> dict: RoiInfos = Sequence[RoiInfo] -RoiFilenames = Dict[int, str] +RoiFilenames = dict[int, str] -PointLabelsMapping = Dict[Tuple[str, str], str] +PointLabelsMapping = dict[tuple[str, str], str] "(skeleton, point) -> job point name" diff --git a/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py b/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py index 00e873f8e4..2682f6ff95 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py @@ -109,53 +109,48 @@ def handle_job_launcher_event(webhook: Webhook, *, db_session: Session, logger: raise case JobLauncherEventTypes.escrow_canceled: - try: - validate_escrow( - webhook.chain_id, - webhook.escrow_address, - accepted_states=[EscrowStatus.Pending, EscrowStatus.Cancelled], - ) + validate_escrow( + webhook.chain_id, + webhook.escrow_address, + accepted_states=[EscrowStatus.Pending, EscrowStatus.Cancelled], + ) - projects = cvat_db_service.get_projects_by_escrow_address( - db_session, webhook.escrow_address, for_update=True, limit=None + projects = cvat_db_service.get_projects_by_escrow_address( + db_session, webhook.escrow_address, for_update=True, limit=None + ) + if not projects: + logger.error( + "Received escrow cancel event " + f"(escrow_address={webhook.escrow_address}). " + "The project doesn't exist, ignoring" ) - if not projects: + return + + for project in projects: + if project.status in [ + ProjectStatuses.canceled, + ProjectStatuses.recorded, + ]: logger.error( "Received escrow cancel event " f"(escrow_address={webhook.escrow_address}). " - "The project doesn't exist, ignoring" - ) - return - - for project in projects: - if project.status in [ - ProjectStatuses.canceled, - ProjectStatuses.recorded, - ]: - logger.error( - "Received escrow cancel event " - f"(escrow_address={webhook.escrow_address}). " - "The project is already finished, ignoring" - ) - continue - - logger.info( - f"Received escrow cancel event (escrow_address={webhook.escrow_address}). " - "Canceling the project" - ) - cvat_db_service.update_project_status( - db_session, project.id, ProjectStatuses.canceled + "The project is already finished, ignoring" ) + continue - cvat_db_service.finish_escrow_creations_by_escrow_address( - db_session, escrow_address=webhook.escrow_address, chain_id=webhook.chain_id + logger.info( + f"Received escrow cancel event (escrow_address={webhook.escrow_address}). " + "Canceling the project" + ) + cvat_db_service.update_project_status( + db_session, project.id, ProjectStatuses.canceled ) - except Exception as ex: - raise - + cvat_db_service.finish_escrow_creations_by_escrow_address( + db_session, escrow_address=webhook.escrow_address, chain_id=webhook.chain_id + ) case _: - assert False, f"Unknown job launcher event {webhook.event_type}" + raise AssertionError(f"Unknown job launcher event {webhook.event_type}") def process_outgoing_job_launcher_webhooks(): diff --git a/packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py b/packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py index b0293a89be..c49fac60b9 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py @@ -68,10 +68,8 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) if not project_ids: logger.error( - "Unexpected event {} received for an unknown project, " - "ignoring (escrow_address={})".format( - webhook.event_type, webhook.escrow_address - ) + f"Unexpected event {webhook.event_type} received for an unknown project, " + f"ignoring (escrow_address={webhook.escrow_address})" ) return @@ -148,7 +146,7 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg cvat_db_service.update_project_status(db_session, project.id, new_status) case _: - assert False, f"Unknown recording oracle event {webhook.event_type}" + raise AssertionError(f"Unknown recording oracle event {webhook.event_type}") def process_outgoing_recording_oracle_webhooks(): diff --git a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py index 09745773e1..868569b522 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py @@ -1,5 +1,3 @@ -from typing import List - from sqlalchemy import exc as sa_errors import src.cvat.api_calls as cvat_api @@ -221,8 +219,8 @@ def track_task_creation() -> None: ) ) - completed: List[cvat_models.DataUpload] = [] - failed: List[cvat_models.DataUpload] = [] + completed: list[cvat_models.DataUpload] = [] + failed: list[cvat_models.DataUpload] = [] for upload in uploads: status, reason = cvat_api.get_task_upload_status(upload.task_id) project = upload.task.project @@ -313,7 +311,7 @@ def track_escrow_creation() -> None: ) ) - finished: List[cvat_models.EscrowCreation] = [] + finished: list[cvat_models.EscrowCreation] = [] for creation in creations: created_jobs_count = cvat_service.count_jobs_by_escrow_address( session, diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/__init__.py b/packages/examples/cvat/exchange-oracle/src/cvat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index 3b5a7b4a8a..75c4ad5f31 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -2,6 +2,7 @@ import json import logging import zipfile +from collections.abc import Generator from contextlib import contextmanager from contextvars import ContextVar from datetime import timedelta @@ -9,7 +10,7 @@ from http import HTTPStatus from io import BytesIO from time import sleep -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models from cvat_sdk.api_client.api_client import Endpoint @@ -50,7 +51,7 @@ def _get_annotations( cvat_id: int, format_name: str, attempt_interval: int = 5, - timeout: Optional[int] = _NOTSET, + timeout: int | None = _NOTSET, ) -> io.RawIOBase: """ Downloads annotations. @@ -125,13 +126,13 @@ def create_cloudstorage( provider: str, bucket_name: str, *, - credentials: Optional[Dict[str, Any]] = None, - bucket_host: Optional[str] = None, + credentials: dict[str, Any] | None = None, + bucket_host: str | None = None, ) -> models.CloudStorageRead: # credentials: access_key | secret_key | service_account_key # CVAT credentials: key | secret_key | key_file - def _to_cvat_credentials(credentials: Dict[str, Any]) -> Dict: - cvat_credentials = dict() + def _to_cvat_credentials(credentials: dict[str, Any]) -> dict: + cvat_credentials = {} for cvat_field, field in { "key": "access_key", "secret_key": "secret_key", @@ -147,7 +148,7 @@ def _to_cvat_credentials(credentials: Dict[str, Any]) -> Dict: cvat_credentials[cvat_field] = value return cvat_credentials - request_kwargs = dict() + request_kwargs = {} if credentials: request_kwargs.update(_to_cvat_credentials(credentials)) @@ -186,7 +187,7 @@ def _to_cvat_credentials(credentials: Dict[str, Any]) -> Dict: def create_project( - name: str, *, labels: Optional[list] = None, user_guide: str = "" + name: str, *, labels: list | None = None, user_guide: str = "" ) -> models.ProjectRead: logger = logging.getLogger("app") with get_api_client() as api_client: @@ -239,7 +240,7 @@ def request_project_annotations(cvat_id: int, format_name: str) -> bool: def get_project_annotations( - cvat_id: int, format_name: str, *, timeout: Optional[int] = _NOTSET + cvat_id: int, format_name: str, *, timeout: int | None = _NOTSET ) -> io.RawIOBase: """ Downloads annotations. @@ -314,7 +315,7 @@ def create_task(project_id: int, name: str) -> models.TaskRead: raise -def get_cloudstorage_contents(cloudstorage_id: int) -> List[str]: +def get_cloudstorage_contents(cloudstorage_id: int) -> list[str]: logger = logging.getLogger("app") with get_api_client() as api_client: try: @@ -332,7 +333,7 @@ def put_task_data( task_id: int, cloudstorage_id: int, *, - filenames: Optional[list[str]] = None, + filenames: list[str] | None = None, sort_images: bool = True, ) -> None: logger = logging.getLogger("app") @@ -355,7 +356,7 @@ def put_task_data( ) try: (_, response) = api_client.tasks_api.create_data(task_id, data_request=data_request) - return None + return except exceptions.ApiException as e: logger.exception(f"Exception when calling ProjectsApi.put_task_data: {e}\n") @@ -388,7 +389,7 @@ def request_task_annotations(cvat_id: int, format_name: str) -> bool: def get_task_annotations( - cvat_id: int, format_name: str, *, timeout: Optional[int] = _NOTSET + cvat_id: int, format_name: str, *, timeout: int | None = _NOTSET ) -> io.RawIOBase: """ Downloads annotations. @@ -418,16 +419,15 @@ def get_task_annotations( raise -def fetch_task_jobs(task_id: int) -> List[models.JobRead]: +def fetch_task_jobs(task_id: int) -> list[models.JobRead]: logger = logging.getLogger("app") with get_api_client() as api_client: try: - data = get_paginated_collection( + return get_paginated_collection( api_client.jobs_api.list_endpoint, task_id=task_id, type="annotation", ) - return data except exceptions.ApiException as e: logger.exception(f"Exception when calling JobsApi.list: {e}\n") raise @@ -459,7 +459,7 @@ def request_job_annotations(cvat_id: int, format_name: str) -> bool: def get_job_annotations( - cvat_id: int, format_name: str, *, timeout: Optional[int] = _NOTSET + cvat_id: int, format_name: str, *, timeout: int | None = _NOTSET ) -> io.RawIOBase: """ Downloads annotations. @@ -509,13 +509,13 @@ def delete_cloudstorage(cvat_id: int) -> None: raise -def fetch_projects(assignee: str = "") -> List[models.ProjectRead]: +def fetch_projects(assignee: str = "") -> list[models.ProjectRead]: logger = logging.getLogger("app") with get_api_client() as api_client: try: return get_paginated_collection( api_client.projects_api.list_endpoint, - **(dict(assignee=assignee) if assignee else {}), + **({"assignee": assignee} if assignee else {}), ) except exceptions.ApiException as e: logger.exception(f"Exception when calling ProjectsApi.list(): {e}\n") @@ -529,7 +529,7 @@ class UploadStatus(str, Enum, metaclass=BetterEnumMeta): FAILED = "Failed" -def get_task_upload_status(cvat_id: int) -> Tuple[Optional[UploadStatus], str]: +def get_task_upload_status(cvat_id: int) -> tuple[UploadStatus | None, str]: logger = logging.getLogger("app") with get_api_client() as api_client: @@ -557,13 +557,13 @@ def clear_job_annotations(job_id: int) -> None: ) except exceptions.ApiException as e: if e.status == 404: - return None + return logger.exception(f"Exception when calling JobsApi.partial_update_annotations(): {e}\n") raise -def update_job_assignee(id: str, assignee_id: Optional[int]): +def update_job_assignee(id: str, assignee_id: int | None): logger = logging.getLogger("app") with get_api_client() as api_client: @@ -577,7 +577,7 @@ def update_job_assignee(id: str, assignee_id: Optional[int]): raise -def restart_job(id: str, *, assignee_id: Optional[int] = None): +def restart_job(id: str, *, assignee_id: int | None = None): logger = logging.getLogger("app") with get_api_client() as api_client: @@ -615,7 +615,7 @@ def remove_user_from_org(user_id: int): with get_api_client() as api_client: try: (page, _) = api_client.users_api.list( - filter='{"==":[{"var":"id"},"%s"]}' % (user_id,), + filter='{"==":[{"var":"id"},"%s"]}' % user_id, # noqa: UP031 org=Config.cvat_config.cvat_org_slug, ) if not page.results: diff --git a/packages/examples/cvat/exchange-oracle/src/db/__init__.py b/packages/examples/cvat/exchange-oracle/src/db/__init__.py index 6e9c85cded..68068770e3 100644 --- a/packages/examples/cvat/exchange-oracle/src/db/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/db/__init__.py @@ -8,7 +8,7 @@ engine = sqlalchemy.create_engine( DATABASE_URL, echo="debug" if Config.loglevel <= src.utils.logging.TRACE else False, - connect_args={"options": "-c lock_timeout={:d}".format(Config.postgres_config.lock_timeout)}, + connect_args={"options": f"-c lock_timeout={Config.postgres_config.lock_timeout:d}"}, ) SessionLocal = sessionmaker(autocommit=False, bind=engine) diff --git a/packages/examples/cvat/exchange-oracle/src/db/errors.py b/packages/examples/cvat/exchange-oracle/src/db/errors.py index 9eeb655405..fbd7606bca 100644 --- a/packages/examples/cvat/exchange-oracle/src/db/errors.py +++ b/packages/examples/cvat/exchange-oracle/src/db/errors.py @@ -3,6 +3,8 @@ if db_engine.driver != "psycopg2": raise NotImplementedError +__all__ = ["LockNotAvailable"] + from psycopg2.errors import LockNotAvailable # These errors can be found, e.g., in the .orig field of sqlalchemy errors diff --git a/packages/examples/cvat/exchange-oracle/src/db/utils.py b/packages/examples/cvat/exchange-oracle/src/db/utils.py index 24dfb41561..77651fa42a 100644 --- a/packages/examples/cvat/exchange-oracle/src/db/utils.py +++ b/packages/examples/cvat/exchange-oracle/src/db/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TypeVar, Union +from typing import TypeVar from sqlalchemy import Select from sqlalchemy.orm import Query @@ -14,14 +14,11 @@ class ForUpdateParams: T = TypeVar("T", Query, Select) -def maybe_for_update(query: T, enable: Union[bool, ForUpdateParams]) -> T: +def maybe_for_update(query: T, enable: bool | ForUpdateParams) -> T: if not enable: return query - if isinstance(enable, ForUpdateParams): - params = enable - else: - params = ForUpdateParams() + params = enable if isinstance(enable, ForUpdateParams) else ForUpdateParams() return query.with_for_update( skip_locked=params.skip_locked, diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py b/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py index b4466bee8b..8845a4b52c 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py @@ -1,4 +1,5 @@ -""" API endpoints """ +"""API endpoints""" + from fastapi import APIRouter, FastAPI from src.core.config import Config @@ -24,11 +25,11 @@ def meta_route() -> MetaResponse: ] return MetaResponse.parse_obj( - dict( - message="Exchange Oracle API", - version="0.1.0", - supported_networks=networks_info, - ) + { + "message": "Exchange Oracle API", + "version": "0.1.0", + "supported_networks": networks_info, + } ) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py index fabeb3543f..0d3f3ab087 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py @@ -1,6 +1,5 @@ from contextlib import suppress from http import HTTPStatus -from typing import Optional from fastapi import APIRouter, Header, HTTPException, Path, Query @@ -16,7 +15,7 @@ @router.get("/tasks", description="Lists available tasks") async def list_tasks( - wallet_address: Optional[str] = Query(default=None), + wallet_address: str | None = Query(default=None), signature: str = Header(description="Calling service signature"), ) -> list[TaskResponse]: await validate_human_app_signature(signature) @@ -57,7 +56,7 @@ async def register( status_code=HTTPStatus.NOT_FOUND, detail="User with this email not found" ) from e - elif ( + if ( e.status == HTTPStatus.BAD_REQUEST and "The user is a member of the organization already." in e.body ): @@ -67,9 +66,7 @@ async def register( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="User already exists" ) - elif ( - e.status == HTTPStatus.BAD_REQUEST and "Enter a valid email address." in e.body - ): + if e.status == HTTPStatus.BAD_REQUEST and "Enter a valid email address." in e.body: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Invalid email address" ) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py index a346eddcc6..e111f1e2f3 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -1,6 +1,7 @@ import json import time -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import fastapi import packaging.version as pv @@ -61,7 +62,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): """ @staticmethod - async def _set_body(request: Request, body: bytes): + async def _set_body(request: Request, body: bytes) -> None: # Before FastAPI 0.108.0 infinite hang is expected, # if request body is awaited more than once. # It's not needed when using FastAPI >= 0.108.0. @@ -119,7 +120,7 @@ async def _log_request(self, request: Request) -> dict[str, Any]: try: body = await request.body() await self._set_body(request, body) - except Exception: + except Exception: # noqa: BLE001 body = None else: if body is not None: @@ -188,7 +189,7 @@ async def _execute_request(self, call_next: Callable, request: Request) -> Respo except Exception as e: self.logger.exception({"path": request.url.path, "method": request.method, "reason": e}) - raise e + raise else: return response diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/webhook.py b/packages/examples/cvat/exchange-oracle/src/endpoints/webhook.py index a770ce7c1f..8d8d34c432 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/webhook.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/webhook.py @@ -1,5 +1,3 @@ -from typing import Union - from fastapi import APIRouter, Header, HTTPException, Request import src.services.webhook as oracle_db_service @@ -14,7 +12,7 @@ async def receive_oracle_webhook( webhook: OracleWebhook, request: Request, - human_signature: Union[str, None] = Header(default=None), + human_signature: str | None = Header(default=None), ) -> OracleWebhookResponse: try: sender_type = await validate_oracle_webhook_signature(request, human_signature, webhook) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/__init__.py b/packages/examples/cvat/exchange-oracle/src/handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py b/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py index 7e031f20a7..346d8eda29 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py @@ -2,8 +2,9 @@ import itertools import logging from collections import Counter +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any from datumaro.util import take_by from sqlalchemy import exc as sa_errors @@ -43,7 +44,7 @@ class _CompletedEscrowsHandler: 4. Prepares a webhook to recording oracle """ - def __init__(self, logger: Optional[logging.Logger]) -> None: + def __init__(self, logger: logging.Logger | None) -> None: self.logger = logger or NullLogger() def _download_with_retries( @@ -51,7 +52,7 @@ def _download_with_retries( download_callback: Callable[[], io.RawIOBase], retry_callback: Callable[[], Any], *, - max_attempts: Optional[int] = None, + max_attempts: int | None = None, ) -> io.RawIOBase: """ Sometimes CVAT downloading can fail with the 500 error. @@ -74,11 +75,12 @@ def _download_with_retries( retry_callback() else: raise + return None def _process_plain_escrows(self): logger = self.logger - plain_task_types = [t for t in TaskTypes if not t == TaskTypes.image_skeletons_from_boxes] + plain_task_types = [t for t in TaskTypes if t != TaskTypes.image_skeletons_from_boxes] with SessionLocal.begin() as session: completed_projects = cvat_service.get_projects_by_status( session, @@ -95,7 +97,7 @@ def _process_plain_escrows(self): # TODO: such escrows can fill all the queried completed projects # need to improve handling for such projects # (e.g. cancel depending on the escrow status) - logger.error( + logger.exception( "Failed to handle completed project id {} for escrow {}: {}".format( project.cvat_id, project.escrow_address, e ) @@ -113,7 +115,7 @@ def _process_plain_escrows(self): jobs = cvat_service.get_jobs_by_cvat_project_id(session, project.cvat_id) annotation_format = CVAT_EXPORT_FORMAT_MAPPING[project.job_type] - job_annotations: Dict[int, FileDescriptor] = {} + job_annotations: dict[int, FileDescriptor] = {} for jobs_batch in take_by( jobs, count=CronConfig.track_completed_escrows_jobs_downloading_batch_size @@ -170,7 +172,7 @@ def _process_plain_escrows(self): file=project_annotations_file, ) - annotation_files: List[FileDescriptor] = [] + annotation_files: list[FileDescriptor] = [] annotation_files.append(project_annotations_file_desc) annotation_metafile = prepare_annotation_metafile( @@ -253,10 +255,8 @@ def _process_skeletons_from_boxes_escrows(self): # TODO: such escrows can fill all the queried completed projects # need to improve handling for such projects # (e.g. cancel depending on the escrow status) - logger.error( - "Failed to handle completed projects for escrow {}: {}".format( - escrow_address, e - ) + logger.exception( + f"Failed to handle completed projects for escrow {escrow_address}: {e}" ) continue @@ -297,7 +297,7 @@ def _process_skeletons_from_boxes_escrows(self): f"Downloading results for the escrow (escrow_address={escrow_address})" ) - jobs: List[cvat_models.Job] = list( + jobs: list[cvat_models.Job] = list( itertools.chain.from_iterable( cvat_service.get_jobs_by_cvat_project_id(session, p.cvat_id) for p in escrow_projects @@ -305,7 +305,7 @@ def _process_skeletons_from_boxes_escrows(self): ) annotation_format = CVAT_EXPORT_FORMAT_MAPPING[manifest.annotation.type] - job_annotations: Dict[int, FileDescriptor] = {} + job_annotations: dict[int, FileDescriptor] = {} # Collect raw annotations from CVAT, validate and convert them # into a recording oracle suitable format @@ -347,7 +347,7 @@ def _process_skeletons_from_boxes_escrows(self): file=None, ) - annotation_files: List[FileDescriptor] = [] + annotation_files: list[FileDescriptor] = [] annotation_files.append(resulting_annotations_file_desc) annotation_metafile = prepare_annotation_metafile( diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py b/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py index 96353138da..7076bc225f 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py @@ -1,5 +1,3 @@ -from typing import List - from dateutil.parser import parse as parse_aware_datetime from sqlalchemy import exc as sa_errors @@ -42,7 +40,7 @@ def handle_update_job_event(payload: dict) -> None: webhook_time = parse_aware_datetime(payload.job["updated_date"]) webhook_assignee_id = (payload.job["assignee"] or {}).get("id") - job_assignments: List[models.Assignment] = sorted( + job_assignments: list[models.Assignment] = sorted( job_assignments, key=lambda a: a.created_at, reverse=True ) latest_assignment = job.assignments[0] diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py b/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py index 86d410aabb..d476e9af6a 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py @@ -1,4 +1,5 @@ -""" Custom error handlers for the FastAPI""" +"""Custom error handlers for the FastAPI""" + from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index e95435baa6..93532aba4b 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -6,10 +6,9 @@ from contextlib import ExitStack from dataclasses import dataclass, field from itertools import chain, groupby -from logging import Logger from math import ceil from tempfile import TemporaryDirectory -from typing import Dict, List, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, TypeVar, Union, cast import cv2 import datumaro as dm @@ -26,7 +25,6 @@ import src.services.cvat as db_service from src.chain.escrow import get_escrow_manifest from src.core.config import Config -from src.core.manifest import TaskManifest from src.core.storage import compose_data_bucket_filename from src.core.types import CvatLabelTypes, TaskStatuses, TaskTypes from src.db import SessionLocal @@ -37,6 +35,12 @@ from src.utils.assignments import parse_manifest from src.utils.logging import NullLogger, get_function_logger +if TYPE_CHECKING: + from collections.abc import Sequence + from logging import Logger + + from src.core.manifest import TaskManifest + module_logger = f"{ROOT_LOGGER_NAME}.cron.cvat" LABEL_TYPE_MAPPING = { @@ -103,7 +107,7 @@ def __bool__(self) -> bool: _unset = _Undefined() -_MaybeUnset = Union[T, _Undefined] +_MaybeUnset = T | _Undefined @dataclass @@ -115,7 +119,7 @@ class _ExcludedAnnotationInfo: @dataclass class _ExcludedAnnotationsInfo: - messages: List[_ExcludedAnnotationInfo] = field(default_factory=list) + messages: list[_ExcludedAnnotationInfo] = field(default_factory=list) excluded_count: int = 0 "The number of excluded annotations. Can be different from len(messages)" @@ -135,7 +139,7 @@ class SimpleTaskBuilder: Handles task creation for IMAGE_POINTS and IMAGE_BOXES task types """ - def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): + def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -> None: self.exit_stack = ExitStack() self.manifest = manifest self.escrow_address = escrow_address @@ -163,7 +167,7 @@ def set_logger(self, logger: Logger): return self def _format_list( - self, items: Sequence[str], *, max_items: int = None, separator: str = ", " + self, items: Sequence[str], *, max_items: int | None = None, separator: str = ", " ) -> str: if max_items is None: max_items = self.list_display_threshold @@ -198,7 +202,7 @@ def _upload_task_meta(self, gt_dataset: dm.Dataset): ) def _parse_gt_dataset( - self, gt_file_data: bytes, *, add_prefix: Optional[str] = None + self, gt_file_data: bytes, *, add_prefix: str | None = None ) -> dm.Dataset: with TemporaryDirectory() as gt_temp_dir: gt_filename = os.path.join(gt_temp_dir, "gt_annotations.json") @@ -222,8 +226,8 @@ def _parse_gt_dataset( return gt_dataset def _get_gt_filenames( - self, gt_dataset: dm.Dataset, data_filenames: List[str], *, manifest: TaskManifest - ) -> List[str]: + self, gt_dataset: dm.Dataset, data_filenames: list[str], *, manifest: TaskManifest + ) -> list[str]: gt_filenames = set(s.id + s.media.ext for s in gt_dataset) known_data_filenames = set(data_filenames) matched_gt_filenames = gt_filenames.intersection(known_data_filenames) @@ -246,14 +250,14 @@ def _get_gt_filenames( def _make_job_configuration( self, - data_filenames: List[str], - gt_filenames: List[str], + data_filenames: list[str], + gt_filenames: list[str], *, manifest: TaskManifest, - ) -> List[List[str]]: + ) -> list[list[str]]: # Make job layouts wrt. manifest params, 1 job per task (CVAT can't repeat images in jobs) gt_filenames_index = set(gt_filenames) - data_filenames = [fn for fn in data_filenames if not fn in gt_filenames_index] + data_filenames = [fn for fn in data_filenames if fn not in gt_filenames_index] random.shuffle(data_filenames) job_layout = [] @@ -357,7 +361,7 @@ def build(self): class BoxesFromPointsTaskBuilder: - def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): + def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -> None: self.exit_stack = ExitStack() self.manifest = manifest self.escrow_address = escrow_address @@ -376,7 +380,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): self._bbox_point_mapping: _MaybeUnset[boxes_from_points_task.BboxPointMapping] = _unset "bbox_id -> point_id" - self._roi_size_estimations: _MaybeUnset[Dict[int, Tuple[float, float]]] = _unset + self._roi_size_estimations: _MaybeUnset[dict[int, tuple[float, float]]] = _unset "label_id -> (rel. w, rel. h)" self._rois: _MaybeUnset[boxes_from_points_task.RoiInfos] = _unset @@ -613,7 +617,7 @@ def _validate_gt(self): self._validate_gt_annotations() def _format_list( - self, items: Sequence[str], *, max_items: int = None, separator: str = ", " + self, items: Sequence[str], *, max_items: int | None = None, separator: str = ", " ) -> str: if max_items is None: max_items = self.list_display_threshold @@ -683,7 +687,7 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox): if len(skeleton.elements) != 1: raise DatasetValidationError( - "invalid points count ({}), expected 1".format(len(skeleton.elements)) + f"invalid points count ({len(skeleton.elements)}), expected 1" ) point = skeleton.elements[0] @@ -776,9 +780,9 @@ def _is_point_in_bbox(px: float, py: float, bbox: dm.Bbox) -> bool: def _prepare_gt(self): def _find_unambiguous_matches( - input_skeletons: List[dm.Skeleton], - gt_boxes: List[dm.Bbox], - ) -> List[Tuple[dm.Skeleton, dm.Bbox]]: + input_skeletons: list[dm.Skeleton], + gt_boxes: list[dm.Bbox], + ) -> list[tuple[dm.Skeleton, dm.Bbox]]: matches = [ [ (input_skeleton.label == gt_bbox.label) @@ -795,7 +799,7 @@ def _find_unambiguous_matches( ambiguous_boxes: list[int] = set() ambiguous_skeletons: list[int] = set() for skeleton_idx, input_skeleton in enumerate(input_skeletons): - matched_boxes: List[dm.Bbox] = [ + matched_boxes: list[dm.Bbox] = [ gt_boxes[j] for j in range(len(gt_boxes)) if matches[skeleton_idx][j] ] @@ -818,7 +822,7 @@ def _find_unambiguous_matches( continue for gt_idx, gt_bbox in enumerate(gt_boxes): - matched_skeletons: List[dm.Skeleton] = [ + matched_skeletons: list[dm.Skeleton] = [ input_skeletons[i] for i in range(len(input_skeletons)) if matches[i][gt_idx] ] @@ -839,7 +843,7 @@ def _find_unambiguous_matches( ambiguous_boxes.add(gt_bbox.id) ambiguous_skeletons.update(a.id for a in matched_skeletons) continue - elif not matched_skeletons: + if not matched_skeletons: # Handle unmatched skeletons excluded_gt_info.add_message( "Sample '{}': GT bbox #{} ({}) skipped - " @@ -854,7 +858,7 @@ def _find_unambiguous_matches( excluded_gt_info.excluded_count += 1 # an error continue - unambiguous_matches: List[Tuple[dm.Bbox, dm.Skeleton]] = [] + unambiguous_matches: list[tuple[dm.Bbox, dm.Skeleton]] = [] for skeleton_idx, input_skeleton in enumerate(input_skeletons): if input_skeleton.id in ambiguous_skeletons: continue @@ -874,9 +878,9 @@ def _find_unambiguous_matches( return unambiguous_matches def _find_good_gt_boxes( - input_skeletons: List[dm.Skeleton], - gt_boxes: List[dm.Bbox], - ) -> List[dm.Bbox]: + input_skeletons: list[dm.Skeleton], + gt_boxes: list[dm.Bbox], + ) -> list[dm.Bbox]: matches = _find_unambiguous_matches(input_skeletons, gt_boxes) matched_boxes = [] @@ -1030,7 +1034,7 @@ def _estimate_roi_sizes(self): if classes_with_default_roi: label_cat = self._gt_dataset.categories()[dm.AnnotationType.label] labels_by_reason = { - g_reason: list(v[0] for v in g_items) + g_reason: [v[0] for v in g_items] for g_reason, g_items in groupby( sorted(classes_with_default_roi.items(), key=lambda v: v[1]), key=lambda v: v[1] ) @@ -1054,7 +1058,7 @@ def _prepare_roi_info(self): assert self._roi_size_estimations is not _unset assert self._points_dataset is not _unset - rois: List[boxes_from_points_task.RoiInfo] = [] + rois: list[boxes_from_points_task.RoiInfo] = [] for sample in self._points_dataset: for skeleton in sample.annotations: if not isinstance(skeleton, dm.Skeleton): @@ -1133,9 +1137,9 @@ def _prepare_job_layout(self): data_filenames = [ fn for point_id, fn in self._roi_filenames.items() - if not point_id in gt_point_ids - if not original_image_id_to_filename[point_id_to_original_image_id[point_id]] - in input_gt_filenames + if point_id not in gt_point_ids + if original_image_id_to_filename[point_id_to_original_image_id[point_id]] + not in input_gt_filenames ] random.shuffle(data_filenames) @@ -1234,7 +1238,7 @@ def _draw_roi_point( (255, 255, 255), cv2.FILLED, ) - roi_pixels = cv2.circle( + return cv2.circle( roi_pixels, center, point_size, @@ -1242,10 +1246,9 @@ def _draw_roi_point( cv2.FILLED, ) - return roi_pixels - def _extract_and_upload_rois(self): - # TODO: maybe optimize via splitting into separate threads (downloading, uploading, processing) + # TODO: maybe optimize via splitting into separate + # threads (downloading, uploading, processing) # Watch for the memory used, as the whole dataset can be quite big (gigabytes, terabytes) # Consider also packing RoIs cut into archives @@ -1267,8 +1270,10 @@ def _extract_and_upload_rois(self): filename_to_sample = {sample.image.path: sample for sample in self._points_dataset} - _roi_key = lambda e: e.original_image_key - rois_by_image: Dict[str, Sequence[boxes_from_points_task.RoiInfo]] = { + def _roi_key(e): + return e.original_image_key + + rois_by_image: dict[str, Sequence[boxes_from_points_task.RoiInfo]] = { image_id_to_filename[image_id]: list(g) for image_id, g in groupby(sorted(self._rois, key=_roi_key), key=_roi_key) } @@ -1420,9 +1425,9 @@ class SkeletonsFromBoxesTaskBuilder: @dataclass class _JobParams: label_id: int - roi_ids: List[int] + roi_ids: list[int] - def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): + def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -> None: self.exit_stack = ExitStack() self.manifest = manifest self.escrow_address = escrow_address @@ -1438,12 +1443,12 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): self._gt_dataset: _MaybeUnset[dm.Dataset] = _unset self._boxes_dataset: _MaybeUnset[dm.Dataset] = _unset - self._skeleton_bbox_mapping: _MaybeUnset[ - skeletons_from_boxes_task.SkeletonBboxMapping - ] = _unset + self._skeleton_bbox_mapping: _MaybeUnset[skeletons_from_boxes_task.SkeletonBboxMapping] = ( + _unset + ) self._roi_infos: _MaybeUnset[skeletons_from_boxes_task.RoiInfos] = _unset - self._roi_filenames: _MaybeUnset[Dict[int, str]] = _unset - self._job_params: _MaybeUnset[List[self._JobParams]] = _unset + self._roi_filenames: _MaybeUnset[dict[int, str]] = _unset + self._job_params: _MaybeUnset[list[self._JobParams]] = _unset self._excluded_gt_info: _MaybeUnset[_ExcludedAnnotationsInfo] = _unset self._excluded_boxes_info: _MaybeUnset[_ExcludedAnnotationsInfo] = _unset @@ -1603,7 +1608,8 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox): for element in skeleton.elements: # This is what Datumaro is expected to parse - assert len(element.points) == 2 and len(element.visibility) == 1 + assert len(element.points) == 2 + assert len(element.visibility) == 1 if element.visibility[0] == dm.Points.Visibility.absent: continue @@ -1806,7 +1812,7 @@ def _validate_boxes(self): self._validate_boxes_annotations() def _format_list( - self, items: Sequence[str], *, max_items: int = None, separator: str = ", " + self, items: Sequence[str], *, max_items: int | None = None, separator: str = ", " ) -> str: if max_items is None: max_items = self.list_display_threshold @@ -1849,11 +1855,11 @@ def _get_skeleton_bbox( def _prepare_gt(self): def _find_unambiguous_matches( - input_boxes: List[dm.Bbox], - gt_skeletons: List[dm.Skeleton], + input_boxes: list[dm.Bbox], + gt_skeletons: list[dm.Skeleton], *, - gt_annotations: List[dm.Annotation], - ) -> List[Tuple[dm.Bbox, dm.Skeleton]]: + gt_annotations: list[dm.Annotation], + ) -> list[tuple[dm.Bbox, dm.Skeleton]]: matches = [ [ (input_bbox.label == gt_skeleton.label) @@ -1871,7 +1877,7 @@ def _find_unambiguous_matches( ambiguous_boxes: list[int] = set() ambiguous_skeletons: list[int] = set() for bbox_idx, input_bbox in enumerate(input_boxes): - matched_skeletons: List[dm.Skeleton] = [ + matched_skeletons: list[dm.Skeleton] = [ gt_skeletons[j] for j in range(len(gt_skeletons)) if matches[bbox_idx][j] ] @@ -1894,7 +1900,7 @@ def _find_unambiguous_matches( continue for skeleton_idx, gt_skeleton in enumerate(gt_skeletons): - matched_boxes: List[dm.Bbox] = [ + matched_boxes: list[dm.Bbox] = [ input_boxes[i] for i in range(len(input_boxes)) if matches[i][skeleton_idx] ] @@ -1915,7 +1921,7 @@ def _find_unambiguous_matches( ambiguous_skeletons.add(gt_skeleton.id) ambiguous_boxes.update(b.id for b in matched_boxes) continue - elif not matched_boxes: + if not matched_boxes: # Handle unmatched skeletons excluded_gt_info.add_message( "Sample '{}': GT skeleton #{} ({}) skipped - " @@ -1930,7 +1936,7 @@ def _find_unambiguous_matches( excluded_gt_info.excluded_count += 1 # an error continue - unambiguous_matches: List[Tuple[dm.Bbox, dm.Skeleton]] = [] + unambiguous_matches: list[tuple[dm.Bbox, dm.Skeleton]] = [] for bbox_idx, input_bbox in enumerate(input_boxes): if input_bbox.id in ambiguous_boxes: continue @@ -1950,11 +1956,11 @@ def _find_unambiguous_matches( return unambiguous_matches def _find_good_gt_skeletons( - input_boxes: List[dm.Bbox], - gt_skeletons: List[dm.Skeleton], + input_boxes: list[dm.Bbox], + gt_skeletons: list[dm.Skeleton], *, - gt_annotations: List[dm.Annotation], - ) -> List[dm.Bbox]: + gt_annotations: list[dm.Annotation], + ) -> list[dm.Bbox]: matches = _find_unambiguous_matches( input_boxes, gt_skeletons, gt_annotations=gt_annotations ) @@ -2080,7 +2086,7 @@ def _prepare_roi_infos(self): assert self._gt_dataset is not _unset assert self._boxes_dataset is not _unset - rois: List[skeletons_from_boxes_task.RoiInfo] = [] + rois: list[skeletons_from_boxes_task.RoiInfo] = [] for sample in self._boxes_dataset: for bbox in sample.annotations: if not isinstance(bbox, dm.Bbox): @@ -2150,7 +2156,7 @@ def _prepare_job_params(self): gt_ratio = self.manifest.validation.val_size / (self.manifest.annotation.job_size or 1) job_size_mult = self.job_size_mult - job_params: List[self._JobParams] = [] + job_params: list[self._JobParams] = [] roi_info_by_id = {roi_info.bbox_id: roi_info for roi_info in self._roi_infos} for label_id, _ in enumerate(self.manifest.annotation.labels): @@ -2289,8 +2295,10 @@ def _extract_and_upload_rois(self): filename_to_sample = {sample.image.path: sample for sample in self._boxes_dataset} - _roi_info_key = lambda e: e.original_image_key - roi_info_by_image: Dict[str, Sequence[skeletons_from_boxes_task.RoiInfo]] = { + def _roi_info_key(e): + return e.original_image_key + + roi_info_by_image: dict[str, Sequence[skeletons_from_boxes_task.RoiInfo]] = { image_id_to_filename[image_id]: list(g) for image_id, g in groupby( sorted(self._roi_infos, key=_roi_info_key), key=_roi_info_key @@ -2339,7 +2347,9 @@ def _create_on_cvat(self): assert self._job_params is not _unset assert self.point_labels is not _unset - _job_params_label_key = lambda ts: ts.label_id + def _job_params_label_key(ts): + return ts.label_id + jobs_by_skeleton_label = { skeleton_label_id: list(g) for skeleton_label_id, g in groupby( @@ -2386,7 +2396,7 @@ def _create_on_cvat(self): # Each skeleton point uses the same file layout in jobs skeleton_label_filenames = [] for skeleton_label_job in skeleton_label_jobs: - skeleton_label_filenames.append( + skeleton_label_filenames.append( # noqa: PERF401 [ compose_data_bucket_filename( self.escrow_address, self.chain_id, self._roi_filenames[roi_id] @@ -2444,8 +2454,10 @@ def _create_on_cvat(self): ) db_service.get_task_by_id(session, task_id, for_update=True) # lock the row - # Actual task creation in CVAT takes some time, so it's done in an async process. - # The task is fully created once 'update:task' or 'update:job' webhook is received. + # Actual task creation in CVAT takes some time, + # so it's done in an async process. + # The task is fully created once 'update:task' or 'update:job' + # webhook is received. cvat_api.put_task_data( cvat_task.id, cvat_cloud_storage.id, @@ -2485,15 +2497,15 @@ def is_image(path: str) -> bool: return trunk and ext.lower() in IMAGE_EXTENSIONS -def filter_image_files(data_filenames: List[str]) -> List[str]: - return list(fn for fn in data_filenames if is_image(fn)) +def filter_image_files(data_filenames: list[str]) -> list[str]: + return [fn for fn in data_filenames if is_image(fn)] -def strip_bucket_prefix(data_filenames: List[str], prefix: str) -> List[str]: - return list(os.path.relpath(fn, prefix) for fn in data_filenames) +def strip_bucket_prefix(data_filenames: list[str], prefix: str) -> list[str]: + return [os.path.relpath(fn, prefix) for fn in data_filenames] -def make_label_configuration(manifest: TaskManifest) -> List[dict]: +def make_label_configuration(manifest: TaskManifest) -> list[dict]: return [ { "name": label.name, @@ -2503,7 +2515,7 @@ def make_label_configuration(manifest: TaskManifest) -> List[dict]: ] -def _make_cvat_cloud_storage_params(bucket_info: BucketAccessInfo) -> Dict: +def _make_cvat_cloud_storage_params(bucket_info: BucketAccessInfo) -> dict: CLOUD_PROVIDER_TO_CVAT_CLOUD_PROVIDER = { CloudProviders.aws: "AWS_S3_BUCKET", CloudProviders.gcs: "GOOGLE_CLOUD_STORAGE", diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py index 8a1df6cac1..a175d58bbd 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_export.py @@ -3,7 +3,6 @@ import zipfile from dataclasses import dataclass from tempfile import TemporaryDirectory -from typing import Dict, List, Optional, Type import datumaro as dm from datumaro.components.dataset import Dataset @@ -39,11 +38,11 @@ @dataclass class FileDescriptor: filename: str - file: Optional[io.RawIOBase] + file: io.RawIOBase | None def prepare_annotation_metafile( - jobs: List[Job], job_annotations: Dict[int, FileDescriptor] + jobs: list[Job], job_annotations: dict[int, FileDescriptor] ) -> FileDescriptor: """ Prepares a task/project annotation descriptor file with annotator mapping. @@ -69,12 +68,12 @@ def __init__( self, escrow_address: str, chain_id: int, - annotations: List[FileDescriptor], + annotations: list[FileDescriptor], merged_annotation: FileDescriptor, *, manifest: TaskManifest, - project_images: List[Image], - ): + project_images: list[Image], + ) -> None: self.escrow_address = escrow_address self.chain_id = chain_id self.annotation_files = annotations @@ -123,7 +122,7 @@ def _process_annotation_file( output_dataset = self._process_dataset(input_dataset, ann_descriptor=ann_descriptor) self._export_dataset(output_dataset, output_dir) - def _parse_dataset(self, ann_descriptor: FileDescriptor, dataset_dir: str) -> dm.Dataset: + def _parse_dataset(self, ann_descriptor: FileDescriptor, dataset_dir: str) -> dm.Dataset: # noqa: ARG002 return dm.Dataset.import_from(dataset_dir, self.input_format) def _export_dataset(self, dataset: dm.Dataset, output_dir: str): @@ -176,7 +175,7 @@ def _process_dataset(self, dataset: Dataset, *, ann_descriptor: FileDescriptor) class _BoxesFromPointsTaskProcessor(_TaskProcessor): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) roi_filenames, roi_infos, points_dataset = self._download_task_meta() @@ -186,7 +185,7 @@ def __init__(self, *args, **kwargs): roi_info_by_id = {roi_info.point_id: roi_info for roi_info in roi_infos} - self.roi_name_to_roi_info: Dict[str, boxes_from_points_task.RoiInfo] = { + self.roi_name_to_roi_info: dict[str, boxes_from_points_task.RoiInfo] = { os.path.splitext(roi_filename)[0]: roi_info_by_id[roi_id] for roi_id, roi_filename in roi_filenames.items() } @@ -271,7 +270,7 @@ def _process_merged_dataset(self, input_dataset: Dataset) -> Dataset: class _SkeletonsFromBoxesTaskProcessor(_TaskProcessor): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) roi_filenames, roi_infos, boxes_dataset, job_label_mapping = self._download_task_meta() @@ -281,7 +280,7 @@ def __init__(self, *args, **kwargs): roi_info_by_id = {roi_info.bbox_id: roi_info for roi_info in roi_infos} - self.roi_name_to_roi_info: Dict[str, skeletons_from_boxes_task.RoiInfo] = { + self.roi_name_to_roi_info: dict[str, skeletons_from_boxes_task.RoiInfo] = { os.path.splitext(roi_filename)[0]: roi_info_by_id[roi_id] for roi_id, roi_filename in roi_filenames.items() } @@ -503,7 +502,7 @@ def _process_dataset(self, dataset: Dataset, *, ann_descriptor: FileDescriptor) skeleton_bbox.group = skeleton_group skeleton_bbox.label = converted_skeleton.label converted_job_dataset.put( - converted_sample.wrap(annotations=converted_sample.annotations + [skeleton_bbox]) + converted_sample.wrap(annotations=[*converted_sample.annotations, skeleton_bbox]) ) # Rename the job skeleton and point to the original names @@ -572,16 +571,16 @@ def process(self): def postprocess_annotations( escrow_address: str, chain_id: int, - annotations: List[FileDescriptor], + annotations: list[FileDescriptor], merged_annotation: FileDescriptor, *, manifest: TaskManifest, - project_images: List[Image], + project_images: list[Image], ) -> None: """ Processes annotations and updates the files list inplace """ - processor_classes: Dict[TaskTypes, Type[_TaskProcessor]] = { + processor_classes: dict[TaskTypes, type[_TaskProcessor]] = { TaskTypes.image_label_binary: _LabelsTaskProcessor, TaskTypes.image_boxes: _BoxesTaskProcessor, TaskTypes.image_points: _PointsTaskProcessor, diff --git a/packages/examples/cvat/exchange-oracle/src/log.py b/packages/examples/cvat/exchange-oracle/src/log.py index 5f5ab8983a..09349ef6f4 100644 --- a/packages/examples/cvat/exchange-oracle/src/log.py +++ b/packages/examples/cvat/exchange-oracle/src/log.py @@ -1,4 +1,5 @@ -""" Config for the application logger""" +"""Config for the application logger""" + import logging from logging.config import dictConfig diff --git a/packages/examples/cvat/exchange-oracle/src/models/__init__.py b/packages/examples/cvat/exchange-oracle/src/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/models/cvat.py b/packages/examples/cvat/exchange-oracle/src/models/cvat.py index 71678b520d..ad4c6769bf 100644 --- a/packages/examples/cvat/exchange-oracle/src/models/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/models/cvat.py @@ -1,8 +1,6 @@ # pylint: disable=too-few-public-methods from __future__ import annotations -from typing import List, Optional - from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, UniqueConstraint from sqlalchemy.orm import Mapped, relationship from sqlalchemy.sql import func @@ -35,23 +33,23 @@ class Project(Base): updated_at = Column(DateTime(timezone=True), onupdate=func.now()) cvat_webhook_id = Column(Integer, nullable=True) - images: Mapped[List["Image"]] = relationship( + images: Mapped[list[Image]] = relationship( back_populates="project", cascade="all, delete", passive_deletes=True ) - tasks: Mapped[List["Task"]] = relationship( + tasks: Mapped[list[Task]] = relationship( back_populates="project", cascade="all, delete", passive_deletes=True, ) - jobs: Mapped[List["Job"]] = relationship( + jobs: Mapped[list[Job]] = relationship( back_populates="project", cascade="all, delete", passive_deletes=True, ) - escrow_creation: Mapped["EscrowCreation"] = relationship( + escrow_creation: Mapped[EscrowCreation] = relationship( back_populates="projects", passive_deletes=True, # A custom join is used because the foreign keys do not actually reference any objects @@ -64,7 +62,7 @@ class Project(Base): foreign_keys=[escrow_address, chain_id], ) - def __repr__(self): + def __repr__(self) -> str: return f"Project. id={self.id}" @@ -81,17 +79,17 @@ class Task(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - project: Mapped["Project"] = relationship(back_populates="tasks") - jobs: Mapped[List["Job"]] = relationship( + project: Mapped[Project] = relationship(back_populates="tasks") + jobs: Mapped[list[Job]] = relationship( back_populates="task", cascade="all, delete", passive_deletes=True, ) - data_upload: Mapped["DataUpload"] = relationship( + data_upload: Mapped[DataUpload] = relationship( back_populates="task", cascade="all, delete", passive_deletes=True ) - def __repr__(self): + def __repr__(self) -> str: return f"Task. id={self.id}" @@ -108,7 +106,7 @@ class EscrowCreation(Base): total_jobs = Column(Integer, nullable=False) - projects: Mapped[List["Project"]] = relationship( + projects: Mapped[list[Project]] = relationship( back_populates="escrow_creation", # A custom join is used because the foreign keys do not actually reference any objects primaryjoin=( @@ -120,7 +118,7 @@ class EscrowCreation(Base): foreign_keys=[Project.escrow_address, Project.chain_id], ) - def __repr__(self): + def __repr__(self) -> str: return f"EscrowCreation. id={self.id} escrow={self.escrow_address}" @@ -135,9 +133,9 @@ class DataUpload(Base): nullable=False, ) - task: Mapped["Task"] = relationship(back_populates="data_upload") + task: Mapped[Task] = relationship(back_populates="data_upload") - def __repr__(self): + def __repr__(self) -> str: return f"DataUpload. id={self.id} task={self.task_id}" @@ -155,9 +153,9 @@ class Job(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - task: Mapped["Task"] = relationship(back_populates="jobs") - project: Mapped["Project"] = relationship(back_populates="jobs") - assignments: Mapped[List["Assignment"]] = relationship( + task: Mapped[Task] = relationship(back_populates="jobs") + project: Mapped[Project] = relationship(back_populates="jobs") + assignments: Mapped[list[Assignment]] = relationship( back_populates="job", cascade="all, delete", passive_deletes=True, @@ -165,11 +163,11 @@ class Job(Base): ) @property - def latest_assignment(self) -> Optional[Assignment]: + def latest_assignment(self) -> Assignment | None: assignments = self.assignments return assignments[0] if assignments else None - def __repr__(self): + def __repr__(self) -> str: return f"Job. id={self.id}" @@ -179,11 +177,11 @@ class User(Base): cvat_email = Column(String, unique=True, index=True, nullable=True) cvat_id = Column(Integer, unique=True, index=True, nullable=True) - assignments: Mapped[List["Assignment"]] = relationship( + assignments: Mapped[list[Assignment]] = relationship( back_populates="user", cascade="all, delete", passive_deletes=True ) - def __repr__(self): + def __repr__(self) -> str: return f"User. wallet_address={self.wallet_address} cvat_id={self.cvat_id}" @@ -206,8 +204,8 @@ class Assignment(Base): nullable=False, ) - user: Mapped["User"] = relationship(back_populates="assignments") - job: Mapped["Job"] = relationship(back_populates="assignments") + user: Mapped[User] = relationship(back_populates="assignments") + job: Mapped[Job] = relationship(back_populates="assignments") @property def is_finished(self) -> bool: @@ -217,7 +215,7 @@ def is_finished(self) -> bool: or self.status != AssignmentStatuses.created ) - def __repr__(self): + def __repr__(self) -> str: return f"Assignment. id={self.id} user={self.user.cvat_id} job={self.job.cvat_id}" @@ -231,11 +229,11 @@ class Image(Base): ) filename = Column(String, nullable=False) - project: Mapped["Project"] = relationship(back_populates="images") + project: Mapped[Project] = relationship(back_populates="images") __table_args__ = (UniqueConstraint("cvat_project_id", "filename", name="_project_filename_uc"),) - def __repr__(self): + def __repr__(self) -> str: return ( f"Image. id={self.id} cvat_project_id={self.cvat_project_id} filename={self.filename}" ) diff --git a/packages/examples/cvat/exchange-oracle/src/models/webhook.py b/packages/examples/cvat/exchange-oracle/src/models/webhook.py index 10c4bec69c..64ca2fb0e7 100644 --- a/packages/examples/cvat/exchange-oracle/src/models/webhook.py +++ b/packages/examples/cvat/exchange-oracle/src/models/webhook.py @@ -26,5 +26,5 @@ class Webhook(Base): event_data = Column(JSON, nullable=True, server_default=None) direction = Column(String, nullable=False) - def __repr__(self): + def __repr__(self) -> str: return f"Webhook. id={self.id} type={self.type}.{self.event_type}" diff --git a/packages/examples/cvat/exchange-oracle/src/schemas/__init__.py b/packages/examples/cvat/exchange-oracle/src/schemas/__init__.py index d3d08c0dd9..e26cbee23e 100644 --- a/packages/examples/cvat/exchange-oracle/src/schemas/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/schemas/__init__.py @@ -1,8 +1,4 @@ -# pylint: disable=too-few-public-methods - -""" Schema for API input&output""" - -from typing import List, Optional +"""Schema for API input&output""" from pydantic import BaseModel @@ -28,7 +24,7 @@ class ResponseError(BaseModel): class SupportedNetwork(BaseModel): chain_id: int - addr: Optional[str] + addr: str | None class MetaResponse(BaseModel): @@ -36,4 +32,4 @@ class MetaResponse(BaseModel): message: str version: str - supported_networks: List[SupportedNetwork] + supported_networks: list[SupportedNetwork] diff --git a/packages/examples/cvat/exchange-oracle/src/schemas/cvat.py b/packages/examples/cvat/exchange-oracle/src/schemas/cvat.py index e7b7252802..36cc946948 100644 --- a/packages/examples/cvat/exchange-oracle/src/schemas/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/schemas/cvat.py @@ -1,10 +1,8 @@ -from typing import Optional - from pydantic import BaseModel class CvatWebhook(BaseModel): event: str - job: Optional[dict] - task: Optional[dict] - before_update: Optional[dict] + job: dict | None + task: dict | None + before_update: dict | None diff --git a/packages/examples/cvat/exchange-oracle/src/schemas/exchange.py b/packages/examples/cvat/exchange-oracle/src/schemas/exchange.py index db156e7f8c..12047ec05f 100644 --- a/packages/examples/cvat/exchange-oracle/src/schemas/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/schemas/exchange.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import AnyUrl, BaseModel, Field @@ -22,7 +21,7 @@ class TaskResponse(BaseModel): job_size: int job_time_limit: int job_type: TaskTypes - assignment: Optional[AssignmentResponse] = None + assignment: AssignmentResponse | None = None status: ProjectStatuses diff --git a/packages/examples/cvat/exchange-oracle/src/schemas/webhook.py b/packages/examples/cvat/exchange-oracle/src/schemas/webhook.py index 5893d6ee42..7e8a169c71 100644 --- a/packages/examples/cvat/exchange-oracle/src/schemas/webhook.py +++ b/packages/examples/cvat/exchange-oracle/src/schemas/webhook.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel, validator @@ -11,8 +10,8 @@ class OracleWebhook(BaseModel): escrow_address: str chain_id: Networks event_type: str - event_data: Optional[dict] = None - timestamp: Optional[datetime] = None # TODO: remove optional + event_data: dict | None = None + timestamp: datetime | None = None # TODO: remove optional @validator("escrow_address", allow_reuse=True) def validate_escrow_(cls, value): diff --git a/packages/examples/cvat/exchange-oracle/src/services/__init__.py b/packages/examples/cvat/exchange-oracle/src/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py index 5bb92d77e3..53cf13dcc0 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py @@ -1,37 +1,29 @@ from abc import ABCMeta, abstractmethod -from typing import List, Optional from urllib.parse import unquote class StorageClient(metaclass=ABCMeta): def __init__( self, - bucket: Optional[str] = None, + bucket: str | None = None, ) -> None: self._bucket = unquote(bucket) if bucket else None @abstractmethod - def create_file(self, key: str, data: bytes = b"", *, bucket: Optional[str] = None): - ... + def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None): ... @abstractmethod - def remove_file(self, key: str, *, bucket: Optional[str] = None): - ... + def remove_file(self, key: str, *, bucket: str | None = None): ... @abstractmethod - def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: - ... + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: ... @abstractmethod - def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: - ... + def download_file(self, key: str, *, bucket: str | None = None) -> bytes: ... @abstractmethod - def list_files( - self, *, bucket: Optional[str] = None, prefix: Optional[str] = None - ) -> List[str]: - ... + def list_files(self, *, bucket: str | None = None, prefix: str | None = None) -> list[str]: ... @staticmethod - def normalize_prefix(prefix: Optional[str]) -> Optional[str]: + def normalize_prefix(prefix: str | None) -> str | None: return unquote(prefix).strip("/\\") + "/" if prefix else prefix diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py index 36611b363f..014be4ce1a 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py @@ -1,5 +1,4 @@ from io import BytesIO -from typing import Dict, List, Optional from urllib.parse import unquote from google.cloud import storage @@ -13,8 +12,8 @@ class GcsClient(StorageClient): def __init__( self, *, - bucket: Optional[str] = None, - service_account_key: Optional[Dict] = None, + bucket: str | None = None, + service_account_key: dict | None = None, ) -> None: super().__init__(bucket) @@ -23,22 +22,22 @@ def __init__( else: self.client = storage.Client.create_anonymous_client() - def create_file(self, key: str, data: bytes = b"", *, bucket: Optional[str] = None) -> None: + def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None) -> None: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) bucket_client.blob(unquote(key)).upload_from_string(data) - def remove_file(self, key: str, *, bucket: Optional[str] = None) -> None: + def remove_file(self, key: str, *, bucket: str | None = None) -> None: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) bucket_client.delete_blob(unquote(key)) - def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) return bucket_client.blob(unquote(key)).exists() - def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: + def download_file(self, key: str, *, bucket: str | None = None) -> bytes: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) blob = bucket_client.blob(unquote(key)) @@ -47,9 +46,7 @@ def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: self.client.download_blob_to_file(blob, data) return data.getvalue() - def list_files( - self, *, bucket: Optional[str] = None, prefix: Optional[str] = None - ) -> List[str]: + def list_files(self, *, bucket: str | None = None, prefix: str | None = None) -> list[str]: bucket = unquote(bucket) if bucket else self._bucket prefix = self.normalize_prefix(prefix) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py index e8e608ce99..ccb5557e7d 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py @@ -1,5 +1,4 @@ from io import BytesIO -from typing import List, Optional from urllib.parse import unquote import boto3 @@ -15,15 +14,15 @@ class S3Client(StorageClient): def __init__( self, *, - bucket: Optional[str] = None, - access_key: Optional[str] = None, - secret_key: Optional[str] = None, - endpoint_url: Optional[str] = None, + bucket: str | None = None, + access_key: str | None = None, + secret_key: str | None = None, + endpoint_url: str | None = None, ) -> None: super().__init__(bucket) session = boto3.Session( - **(dict(aws_access_key_id=access_key) if access_key else {}), - **(dict(aws_secret_access_key=secret_key) if secret_key else {}), + **({"aws_access_key_id": access_key} if access_key else {}), + **({"aws_secret_access_key": secret_key} if secret_key else {}), ) s3 = session.resource( "s3", **({"endpoint_url": unquote(endpoint_url)} if endpoint_url else {}) @@ -34,15 +33,15 @@ def __init__( if not access_key and not secret_key: self.client.meta.events.register("choose-signer.s3.*", disable_signing) - def create_file(self, key: str, data: bytes = b"", *, bucket: Optional[str] = None): + def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None): bucket = unquote(bucket) if bucket else self._bucket self.client.put_object(Body=data, Bucket=bucket, Key=unquote(key)) - def remove_file(self, key: str, *, bucket: Optional[str] = None): + def remove_file(self, key: str, *, bucket: str | None = None): bucket = unquote(bucket) if bucket else self._bucket self.client.delete_object(Bucket=bucket, Key=unquote(key)) - def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: bucket = unquote(bucket) if bucket else self._bucket try: self.client.head_object(Bucket=bucket, Key=unquote(key)) @@ -53,19 +52,14 @@ def file_exists(self, key: str, *, bucket: Optional[str] = None) -> bool: else: raise - def download_file(self, key: str, *, bucket: Optional[str] = None) -> bytes: + def download_file(self, key: str, *, bucket: str | None = None) -> bytes: bucket = unquote(bucket) if bucket else self._bucket with BytesIO() as data: self.client.download_fileobj(Bucket=bucket, Key=unquote(key), Fileobj=data) return data.getvalue() - def list_files( - self, *, bucket: Optional[str] = None, prefix: Optional[str] = None - ) -> List[str]: + def list_files(self, *, bucket: str | None = None, prefix: str | None = None) -> list[str]: bucket = unquote(bucket) if bucket else self._bucket objects = self.resource.Bucket(bucket).objects - if prefix: - objects = objects.filter(Prefix=self.normalize_prefix(prefix)) - else: - objects = objects.all() + objects = objects.filter(Prefix=self.normalize_prefix(prefix)) if prefix else objects.all() return [file_info.key for file_info in objects] diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/types.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/types.py index 5f5b0db3ad..45e0dbffec 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/types.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/types.py @@ -4,7 +4,6 @@ from dataclasses import asdict, dataclass, is_dataclass from enum import Enum, auto from inspect import isclass -from typing import Dict, Optional, Type, Union from urllib.parse import urlparse from src.core import manifest @@ -31,14 +30,14 @@ def from_str(cls, provider: str) -> CloudProviders: class BucketCredentials: - def to_dict(self) -> Dict: + def to_dict(self) -> dict: if not is_dataclass(self): raise NotImplementedError return asdict(self) @classmethod - def from_storage_config(cls, config: Type[StorageConfig]) -> Optional[BucketCredentials]: + def from_storage_config(cls, config: type[StorageConfig]) -> BucketCredentials | None: credentials = None if (config.access_key or config.secret_key) and config.provider.lower() != "aws": @@ -46,9 +45,8 @@ def from_storage_config(cls, config: Type[StorageConfig]) -> Optional[BucketCred "Invalid storage configuration. The access_key/secret_key pair" f"cannot be specified with {config.provider} provider" ) - elif ( - bool(config.access_key) ^ bool(config.secret_key) - ) and config.provider.lower() == "aws": + + if (bool(config.access_key) ^ bool(config.secret_key)) and config.provider.lower() == "aws": raise ValueError( "Invalid storage configuration. " "Either none or both access_key and secret_key must be specified for an AWS storage" @@ -71,7 +69,7 @@ def from_storage_config(cls, config: Type[StorageConfig]) -> Optional[BucketCred @dataclass class GcsBucketCredentials(BucketCredentials): - service_account_key: Dict + service_account_key: dict @dataclass @@ -85,8 +83,8 @@ class BucketAccessInfo: provider: CloudProviders host_url: str bucket_name: str - path: Optional[str] = None - credentials: Optional[BucketCredentials] = None + path: str | None = None + credentials: BucketCredentials | None = None @classmethod def from_url(cls, url: str) -> BucketAccessInfo: @@ -129,7 +127,7 @@ def from_url(cls, url: str) -> BucketAccessInfo: raise ValueError(f"{parsed_url.netloc} cloud provider is not supported.") @classmethod - def _from_dict(cls, data: Dict) -> BucketAccessInfo: + def _from_dict(cls, data: dict) -> BucketAccessInfo: for required_field in ( "provider", "bucket_name", @@ -159,7 +157,7 @@ def _from_dict(cls, data: Dict) -> BucketAccessInfo: return BucketAccessInfo(**data) @classmethod - def from_storage_config(cls, config: Type[StorageConfig]) -> BucketAccessInfo: + def from_storage_config(cls, config: type[StorageConfig]) -> BucketAccessInfo: credentials = BucketCredentials.from_storage_config(config) return BucketAccessInfo( @@ -174,9 +172,7 @@ def from_bucket_url(cls, bucket_url: manifest.BucketUrl) -> BucketAccessInfo: return cls._from_dict(bucket_url.dict()) @classmethod - def parse_obj( - cls, data: Union[str, Type[StorageConfig], manifest.BucketUrl] - ) -> BucketAccessInfo: + def parse_obj(cls, data: str | type[StorageConfig] | manifest.BucketUrl) -> BucketAccessInfo: if isinstance(data, manifest.BucketUrlBase): return cls.from_bucket_url(data) elif isinstance(data, str): diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/utils.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/utils.py index a9f821d174..bfc23305c7 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/utils.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/utils.py @@ -1,5 +1,3 @@ -from typing import Optional - from src.services.cloud.client import StorageClient from src.services.cloud.gcs import DEFAULT_GCS_HOST, GcsClient from src.services.cloud.s3 import DEFAULT_S3_HOST, S3Client @@ -7,7 +5,7 @@ def compose_bucket_url( - bucket_name: str, provider: CloudProviders, *, bucket_host: Optional[str] = None + bucket_name: str, provider: CloudProviders, *, bucket_host: str | None = None ) -> str: match provider: case CloudProviders.aws: diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index a57791e8ef..ccd6fd9ddd 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -1,7 +1,7 @@ import itertools import uuid +from collections.abc import Sequence from datetime import datetime -from typing import List, Optional, Sequence, Union from sqlalchemy import delete, insert, update from sqlalchemy.orm import Session @@ -22,7 +22,7 @@ def create_project( escrow_address: str, chain_id: int, bucket_url: str, - cvat_webhook_id: Optional[int] = None, + cvat_webhook_id: int | None = None, status: ProjectStatuses = ProjectStatuses.creation, ) -> str: """ @@ -50,13 +50,10 @@ def get_project_by_id( session: Session, project_id: str, *, - for_update: Union[bool, ForUpdateParams] = False, - status_in: Optional[List[ProjectStatuses]] = None, -) -> Optional[Project]: - if status_in: - status_filter_arg = [Project.status.in_(s.value for s in status_in)] - else: - status_filter_arg = [] + for_update: bool | ForUpdateParams = False, + status_in: list[ProjectStatuses] | None = None, +) -> Project | None: + status_filter_arg = [Project.status.in_(s.value for s in status_in)] if status_in else [] return ( _maybe_for_update(session.query(Project), enable=for_update) @@ -69,14 +66,11 @@ def get_projects_by_cvat_ids( session: Session, project_cvat_ids: Sequence[int], *, - for_update: Union[bool, ForUpdateParams] = False, - status_in: Optional[List[ProjectStatuses]] = None, + for_update: bool | ForUpdateParams = False, + status_in: list[ProjectStatuses] | None = None, limit: int = 5, -) -> List[Project]: - if status_in: - status_filter_arg = [Project.status.in_(s.value for s in status_in)] - else: - status_filter_arg = [] +) -> list[Project]: + status_filter_arg = [Project.status.in_(s.value for s in status_in)] if status_in else [] return ( _maybe_for_update(session.query(Project), enable=for_update) @@ -87,8 +81,8 @@ def get_projects_by_cvat_ids( def get_project_by_escrow_address( - session: Session, escrow_address: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Project]: + session: Session, escrow_address: str, *, for_update: bool | ForUpdateParams = False +) -> Project | None: return ( _maybe_for_update(session.query(Project), enable=for_update) .where(Project.escrow_address == escrow_address) @@ -100,9 +94,9 @@ def get_projects_by_escrow_address( session: Session, escrow_address: str, *, - for_update: Union[bool, ForUpdateParams] = False, - limit: Optional[int] = 5, -) -> List[Project]: + for_update: bool | ForUpdateParams = False, + limit: int | None = 5, +) -> list[Project]: projects = _maybe_for_update(session.query(Project), enable=for_update).where( Project.escrow_address == escrow_address ) @@ -116,7 +110,7 @@ def get_projects_by_escrow_address( def get_project_cvat_ids_by_escrow_address( session: Session, escrow_address: str, -) -> List[int]: +) -> list[int]: projects = session.query(Project).where(Project.escrow_address == escrow_address) return list(itertools.chain.from_iterable(projects.values(Project.cvat_id))) @@ -126,11 +120,11 @@ def get_projects_by_status( session: Session, status: ProjectStatuses, *, - included_types: Optional[Sequence[TaskTypes]] = None, - task_status: Optional[TaskStatuses] = None, + included_types: Sequence[TaskTypes] | None = None, + task_status: TaskStatuses | None = None, limit: int = 5, - for_update: Union[bool, ForUpdateParams] = False, -) -> List[Project]: + for_update: bool | ForUpdateParams = False, +) -> list[Project]: projects = _maybe_for_update(session.query(Project), enable=for_update).where( Project.status == status.value ) @@ -141,18 +135,16 @@ def get_projects_by_status( if included_types is not None: projects = projects.where(Project.job_type.in_([t.value for t in included_types])) - projects = projects.limit(limit).all() - - return projects + return projects.limit(limit).all() def get_escrows_by_project_status( session: Session, project_status: ProjectStatuses, *, - included_types: Optional[Sequence[TaskTypes]] = None, + included_types: Sequence[TaskTypes] | None = None, limit: int = 5, -) -> List[tuple[str, int]]: +) -> list[tuple[str, int]]: escrows = ( session.query(Project.escrow_address, Project.chain_id) .group_by(Project.escrow_address, Project.chain_id) @@ -162,12 +154,10 @@ def get_escrows_by_project_status( if included_types: escrows = escrows.where(Project.job_type.in_([t.value for t in included_types])) - escrows = escrows.limit(limit).all() - - return escrows + return escrows.limit(limit).all() -def get_available_projects(session: Session, *, limit: int = 10) -> List[Project]: +def get_available_projects(session: Session, *, limit: int = 10) -> list[Project]: return ( session.query(Project) .where( @@ -185,11 +175,11 @@ def get_available_projects(session: Session, *, limit: int = 10) -> List[Project def get_projects_by_assignee( session: Session, - wallet_address: Optional[str] = None, + wallet_address: str | None = None, *, limit: int = 10, - for_update: Union[bool, ForUpdateParams] = False, -) -> List[Project]: + for_update: bool | ForUpdateParams = False, +) -> list[Project]: return ( _maybe_for_update(session.query(Project), enable=for_update) .where( @@ -237,10 +227,7 @@ def delete_project(session: Session, project_id: str) -> None: def is_project_completed(session: Session, project_id: str) -> bool: project = get_project_by_id(session, project_id) jobs = get_jobs_by_cvat_project_id(session, project.cvat_id) - if len(jobs) > 0 and all(job.status == JobStatuses.completed.value for job in jobs): - return True - else: - return False + return bool(len(jobs) > 0 and all(job.status == JobStatuses.completed.value for job in jobs)) # EscrowCreation @@ -271,8 +258,8 @@ def get_escrow_creation_by_id( session: Session, escrow_creation_id: str, *, - for_update: Union[bool, ForUpdateParams] = False, -) -> Optional[EscrowCreation]: + for_update: bool | ForUpdateParams = False, +) -> EscrowCreation | None: return ( _maybe_for_update(session.query(EscrowCreation), enable=for_update) .where(EscrowCreation.id == escrow_creation_id, EscrowCreation.finished_at.is_(None)) @@ -285,8 +272,8 @@ def get_escrow_creation_by_escrow_address( escrow_address: str, chain_id: int, *, - for_update: Union[bool, ForUpdateParams] = False, -) -> Optional[EscrowCreation]: + for_update: bool | ForUpdateParams = False, +) -> EscrowCreation | None: return ( _maybe_for_update(session.query(EscrowCreation), enable=for_update) .where( @@ -299,8 +286,8 @@ def get_escrow_creation_by_escrow_address( def get_active_escrow_creations( - session: Session, *, limit: int = 10, for_update: Union[bool, ForUpdateParams] = False -) -> List[EscrowCreation]: + session: Session, *, limit: int = 10, for_update: bool | ForUpdateParams = False +) -> list[EscrowCreation]: return ( _maybe_for_update(session.query(EscrowCreation), enable=for_update) .where(EscrowCreation.finished_at.is_(None)) @@ -309,7 +296,7 @@ def get_active_escrow_creations( ) -def finish_escrow_creations(session: Session, escrow_creations: List[EscrowCreation]) -> None: +def finish_escrow_creations(session: Session, escrow_creations: list[EscrowCreation]) -> None: statement = ( update(EscrowCreation) .where(EscrowCreation.id.in_(c.id for c in escrow_creations)) @@ -348,16 +335,16 @@ def create_task(session: Session, cvat_id: int, cvat_project_id: int, status: Ta def get_task_by_id( - session: Session, task_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Task]: + session: Session, task_id: str, *, for_update: bool | ForUpdateParams = False +) -> Task | None: return ( _maybe_for_update(session.query(Task), enable=for_update).where(Task.id == task_id).first() ) def get_tasks_by_cvat_id( - session: Session, task_ids: List[int], *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Task]: + session: Session, task_ids: list[int], *, for_update: bool | ForUpdateParams = False +) -> list[Task]: return ( _maybe_for_update(session.query(Task), enable=for_update) .where(Task.cvat_id.in_(task_ids)) @@ -369,11 +356,11 @@ def get_tasks_by_status( session: Session, status: TaskStatuses, *, - job_status: Optional[JobStatuses] = None, - project_status: Optional[ProjectStatuses] = None, - for_update: Union[bool, ForUpdateParams] = False, - limit: Optional[int] = 20, -) -> List[Task]: + job_status: JobStatuses | None = None, + project_status: ProjectStatuses | None = None, + for_update: bool | ForUpdateParams = False, + limit: int | None = 20, +) -> list[Task]: query = _maybe_for_update(session.query(Task), enable=for_update).where( Task.status == status.value ) @@ -396,8 +383,8 @@ def update_task_status(session: Session, task_id: int, status: TaskStatuses) -> def get_tasks_by_cvat_project_id( - session: Session, cvat_project_id: int, *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Task]: + session: Session, cvat_project_id: int, *, for_update: bool | ForUpdateParams = False +) -> list[Task]: return ( _maybe_for_update(session.query(Task), enable=for_update) .where(Task.cvat_project_id == cvat_project_id) @@ -421,8 +408,8 @@ def create_data_upload( def get_active_task_uploads_by_task_id( - session: Session, task_ids: List[int], *, for_update: Union[bool, ForUpdateParams] = False -) -> List[DataUpload]: + session: Session, task_ids: list[int], *, for_update: bool | ForUpdateParams = False +) -> list[DataUpload]: return ( _maybe_for_update(session.query(DataUpload), enable=for_update) .where(DataUpload.task_id.in_(task_ids)) @@ -431,8 +418,8 @@ def get_active_task_uploads_by_task_id( def get_active_task_uploads( - session: Session, *, limit: int = 10, for_update: Union[bool, ForUpdateParams] = False -) -> List[DataUpload]: + session: Session, *, limit: int = 10, for_update: bool | ForUpdateParams = False +) -> list[DataUpload]: return _maybe_for_update(session.query(DataUpload), enable=for_update).limit(limit).all() @@ -469,14 +456,14 @@ def create_job( def get_job_by_id( - session: Session, job_id: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Job]: + session: Session, job_id: str, *, for_update: bool | ForUpdateParams = False +) -> Job | None: return _maybe_for_update(session.query(Job), enable=for_update).where(Job.id == job_id).first() def get_jobs_by_cvat_id( - session: Session, cvat_ids: List[int], *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Job]: + session: Session, cvat_ids: list[int], *, for_update: bool | ForUpdateParams = False +) -> list[Job]: return ( _maybe_for_update(session.query(Job), enable=for_update) .where(Job.cvat_id.in_(cvat_ids)) @@ -490,8 +477,8 @@ def update_job_status(session: Session, job_id: int, status: JobStatuses) -> Non def get_jobs_by_cvat_task_id( - session: Session, cvat_task_id: int, *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Job]: + session: Session, cvat_task_id: int, *, for_update: bool | ForUpdateParams = False +) -> list[Job]: return ( _maybe_for_update(session.query(Job), enable=for_update) .where(Job.cvat_task_id == cvat_task_id) @@ -500,8 +487,8 @@ def get_jobs_by_cvat_task_id( def get_jobs_by_cvat_project_id( - session: Session, cvat_project_id: int, *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Job]: + session: Session, cvat_project_id: int, *, for_update: bool | ForUpdateParams = False +) -> list[Job]: return ( _maybe_for_update(session.query(Job), enable=for_update) .where(Job.cvat_project_id == cvat_project_id) @@ -526,10 +513,10 @@ def count_jobs_by_escrow_address( def get_free_job( session: Session, - cvat_projects: List[int], + cvat_projects: list[int], *, - for_update: Union[bool, ForUpdateParams] = False, -) -> Optional[Job]: + for_update: bool | ForUpdateParams = False, +) -> Job | None: return ( _maybe_for_update(session.query(Job), enable=for_update) .where( @@ -564,8 +551,8 @@ def put_user(session: Session, wallet_address: str, cvat_email: str, cvat_id: in def get_user_by_id( - session: Session, wallet_address: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[User]: + session: Session, wallet_address: str, *, for_update: bool | ForUpdateParams = False +) -> User | None: return ( _maybe_for_update(session.query(User), enable=for_update) .where(User.wallet_address == wallet_address) @@ -574,8 +561,8 @@ def get_user_by_id( def get_user_by_email( - session: Session, email: str, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[User]: + session: Session, email: str, *, for_update: bool | ForUpdateParams = False +) -> User | None: return ( _maybe_for_update(session.query(User), enable=for_update) .where(User.cvat_email == email) @@ -606,8 +593,8 @@ def create_assignment( def get_assignments_by_id( - session: Session, ids: List[str], *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Assignment]: + session: Session, ids: list[str], *, for_update: bool | ForUpdateParams = False +) -> list[Assignment]: return ( _maybe_for_update(session.query(Assignment), enable=for_update) .where(Assignment.id.in_(ids)) @@ -616,8 +603,8 @@ def get_assignments_by_id( def get_latest_assignment_by_cvat_job_id( - session: Session, cvat_job_id: int, *, for_update: Union[bool, ForUpdateParams] = False -) -> Optional[Assignment]: + session: Session, cvat_job_id: int, *, for_update: bool | ForUpdateParams = False +) -> Assignment | None: return ( _maybe_for_update(session.query(Assignment), enable=for_update) .where(Assignment.cvat_job_id == cvat_job_id) @@ -627,8 +614,8 @@ def get_latest_assignment_by_cvat_job_id( def get_unprocessed_expired_assignments( - session: Session, *, limit: int = 10, for_update: Union[bool, ForUpdateParams] = False -) -> List[Assignment]: + session: Session, *, limit: int = 10, for_update: bool | ForUpdateParams = False +) -> list[Assignment]: return ( _maybe_for_update(session.query(Assignment), enable=for_update) .where( @@ -642,8 +629,8 @@ def get_unprocessed_expired_assignments( def get_active_assignments( - session: Session, *, limit: int = 10, for_update: Union[bool, ForUpdateParams] = False -) -> List[Assignment]: + session: Session, *, limit: int = 10, for_update: bool | ForUpdateParams = False +) -> list[Assignment]: return ( _maybe_for_update(session.query(Assignment), enable=for_update) .where( @@ -661,7 +648,7 @@ def update_assignment( id: str, *, status: AssignmentStatuses, - completed_at: Optional[datetime] = None, + completed_at: datetime | None = None, ): statement = ( update(Assignment) @@ -691,10 +678,10 @@ def complete_assignment(session: Session, assignment_id: str, completed_at: date def get_user_assignments_in_cvat_projects( session: Session, wallet_address: int, - cvat_projects: List[int], + cvat_projects: list[int], *, - for_update: Union[bool, ForUpdateParams] = False, -) -> List[Assignment]: + for_update: bool | ForUpdateParams = False, +) -> list[Assignment]: return ( _maybe_for_update(session.query(Assignment), enable=for_update) .where( @@ -708,7 +695,7 @@ def get_user_assignments_in_cvat_projects( def count_active_user_assignments( session: Session, wallet_address: int, - cvat_projects: List[int], + cvat_projects: list[int], ) -> int: return ( session.query(Assignment) @@ -724,23 +711,23 @@ def count_active_user_assignments( # Image -def add_project_images(session: Session, cvat_project_id: int, filenames: List[str]) -> None: +def add_project_images(session: Session, cvat_project_id: int, filenames: list[str]) -> None: session.execute( insert(Image), [ - dict( - id=str(uuid.uuid4()), - cvat_project_id=cvat_project_id, - filename=fn, - ) + { + "id": str(uuid.uuid4()), + "cvat_project_id": cvat_project_id, + "filename": fn, + } for fn in filenames ], ) def get_project_images( - session: Session, cvat_project_id: int, *, for_update: Union[bool, ForUpdateParams] = False -) -> List[Image]: + session: Session, cvat_project_id: int, *, for_update: bool | ForUpdateParams = False +) -> list[Image]: return ( _maybe_for_update(session.query(Image), enable=for_update) .where(Image.cvat_project_id == cvat_project_id) diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index b651180654..d072e5f53a 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -1,5 +1,4 @@ from datetime import timedelta -from typing import Optional import src.cvat.api_calls as cvat_api import src.services.cvat as cvat_service @@ -18,7 +17,7 @@ def serialize_task( - project_id: str, *, assignment_id: Optional[str] = None + project_id: str, *, assignment_id: str | None = None ) -> service_api.TaskResponse: with SessionLocal.begin() as session: project = cvat_service.get_project_by_id(session, project_id) @@ -57,19 +56,14 @@ def serialize_task( def get_available_tasks() -> list[service_api.TaskResponse]: - results = [] - with SessionLocal.begin() as session: cvat_projects = cvat_service.get_available_projects(session) - for project in cvat_projects: - results.append(serialize_task(project.id)) - - return results + return [serialize_task(project.id) for project in cvat_projects] def get_tasks_by_assignee( - wallet_address: Optional[str] = None, + wallet_address: str | None = None, ) -> list[service_api.TaskResponse]: results = [] @@ -104,7 +98,7 @@ class UserHasUnfinishedAssignmentError(Exception): pass -def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: +def create_assignment(project_id: int, wallet_address: str) -> str | None: with SessionLocal.begin() as session: user = get_or_404( cvat_service.get_user_by_id(session, wallet_address, for_update=True), diff --git a/packages/examples/cvat/exchange-oracle/src/services/webhook.py b/packages/examples/cvat/exchange-oracle/src/services/webhook.py index 7610cf9dab..e2bc9c7fbe 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/webhook.py +++ b/packages/examples/cvat/exchange-oracle/src/services/webhook.py @@ -1,7 +1,6 @@ import datetime import uuid from enum import Enum -from typing import List, Optional from attrs import define from sqlalchemy import case, update @@ -25,7 +24,7 @@ class OracleWebhookDirectionTags(str, Enum, metaclass=BetterEnumMeta): @define class OracleWebhookQueue: direction: OracleWebhookDirectionTags - default_sender: Optional[OracleWebhookTypes] = None + default_sender: OracleWebhookTypes | None = None def create_webhook( self, @@ -33,10 +32,10 @@ def create_webhook( escrow_address: str, chain_id: int, type: OracleWebhookTypes, - signature: Optional[str] = None, - event_type: Optional[str] = None, - event_data: Optional[dict] = None, - event: Optional[OracleEvent] = None, + signature: str | None = None, + event_type: str | None = None, + event_data: dict | None = None, + event: OracleEvent | None = None, ) -> str: """ Creates a webhook in a database @@ -44,7 +43,7 @@ def create_webhook( assert not event_data or event_type, "'event_data' requires 'event_type'" assert bool(event) ^ bool( event_type - ), f"'event' and 'event_type' cannot be used together. Please use only one of the fields" + ), "'event' and 'event_type' cannot be used together. Please use only one of the fields" if event_type: if self.direction == OracleWebhookDirectionTags.incoming: @@ -59,7 +58,7 @@ def create_webhook( if self.direction == OracleWebhookDirectionTags.incoming and not signature: raise ValueError("Webhook signature must be specified for incoming events") - elif self.direction == OracleWebhookDirectionTags.outgoing and signature: + if self.direction == OracleWebhookDirectionTags.outgoing and signature: raise ValueError("Webhook signature must not be specified for outgoing events") if signature: @@ -93,8 +92,8 @@ def get_pending_webhooks( *, limit: int = 10, for_update: bool = False, - ) -> List[Webhook]: - webhooks = ( + ) -> list[Webhook]: + return ( _maybe_for_update(session.query(Webhook), enable=for_update) .where( Webhook.direction == self.direction.value, @@ -105,7 +104,6 @@ def get_pending_webhooks( .limit(limit) .all() ) - return webhooks def update_webhook_status( self, session: Session, webhook_id: str, status: OracleWebhookStatuses diff --git a/packages/examples/cvat/exchange-oracle/src/utils/__init__.py b/packages/examples/cvat/exchange-oracle/src/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/utils/annotations.py b/packages/examples/cvat/exchange-oracle/src/utils/annotations.py index 5deed7efc4..075ce7d035 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/annotations.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/annotations.py @@ -1,16 +1,18 @@ import os +from argparse import ArgumentParser +from collections.abc import Iterable, Sequence from copy import deepcopy from glob import glob -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import TypeVar import datumaro as dm import numpy as np from datumaro.util import filter_dict, mask_tools from datumaro.util.annotation_util import find_group_leader, find_instances, max_bbox -from defusedxml import ElementTree as ET +from defusedxml import ElementTree -def flatten_points(input_points: Sequence[dm.Points]) -> List[dm.Points]: +def flatten_points(input_points: Sequence[dm.Points]) -> list[dm.Points]: results = [] for pts in input_points: @@ -48,7 +50,7 @@ def prepare_cvat_annotations_for_dm(dataset_root: str): for annotation_filename in glob(os.path.join(dataset_root, "**/*.xml"), recursive=True): with open(annotation_filename, "rb+") as f: - doc = ET.parse(f) + doc = ElementTree.parse(f) doc_root = doc.getroot() if doc_root.find("meta/project"): @@ -98,7 +100,7 @@ def _get_skeleton_label(original_label: str) -> str: media_type=dm.Image, ) - label_id_map: Dict[int, int] = { + label_id_map: dict[int, int] = { original_id: new_label_cat.find(label.name, parent=_get_skeleton_label(label.name))[0] for original_id, label in enumerate(dataset.categories()[dm.AnnotationType.label]) } # old id -> new id @@ -173,7 +175,7 @@ def shift_ann(ann: T, offset_x: float, offset_y: float, *, img_w: int, img_h: in ] ) else: - assert False, f"Unsupported annotation type '{ann.type}'" + raise TypeError(f"Unsupported annotation type '{ann.type}'") return shifted_ann @@ -200,7 +202,7 @@ class ProjectLabels(dm.ItemTransform): """ @classmethod - def build_cmdline_parser(cls, **kwargs): + def build_cmdline_parser(cls, **kwargs) -> ArgumentParser: parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "-l", @@ -211,19 +213,19 @@ def build_cmdline_parser(cls, **kwargs): ) return parser - def __init__( + def __init__( # noqa: PLR0912 self, extractor: dm.IExtractor, - dst_labels: Union[Iterable[Union[str, Tuple[str, str]]], dm.LabelCategories], - ): + dst_labels: Iterable[str | tuple[str, str]] | dm.LabelCategories, + ) -> None: super().__init__(extractor) self._categories = {} src_categories = self._extractor.categories() - src_label_cat: Optional[dm.LabelCategories] = src_categories.get(dm.AnnotationType.label) - src_point_cat: Optional[dm.PointsCategories] = src_categories.get(dm.AnnotationType.points) + src_label_cat: dm.LabelCategories | None = src_categories.get(dm.AnnotationType.label) + src_point_cat: dm.PointsCategories | None = src_categories.get(dm.AnnotationType.points) if isinstance(dst_labels, dm.LabelCategories): dst_label_cat = deepcopy(dst_labels) @@ -234,7 +236,7 @@ def __init__( dst_label_cat = dm.LabelCategories(attributes=deepcopy(src_label_cat.attributes)) for dst_label in dst_labels: - assert isinstance(dst_label, str) or isinstance(dst_label, tuple) + assert isinstance(dst_label, str | tuple) dst_parent = "" if isinstance(dst_label, tuple): @@ -316,7 +318,7 @@ def _make_label_id_map(self, src_label_cat, dst_label_cat): src_id: dst_label_cat.find(src_label_cat[src_id].name, src_label_cat[src_id].parent)[0] for src_id in range(len(src_label_cat or ())) } - self._map_id = lambda src_id: id_mapping.get(src_id, None) + self._map_id = lambda src_id: id_mapping.get(src_id) def categories(self): return self._categories diff --git a/packages/examples/cvat/exchange-oracle/src/utils/enums.py b/packages/examples/cvat/exchange-oracle/src/utils/enums.py index 4f3d688251..d4c133b0e5 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/enums.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/enums.py @@ -6,5 +6,5 @@ class BetterEnumMeta(EnumMeta): Extends the default enum metaclass with extra methods for better usability """ - def __contains__(cls, item): + def __contains__(cls, item) -> bool: return isinstance(item, cls) or item in [v.value for v in cls.__members__.values()] diff --git a/packages/examples/cvat/exchange-oracle/src/utils/logging.py b/packages/examples/cvat/exchange-oracle/src/utils/logging.py index be2c5feba3..e7660eb0d7 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/logging.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/logging.py @@ -1,5 +1,5 @@ import logging -from typing import NewType, Optional, Union +from typing import NewType from src.utils.stack import current_function_name @@ -15,7 +15,7 @@ def parse_log_level(level: str) -> LogLevel: def get_function_logger( - parent_logger: Optional[Union[str, logging.Logger]] = None, + parent_logger: str | logging.Logger | None = None, ) -> logging.Logger: if isinstance(parent_logger, str): parent_logger = logging.getLogger(parent_logger) diff --git a/packages/examples/cvat/exchange-oracle/src/utils/net.py b/packages/examples/cvat/exchange-oracle/src/utils/net.py index 360dafaebb..04ad799510 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/net.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/net.py @@ -1,7 +1,7 @@ import ipaddress -def is_ipv4(addr: str, allow_port: bool = True) -> bool: +def is_ipv4(addr: str, *, allow_port: bool = True) -> bool: try: if allow_port: addr = addr.split(":", maxsplit=1)[0] diff --git a/packages/examples/cvat/exchange-oracle/src/utils/requests.py b/packages/examples/cvat/exchange-oracle/src/utils/requests.py index 785c2cfc87..ef2174f9b9 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/requests.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/requests.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import TypeVar from fastapi import HTTPException @@ -7,11 +7,11 @@ def get_or_404( - obj: Optional[T], + obj: T | None, object_id: V, object_type_name: str, *, - reason: Optional[str] = None, + reason: str | None = None, ) -> T: if obj is None: raise HTTPException( diff --git a/packages/examples/cvat/exchange-oracle/src/utils/webhooks.py b/packages/examples/cvat/exchange-oracle/src/utils/webhooks.py index acb534c265..9e1b59bf90 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/webhooks.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/webhooks.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Dict, Optional, Tuple from src.chain.web3 import sign_message from src.core.oracle_events import parse_event @@ -11,8 +10,8 @@ def prepare_outgoing_webhook_body( chain_id: Networks, event_type: str, event_data: dict, - timestamp: Optional[datetime], -) -> Dict: + timestamp: datetime | None, +) -> dict: body = {"escrow_address": escrow_address, "chain_id": chain_id} if timestamp: @@ -29,11 +28,11 @@ def prepare_outgoing_webhook_body( def prepare_signed_message( - escrow_address: str, + escrow_address: str, # noqa: ARG001 chain_id: Networks, - message: Optional[str] = None, - body: Optional[dict] = None, -) -> Tuple[str, str]: + message: str | None = None, + body: dict | None = None, +) -> tuple[str, str]: """ Sign the message with the service identity. Optionally, can serialize the input structure. diff --git a/packages/examples/cvat/exchange-oracle/src/validators/__init__.py b/packages/examples/cvat/exchange-oracle/src/validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/validators/validation.py b/packages/examples/cvat/exchange-oracle/src/validators/validation.py index c8ce69d5f2..4d9d0279c8 100644 --- a/packages/examples/cvat/exchange-oracle/src/validators/validation.py +++ b/packages/examples/cvat/exchange-oracle/src/validators/validation.py @@ -1,4 +1,4 @@ -""" Validation utils""" +"""Validation utils""" class ValidationResult: @@ -7,7 +7,7 @@ class ValidationResult: It encapsulates validation logic and helping during generating response body """ - def __init__(self): + def __init__(self) -> None: self.is_valid = True self.errors = [] diff --git a/packages/examples/cvat/exchange-oracle/tests/api/__init__.py b/packages/examples/cvat/exchange-oracle/tests/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py index 9eba05e767..b910addd60 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py @@ -38,7 +38,7 @@ def test_empty_list_tasks_200_without_address(client: TestClient) -> None: def test_list_tasks_200_with_address(client: TestClient) -> None: - with (SessionLocal.begin() as session,): + with SessionLocal.begin() as session: _, _, cvat_job_1 = create_project_task_and_job( session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) @@ -95,7 +95,7 @@ def test_list_tasks_200_with_address(client: TestClient) -> None: def test_list_tasks_200_without_address(client: TestClient) -> None: - with (SessionLocal.begin() as session,): + with SessionLocal.begin() as session: _, _, cvat_job_1 = create_project_task_and_job( session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_greet.py b/packages/examples/cvat/exchange-oracle/tests/api/test_greet.py index e1f46b0df0..e721439fd8 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_greet.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_greet.py @@ -2,5 +2,5 @@ def test_greet_route(client: TestClient) -> None: - response = client.get(f"/") + response = client.get("/") assert response.status_code == 200 diff --git a/packages/examples/cvat/exchange-oracle/tests/conftest.py b/packages/examples/cvat/exchange-oracle/tests/conftest.py index e027999628..f39f40a8ca 100644 --- a/packages/examples/cvat/exchange-oracle/tests/conftest.py +++ b/packages/examples/cvat/exchange-oracle/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Generator +from collections.abc import Generator import pytest from fastapi.testclient import TestClient diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/chain/__init__.py b/packages/examples/cvat/exchange-oracle/tests/integration/chain/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py index cb3245f716..d65097e64c 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import patch +import pytest from human_protocol_sdk.constants import ChainId, Status from human_protocol_sdk.encryption import EncryptionUtils from human_protocol_sdk.escrow import EscrowClientError, EscrowData @@ -53,45 +54,40 @@ def test_validate_escrow(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: mock_function.return_value = self.escrow_data validation = validate_escrow(chain_id, escrow_address) - self.assertIsNone(validation) + assert validation is None def test_validate_escrow_invalid_address(self): - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): validate_escrow(chain_id, "invalid_address") - self.assertEqual(f"Invalid escrow address: invalid_address", str(error.exception)) - def test_validate_escrow_invalid_status(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: self.escrow_data.status = Status.Launched.name mock_function.return_value = self.escrow_data - with self.assertRaises(ValueError) as error: + with pytest.raises( + ValueError, + match=rf"Escrow is not in any of the accepted states \(Pending\). " + f"Current state: {self.escrow_data.status}", + ): validate_escrow(chain_id, escrow_address) - self.assertEqual( - f"Escrow is not in any of the accepted states (Pending). Current state: {self.escrow_data.status}", - str(error.exception), - ) def test_validate_escrow_without_funds(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: self.escrow_data.balance = "0" mock_function.return_value = self.escrow_data - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="Escrow doesn't have funds"): validate_escrow(chain_id, escrow_address) - self.assertEqual( - f"Escrow doesn't have funds", - str(error.exception), - ) def test_get_escrow_manifest(self): - with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function, patch( - "src.chain.escrow.StorageUtils.download_file_from_url" - ) as mock_download: + with ( + patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function, + patch("src.chain.escrow.StorageUtils.download_file_from_url") as mock_download, + ): mock_download.return_value = json.dumps({"title": "test"}).encode() mock_function.return_value = self.escrow_data manifest = get_escrow_manifest(chain_id, escrow_address) - self.assertIsInstance(manifest, dict) - self.assertIsNotNone(manifest) + assert isinstance(manifest, dict) + assert manifest is not None def test_get_encrypted_escrow_manifest(self): with ( @@ -112,62 +108,55 @@ def test_get_encrypted_escrow_manifest(self): encrypted_manifest = EncryptionUtils.encrypt( original_manifest, public_keys=[PGP_PUBLIC_KEY1, PGP_PUBLIC_KEY2] ) - self.assertNotEqual(encrypted_manifest, original_manifest) + assert encrypted_manifest != original_manifest mock_download.return_value = encrypted_manifest.encode() downloaded_manifest_content = get_escrow_manifest(chain_id, escrow_address) - self.assertDictEqual(downloaded_manifest_content, original_manifest_content) + assert downloaded_manifest_content == original_manifest_content def test_get_escrow_manifest_invalid_address(self): - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): get_escrow_manifest(chain_id, "invalid_address") - self.assertEqual(f"Invalid escrow address: invalid_address", str(error.exception)) def test_get_job_launcher_address(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: mock_function.return_value = self.escrow_data job_launcher_address = get_job_launcher_address(chain_id, escrow_address) - self.assertIsInstance(job_launcher_address, str) - self.assertEqual(job_launcher_address, JOB_LAUNCHER_ADDRESS) + assert isinstance(job_launcher_address, str) + assert job_launcher_address == JOB_LAUNCHER_ADDRESS def test_get_job_launcher_address_invalid_address(self): - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): get_job_launcher_address(chain_id, "invalid_address") - self.assertEqual(f"Invalid escrow address: invalid_address", str(error.exception)) def test_get_job_launcher_address_invalid_chain_id(self): - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="123 is not a valid ChainId"): get_job_launcher_address(123, escrow_address) - self.assertEqual(f"123 is not a valid ChainId", str(error.exception)) def test_get_job_launcher_address_empty_escrow(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: mock_function.return_value = None - with self.assertRaises(Exception) as error: + with pytest.raises(Exception, match=f"Can't find escrow {ESCROW_ADDRESS}"): get_job_launcher_address(chain_id, escrow_address) - self.assertEqual(f"Can't find escrow {ESCROW_ADDRESS}", str(error.exception)) def test_get_recording_oracle_address(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: self.escrow_data.recording_oracle = RECORDING_ORACLE_ADDRESS mock_function.return_value = self.escrow_data recording_oracle_address = get_recording_oracle_address(chain_id, escrow_address) - self.assertIsInstance(recording_oracle_address, str) - self.assertEqual(recording_oracle_address, RECORDING_ORACLE_ADDRESS) + assert isinstance(recording_oracle_address, str) + assert recording_oracle_address == RECORDING_ORACLE_ADDRESS def test_get_recording_oracle_address_invalid_address(self): - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): get_recording_oracle_address(chain_id, "invalid_address") - self.assertEqual(f"Invalid escrow address: invalid_address", str(error.exception)) def test_get_recording_oracle_address_invalid_chain_id(self): - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="123 is not a valid ChainId"): get_recording_oracle_address(123, escrow_address) - self.assertEqual(f"123 is not a valid ChainId", str(error.exception)) def test_get_recording_oracle_address_empty_escrow(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: mock_function.return_value = None - with self.assertRaises(Exception) as error: + with pytest.raises(Exception, match=f"Can't find escrow {ESCROW_ADDRESS}"): get_recording_oracle_address(chain_id, escrow_address) - self.assertEqual(f"Can't find escrow {ESCROW_ADDRESS}", str(error.exception)) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_kvstore.py b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_kvstore.py index 32dbe9599d..b2c26160f1 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_kvstore.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_kvstore.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import MagicMock, Mock, patch +import pytest from human_protocol_sdk.constants import ChainId, Status from human_protocol_sdk.escrow import EscrowClientError, EscrowData from human_protocol_sdk.kvstore import KVStoreClient, KVStoreClientError @@ -42,52 +43,54 @@ def setUp(self): ) def test_get_job_launcher_url(self): - with patch("src.chain.kvstore.get_escrow") as mock_escrow, patch( - "src.chain.kvstore.OperatorUtils.get_leader" - ) as mock_leader: + with ( + patch("src.chain.kvstore.get_escrow") as mock_escrow, + patch("src.chain.kvstore.OperatorUtils.get_leader") as mock_leader, + ): mock_escrow.return_value = self.escrow_data mock_leader.return_value = MagicMock(webhook_url=DEFAULT_MANIFEST_URL) recording_url = get_job_launcher_url(self.w3.eth.chain_id, escrow_address) - self.assertEqual(recording_url, DEFAULT_MANIFEST_URL) + assert recording_url == DEFAULT_MANIFEST_URL def test_get_job_launcher_url_invalid_escrow(self): - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): get_job_launcher_url(self.w3.eth.chain_id, "invalid_address") - self.assertEqual(f"Invalid escrow address: invalid_address", str(error.exception)) def test_get_job_launcher_url_invalid_recording_address(self): - with patch("src.chain.kvstore.get_escrow") as mock_escrow, patch( - "src.chain.kvstore.OperatorUtils.get_leader" - ) as mock_leader: + with ( + patch("src.chain.kvstore.get_escrow") as mock_escrow, + patch("src.chain.kvstore.OperatorUtils.get_leader") as mock_leader, + ): mock_escrow.return_value = self.escrow_data mock_leader.return_value = MagicMock(webhook_url="") recording_url = get_job_launcher_url(self.w3.eth.chain_id, escrow_address) - self.assertEqual(recording_url, "") + assert recording_url == "" def test_get_recording_oracle_url(self): - with patch("src.chain.kvstore.get_escrow") as mock_escrow, patch( - "src.chain.kvstore.OperatorUtils.get_leader" - ) as mock_leader: + with ( + patch("src.chain.kvstore.get_escrow") as mock_escrow, + patch("src.chain.kvstore.OperatorUtils.get_leader") as mock_leader, + ): self.escrow_data.recording_oracle = RECORDING_ORACLE_ADDRESS mock_escrow.return_value = self.escrow_data mock_leader.return_value = MagicMock(webhook_url=DEFAULT_MANIFEST_URL) recording_url = get_recording_oracle_url(self.w3.eth.chain_id, escrow_address) - self.assertEqual(recording_url, DEFAULT_MANIFEST_URL) + assert recording_url == DEFAULT_MANIFEST_URL def test_get_recording_oracle_url_invalid_escrow(self): - with self.assertRaises(EscrowClientError) as error: + with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): get_recording_oracle_url(self.w3.eth.chain_id, "invalid_address") - self.assertEqual(f"Invalid escrow address: invalid_address", str(error.exception)) def test_get_recording_oracle_url_invalid_recording_address(self): - with patch("src.chain.kvstore.get_escrow") as mock_escrow, patch( - "src.chain.kvstore.OperatorUtils.get_leader" - ) as mock_leader: + with ( + patch("src.chain.kvstore.get_escrow") as mock_escrow, + patch("src.chain.kvstore.OperatorUtils.get_leader") as mock_leader, + ): self.escrow_data.recording_oracle = RECORDING_ORACLE_ADDRESS mock_escrow.return_value = self.escrow_data mock_leader.return_value = MagicMock(webhook_url="") recording_url = get_recording_oracle_url(self.w3.eth.chain_id, escrow_address) - self.assertEqual(recording_url, "") + assert recording_url == "" def test_store_public_key(self): PGP_PUBLIC_KEY_URL_1 = "http://pgp-public-key-url-1" @@ -107,7 +110,7 @@ def get_file_url_and_verify_hash(*args, **kwargs): hash_ = store["public_key_hash"] if hash_ != hash(public_key): - raise KVStoreClientError(f"Invalid hash") + raise KVStoreClientError("Invalid hash") return public_key @@ -134,7 +137,7 @@ def set_file_url_and_hash(url: str, key: str): mock_web3.return_value = self.w3 kvstore_client = KVStoreClient(self.w3) - self.assertIsNone(kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr)) + assert kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr) is None # check that public key will be set to KVStore at first time with patch( @@ -143,9 +146,9 @@ def set_file_url_and_hash(url: str, key: str): mock_set_file_url_and_hash.side_effect = set_file_url_and_hash register_in_kvstore() mock_set_file_url_and_hash.assert_called_once() - self.assertEquals( - kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr), - PGP_PUBLIC_KEY_URL_1, + assert ( + kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr) + == PGP_PUBLIC_KEY_URL_1 ) # check that the same public key URL is not written to KVStore a second time @@ -156,7 +159,8 @@ def set_file_url_and_hash(url: str, key: str): register_in_kvstore() mock_set_file_url_and_hash.assert_not_called() - # check that public key URL and hash will be updated in KVStore if previous hash is outdated/corrupted + # check that public key URL and hash will be updated in KVStore + # if previous hash is outdated/corrupted with patch( "human_protocol_sdk.kvstore.KVStoreClient.set_file_url_and_hash", Mock() ) as mock_set_file_url_and_hash: @@ -164,9 +168,10 @@ def set_file_url_and_hash(url: str, key: str): store["public_key_hash"] = "corrupted_hash" register_in_kvstore() mock_set_file_url_and_hash.assert_called_once() - self.assertNotEquals(store["public_key_hash"], "corrupted_hash") + assert store["public_key_hash"] != "corrupted_hash" - # check that a new public key URL will be written to KVStore when an outdated URL is stored there + # check that a new public key URL will be written to KVStore when + # an outdated URL is stored there with ( patch( "src.core.config.Config.encryption_config.pgp_public_key_url", @@ -179,7 +184,7 @@ def set_file_url_and_hash(url: str, key: str): mock_set_file_url_and_hash.side_effect = set_file_url_and_hash register_in_kvstore() mock_set_file_url_and_hash.assert_called_once() - self.assertEquals( - kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr), - PGP_PUBLIC_KEY_URL_2, + assert ( + kvstore_client.get_file_url_and_verify_hash(LocalhostConfig.addr) + == PGP_PUBLIC_KEY_URL_2 ) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_web3.py b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_web3.py index 9c7da3e4f2..22f430e478 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_web3.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_web3.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import patch +import pytest from human_protocol_sdk.constants import ChainId from web3 import HTTPProvider, Web3 from web3.middleware import construct_sign_and_send_raw_middleware @@ -32,9 +33,9 @@ class PolygonMainnetConfig: with patch("src.chain.web3.Config.polygon_mainnet", PolygonMainnetConfig): w3 = get_web3(ChainId.POLYGON.value) - self.assertIsInstance(w3, Web3) - self.assertEqual(w3.eth.default_account, DEFAULT_GAS_PAYER) - self.assertEqual(w3.manager._provider.endpoint_uri, PolygonMainnetConfig.rpc_api) + assert isinstance(w3, Web3) + assert w3.eth.default_account == DEFAULT_GAS_PAYER + assert w3.manager._provider.endpoint_uri == PolygonMainnetConfig.rpc_api def test_get_web3_amoy(self): class PolygonAmoyConfig: @@ -44,23 +45,19 @@ class PolygonAmoyConfig: with patch("src.chain.web3.Config.polygon_amoy", PolygonAmoyConfig): w3 = get_web3(ChainId.POLYGON_AMOY.value) - self.assertIsInstance(w3, Web3) - self.assertEqual(w3.eth.default_account, DEFAULT_GAS_PAYER) - self.assertEqual(w3.manager._provider.endpoint_uri, PolygonAmoyConfig.rpc_api) + assert isinstance(w3, Web3) + assert w3.eth.default_account == DEFAULT_GAS_PAYER + assert w3.manager._provider.endpoint_uri == PolygonAmoyConfig.rpc_api def test_get_web3_localhost(self): w3 = get_web3(ChainId.LOCALHOST.value) - self.assertIsInstance(w3, Web3) - self.assertEqual(w3.eth.default_account, DEFAULT_GAS_PAYER) - self.assertEqual(w3.manager._provider.endpoint_uri, LocalhostConfig.rpc_api) + assert isinstance(w3, Web3) + assert w3.eth.default_account == DEFAULT_GAS_PAYER + assert w3.manager._provider.endpoint_uri == LocalhostConfig.rpc_api def test_get_web3_invalid_chain_id(self): - with self.assertRaises(ValueError) as error: - w3 = get_web3(1234) - self.assertEqual( - "1234 is not in available list of networks.", - str(error.exception), - ) + with pytest.raises(ValueError, match="1234 is not in available list of networks."): + get_web3(1234) def test_sign_message_polygon(self): with patch("src.chain.web3.get_web3") as mock_function: @@ -70,8 +67,8 @@ def test_sign_message_polygon(self): ): mock_function.return_value = self.w3 signature, serialized_message = sign_message(ChainId.POLYGON.value, "message") - self.assertEqual(signature, SIGNATURE) - self.assertEqual(serialized_message, json.dumps("message")) + assert signature == SIGNATURE + assert serialized_message == json.dumps("message") def test_sign_message_amoy(self): with patch("src.chain.web3.get_web3") as mock_function: @@ -81,37 +78,29 @@ def test_sign_message_amoy(self): ): mock_function.return_value = self.w3 signature, serialized_message = sign_message(ChainId.POLYGON_AMOY.value, "message") - self.assertEqual(signature, SIGNATURE) - self.assertEqual(serialized_message, json.dumps("message")) + assert signature == SIGNATURE + assert serialized_message == json.dumps("message") def test_sign_message_invalid_chain_id(self): - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="1234 is not in available list of networks."): sign_message(1234, "message") - self.assertEqual( - "1234 is not in available list of networks.", - str(error.exception), - ) def test_recover_signer(self): with patch("src.chain.web3.get_web3") as mock_function: mock_function.return_value = self.w3 signer = recover_signer(ChainId.POLYGON.value, "message", SIGNATURE) - self.assertEqual(signer, DEFAULT_GAS_PAYER) + assert signer == DEFAULT_GAS_PAYER def test_recover_signer_invalid_signature(self): with patch("src.chain.web3.get_web3") as mock_function: mock_function.return_value = self.w3 signer = recover_signer(ChainId.POLYGON.value, "test", SIGNATURE) - self.assertNotEqual(signer, DEFAULT_GAS_PAYER) + assert signer != DEFAULT_GAS_PAYER def test_validate_address(self): address = validate_address(DEFAULT_GAS_PAYER) - self.assertEqual(address, DEFAULT_GAS_PAYER) + assert address == DEFAULT_GAS_PAYER def test_validate_address_invalid_address(self): - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError, match="invalid_address is not a correct Web3 address"): validate_address("invalid_address") - self.assertEqual( - f"invalid_address is not a correct Web3 address", - str(error.exception), - ) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/__init__.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/__init__.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py index 446482bfae..2a0008ce37 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py @@ -3,20 +3,16 @@ from datetime import datetime, timedelta from unittest.mock import patch +import pytest from sqlalchemy import update -from sqlalchemy.sql import select from src.core.types import ( AssignmentStatuses, - JobStatuses, - Networks, ProjectStatuses, - TaskStatuses, - TaskTypes, ) from src.crons.state_trackers import track_assignments from src.db import SessionLocal -from src.models.cvat import Assignment, Job, Project, Task, User +from src.models.cvat import Assignment, Project, User from tests.utils.db_helper import create_project_task_and_job @@ -67,8 +63,8 @@ def test_track_expired_assignments(self): db_assignments = sorted( self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id ) - self.assertEqual(db_assignments[0].status, AssignmentStatuses.created.value) - self.assertEqual(db_assignments[1].status, AssignmentStatuses.created.value) + assert db_assignments[0].status == AssignmentStatuses.created.value + assert db_assignments[1].status == AssignmentStatuses.created.value with patch("src.crons.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: track_assignments() @@ -79,74 +75,77 @@ def test_track_expired_assignments(self): db_assignments = sorted( self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id ) - self.assertEqual(db_assignments[0].status, AssignmentStatuses.created.value) - self.assertEqual(db_assignments[1].status, AssignmentStatuses.expired.value) - - # TODO: - # Fix src/crons/state_trackers.py - # Where in `cvat_service.get_active_assignments()` return value will be empty - # because it actually looking for the expired assignments - - # def test_track_canceled_assignments(self): - # (_, _, cvat_job) = create_project_task_and_job( - # self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 - # ) - # (cvat_project_2, _, cvat_job_2) = create_project_task_and_job( - # self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC68", 2 - # ) - # wallet_address_1 = "0x86e83d346041E8806e352681f3F14549C0d2BC67" - # user = User( - # wallet_address=wallet_address_1, - # cvat_email="test@hmt.ai", - # cvat_id=1, - # ) - # self.session.add(user) - - # wallet_address_2 = "0x86e83d346041E8806e352681f3F14549C0d2BC68" - # user = User( - # wallet_address=wallet_address_2, - # cvat_email="test2@hmt.ai", - # cvat_id=2, - # ) - # self.session.add(user) - # assignment = Assignment( - # id=str(uuid.uuid4()), - # user_wallet_address=wallet_address_1, - # cvat_job_id=cvat_job.cvat_id, - # expires_at=datetime.now() + timedelta(days=1), - # ) - # assignment_2 = Assignment( - # id=str(uuid.uuid4()), - # user_wallet_address=wallet_address_2, - # cvat_job_id=cvat_job_2.cvat_id, - # expires_at=datetime.now() + timedelta(days=1), - # created_at=datetime.now() + timedelta(hours=1), - # ) - # self.session.add(assignment) - # self.session.add(assignment_2) - - # self.session.execute( - # update(Project) - # .where(Project.id == cvat_project_2.id) - # .values(status=ProjectStatuses.completed.value) - # ) - - # self.session.commit() - - # db_assignments = sorted( - # self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id - # ) - # self.assertEqual(db_assignments[0].status, AssignmentStatus.created.value) - # self.assertEqual(db_assignments[1].status, AssignmentStatus.created.value) - - # with patch("src.crons.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: - # track_assignments() - # mock_cvat_api.assert_called_once_with(assignment_2.cvat_job_id, assignee_id=None) - - # self.session.commit() - - # db_assignments = sorted( - # self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id - # ) - # self.assertEqual(db_assignments[0].status, AssignmentStatus.created.value) - # self.assertEqual(db_assignments[1].status, AssignmentStatus.canceled.value) + assert db_assignments[0].status == AssignmentStatuses.created.value + assert db_assignments[1].status == AssignmentStatuses.expired.value + + @pytest.mark.xfail( + strict=True, + reason=""" +Fix src/crons/state_trackers.py +Where in `cvat_service.get_active_assignments()` return value will be empty +because it actually looking for the expired assignments +""", + ) + def test_track_canceled_assignments(self): + (_, _, cvat_job) = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + (cvat_project_2, _, cvat_job_2) = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC68", 2 + ) + wallet_address_1 = "0x86e83d346041E8806e352681f3F14549C0d2BC67" + user = User( + wallet_address=wallet_address_1, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + + wallet_address_2 = "0x86e83d346041E8806e352681f3F14549C0d2BC68" + user = User( + wallet_address=wallet_address_2, + cvat_email="test2@hmt.ai", + cvat_id=2, + ) + self.session.add(user) + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=wallet_address_1, + cvat_job_id=cvat_job.cvat_id, + expires_at=datetime.now() + timedelta(days=1), + ) + assignment_2 = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=wallet_address_2, + cvat_job_id=cvat_job_2.cvat_id, + expires_at=datetime.now() + timedelta(days=1), + created_at=datetime.now() + timedelta(hours=1), + ) + self.session.add(assignment) + self.session.add(assignment_2) + + self.session.execute( + update(Project) + .where(Project.id == cvat_project_2.id) + .values(status=ProjectStatuses.completed.value) + ) + + self.session.commit() + + db_assignments = sorted( + self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id + ) + assert db_assignments[0].status == AssignmentStatuses.created.value + assert db_assignments[1].status == AssignmentStatuses.created.value + + with patch("src.crons.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: + track_assignments() + mock_cvat_api.assert_called_once_with(assignment_2.cvat_job_id, assignee_id=None) + + self.session.commit() + + db_assignments = sorted( + self.session.query(Assignment).all(), key=lambda assignment: assignment.user.cvat_id + ) + assert db_assignments[0].status == AssignmentStatuses.created.value + assert db_assignments[1].status == AssignmentStatuses.canceled.value diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py index 0844eb5560..dfcf5025f0 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py @@ -141,14 +141,13 @@ def test_retrieve_annotations(self): .filter_by(escrow_address=escrow_address, chain_id=Networks.localhost.value) .first() ) - self.assertIsNotNone(webhook) - self.assertEqual(webhook.event_type, ExchangeOracleEventTypes.task_finished) + assert webhook is not None + assert webhook.event_type == ExchangeOracleEventTypes.task_finished db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.validation) + assert db_project.status == ProjectStatuses.validation - @patch("src.cvat.api_calls.get_job_annotations") - def test_retrieve_annotations_error_getting_annotations(self, mock_annotations): + def test_retrieve_annotations_error_getting_annotations(self): cvat_project_id = 1 escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" project_id = str(uuid.uuid4()) @@ -232,11 +231,11 @@ def test_retrieve_annotations_error_getting_annotations(self, mock_annotations): .filter_by(escrow_address=escrow_address, chain_id=Networks.localhost.value) .first() ) - self.assertIsNone(webhook) + assert webhook is None db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.completed.value) + assert db_project.status == ProjectStatuses.completed.value def test_retrieve_annotations_error_uploading_files(self): cvat_project_id = 1 @@ -311,11 +310,11 @@ def test_retrieve_annotations_error_uploading_files(self): .filter_by(escrow_address=escrow_address, chain_id=Networks.localhost.value) .first() ) - self.assertIsNone(webhook) + assert webhook is None db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.completed.value) + assert db_project.status == ProjectStatuses.completed.value def test_retrieve_annotations_multiple_projects_per_escrow_all_completed(self): escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -394,9 +393,10 @@ def test_retrieve_annotations_multiple_projects_per_escrow_all_completed(self): def _fake_get_annotations(*args, **kwargs): dummy_zip_file = io.BytesIO() - with zipfile.ZipFile( - dummy_zip_file, "w" - ) as archive, TemporaryDirectory() as tempdir: + with ( + zipfile.ZipFile(dummy_zip_file, "w") as archive, + TemporaryDirectory() as tempdir, + ): mock_dataset = dm.Dataset( media_type=dm.Image, categories={ @@ -443,17 +443,17 @@ def _fake_postprocess_annotations( .filter_by(escrow_address=escrow_address, chain_id=Networks.localhost.value) .first() ) - self.assertIsNotNone(webhook) - self.assertEqual(webhook.event_type, ExchangeOracleEventTypes.task_finished) + assert webhook is not None + assert webhook.event_type == ExchangeOracleEventTypes.task_finished db_projects = ( self.session.query(Project) .where(Project.id.in_([project1.id, project2.id, project3.id])) .all() ) - self.assertEqual(len(db_projects), 3) + assert len(db_projects) == 3 for db_project in db_projects: - self.assertEqual(db_project.status, ProjectStatuses.validation) + assert db_project.status == ProjectStatuses.validation def test_retrieve_annotations_multiple_projects_per_escrow_some_completed_some_validation(self): escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -532,9 +532,10 @@ def test_retrieve_annotations_multiple_projects_per_escrow_some_completed_some_v def _fake_get_annotations(*args, **kwargs): dummy_zip_file = io.BytesIO() - with zipfile.ZipFile( - dummy_zip_file, "w" - ) as archive, TemporaryDirectory() as tempdir: + with ( + zipfile.ZipFile(dummy_zip_file, "w") as archive, + TemporaryDirectory() as tempdir, + ): mock_dataset = dm.Dataset( media_type=dm.Image, categories={ @@ -581,17 +582,17 @@ def _fake_postprocess_annotations( .filter_by(escrow_address=escrow_address, chain_id=Networks.localhost.value) .first() ) - self.assertIsNotNone(webhook) - self.assertEqual(webhook.event_type, ExchangeOracleEventTypes.task_finished) + assert webhook is not None + assert webhook.event_type == ExchangeOracleEventTypes.task_finished db_projects = ( self.session.query(Project) .where(Project.id.in_([project1.id, project2.id, project3.id])) .all() ) - self.assertEqual(len(db_projects), 3) + assert len(db_projects) == 3 for db_project in db_projects: - self.assertEqual(db_project.status, ProjectStatuses.validation) + assert db_project.status == ProjectStatuses.validation def test_retrieve_annotations_multiple_projects_per_escrow_all_validation(self): escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -671,13 +672,13 @@ def test_retrieve_annotations_multiple_projects_per_escrow_all_validation(self): mock_postprocess_annotations.assert_not_called() - self.assertEqual(self.session.query(Webhook).count(), 0) + assert self.session.query(Webhook).count() == 0 db_projects = ( self.session.query(Project) .where(Project.id.in_([project1.id, project2.id, project3.id])) .all() ) - self.assertEqual(len(db_projects), 3) + assert len(db_projects) == 3 for db_project in db_projects: - self.assertEqual(db_project.status, ProjectStatuses.validation) + assert db_project.status == ProjectStatuses.validation diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py index 489215e0db..36de489508 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py @@ -46,7 +46,7 @@ def test_track_completed_projects(self): self.session.execute(select(Project).where(Project.id == project_id)).scalars().first() ) - self.assertEqual(updated_project.status, ProjectStatuses.completed.value) + assert updated_project.status == ProjectStatuses.completed.value def test_track_completed_projects_with_unfinished_task(self): project_id = str(uuid.uuid4()) @@ -85,4 +85,4 @@ def test_track_completed_projects_with_unfinished_task(self): self.session.execute(select(Project).where(Project.id == project_id)).scalars().first() ) - self.assertEqual(updated_project.status, ProjectStatuses.annotation.value) + assert updated_project.status == ProjectStatuses.annotation.value diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py index 4619c5b7cf..614afd20d6 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py @@ -55,7 +55,7 @@ def test_track_completed_tasks(self): self.session.execute(select(Task).where(Task.id == task_id)).scalars().first() ) - self.assertEqual(updated_task.status, TaskStatuses.completed.value) + assert updated_task.status == TaskStatuses.completed.value def test_track_completed_tasks_with_unfinished_job(self): project = Project( @@ -104,4 +104,4 @@ def test_track_completed_tasks_with_unfinished_job(self): self.session.execute(select(Task).where(Task.id == task_id)).scalars().first() ) - self.assertEqual(updated_task.status, TaskStatuses.annotation.value) + assert updated_task.status == TaskStatuses.annotation.value diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py index 46e1dd9a23..b1cc99f939 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py @@ -46,7 +46,7 @@ def test_track_track_successful_escrow_creation(self): updated_projects = ( self.session.query(Project).where(Project.cvat_id.in_(cvat_project_ids)).all() ) - self.assertEqual( - [p.status for p in updated_projects], - [ProjectStatuses.annotation, ProjectStatuses.annotation], - ) + assert [p.status for p in updated_projects] == [ + ProjectStatuses.annotation, + ProjectStatuses.annotation, + ] diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py index a140355d7d..c04c1181a5 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py @@ -38,9 +38,9 @@ def test_track_track_failed_task_creation(self): track_task_creation() webhook = self.session.query(Webhook).filter_by(escrow_address=escrow_address).first() - self.assertIsNotNone(webhook) + assert webhook is not None data_upload = self.session.query(DataUpload).filter_by(id=upload_id).first() - self.assertIsNone(data_upload) + assert data_upload is None def test_track_track_completed_task_creation(self): escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -75,11 +75,11 @@ def test_track_track_completed_task_creation(self): self.session.commit() jobs = self.session.query(Job).all() - self.assertIsNotNone(jobs) - self.assertEqual(len(jobs), 2) - self.assertTrue(any(job.cvat_id == 2 for job in jobs)) + assert jobs is not None + assert len(jobs) == 2 + assert any(job.cvat_id == 2 for job in jobs) data_upload = self.session.query(DataUpload).filter_by(id=upload_id).first() - self.assertIsNone(data_upload) + assert data_upload is None def test_track_track_completed_task_creation_error(self): escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -107,5 +107,5 @@ def test_track_track_completed_task_creation_error(self): self.session.commit() webhook = self.session.query(Webhook).filter_by(escrow_address=escrow_address).first() - self.assertIsNotNone(webhook) - self.assertEqual(webhook.event_type, ExchangeOracleEventTypes.task_creation_failed) + assert webhook is not None + assert webhook.event_type == ExchangeOracleEventTypes.task_creation_failed diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py index cddac828f7..1fbb1b1c35 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py @@ -92,23 +92,23 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = ( self.session.query(Project) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertEqual(db_project.status, ProjectStatuses.creation.value) + assert db_project.status == ProjectStatuses.creation.value db_escrow_creation_tracker = ( self.session.query(EscrowCreation) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertListEqual(db_escrow_creation_tracker.projects, [db_project]) - self.assertEqual(db_escrow_creation_tracker.total_jobs, 1) + assert db_escrow_creation_tracker.projects == [db_project] + assert db_escrow_creation_tracker.total_jobs == 1 def test_process_incoming_job_launcher_webhooks_escrow_created_type_invalid_escrow_status( self, @@ -138,8 +138,8 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type_invalid_escr updated_webhook = self.session.query(Webhook).filter_by(id=webhok_id).first() - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 def test_process_incoming_job_launcher_webhooks_escrow_created_type_exceed_max_retries( self, @@ -170,8 +170,8 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type_exceed_max_r updated_webhook = self.session.query(Webhook).filter_by(id=webhok_id).first() - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.failed.value) - self.assertEqual(updated_webhook.attempts, 6) + assert updated_webhook.status == OracleWebhookStatuses.failed.value + assert updated_webhook.attempts == 6 new_webhook = ( self.session.query(Webhook) @@ -183,9 +183,9 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type_exceed_max_r .first() ) - self.assertEqual(new_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(new_webhook.event_type, ExchangeOracleEventTypes.task_creation_failed) - self.assertEqual(new_webhook.attempts, 0) + assert new_webhook.status == OracleWebhookStatuses.pending.value + assert new_webhook.event_type == ExchangeOracleEventTypes.task_creation_failed + assert new_webhook.attempts == 0 def test_process_incoming_job_launcher_webhooks_escrow_created_type_remove_when_error( self, @@ -240,15 +240,15 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type_remove_when_ self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 db_project = ( self.session.query(Project) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertIsNone(db_project) + assert db_project is None def test_process_incoming_job_launcher_webhooks_escrow_canceled_type(self): project_id = str(uuid.uuid4()) @@ -290,17 +290,17 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = ( self.session.query(Project) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertEqual(db_project.status, ProjectStatuses.canceled.value) + assert db_project.status == ProjectStatuses.canceled.value - def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_with_multiple_creating_projects( + def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_with_multiple_creating_projects( # noqa: E501 self, ): project_ids = [] @@ -353,22 +353,22 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_with_multip self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = ( self.session.query(Project) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertEqual(db_project.status, ProjectStatuses.canceled.value) + assert db_project.status == ProjectStatuses.canceled.value db_escrow_creation_tracker = ( self.session.query(EscrowCreation) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertTrue(bool(db_escrow_creation_tracker.finished_at)) + assert bool(db_escrow_creation_tracker.finished_at) def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_invalid_status( self, @@ -411,15 +411,15 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_invalid_sta self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 db_project = ( self.session.query(Project) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertEqual(db_project.status, ProjectStatuses.annotation.value) + assert db_project.status == ProjectStatuses.annotation.value def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_invalid_balance( self, @@ -463,15 +463,15 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_invalid_bal self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 db_project = ( self.session.query(Project) .filter_by(escrow_address=escrow_address, chain_id=chain_id) .first() ) - self.assertEqual(db_project.status, ProjectStatuses.annotation.value) + assert db_project.status == ProjectStatuses.annotation.value def test_process_outgoing_job_launcher_webhooks(self): chain_id = Networks.localhost.value @@ -511,8 +511,8 @@ def test_process_outgoing_job_launcher_webhooks(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 mock_httpx_post.assert_called_once() def test_process_outgoing_job_launcher_webhooks_invalid_type(self): @@ -539,5 +539,5 @@ def test_process_outgoing_job_launcher_webhooks_invalid_type(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py index 9f3b6a5880..af027d81a9 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py @@ -73,11 +73,11 @@ def test_process_incoming_recording_oracle_webhooks_task_completed_type(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.recorded.value) + assert db_project.status == ProjectStatuses.recorded.value def test_process_incoming_recording_oracle_webhooks_task_completed_type_invalid_project_status( self, @@ -116,12 +116,12 @@ def test_process_incoming_recording_oracle_webhooks_task_completed_type_invalid_ self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.completed.value) + assert db_project.status == ProjectStatuses.completed.value def test_process_incoming_recording_oracle_webhooks_task_task_rejected_type(self): cvat_id = 1 @@ -179,21 +179,21 @@ def test_process_incoming_recording_oracle_webhooks_task_task_rejected_type(self self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.annotation.value) + assert db_project.status == ProjectStatuses.annotation.value db_task = self.session.query(Task).filter_by(id=task_id).first() - self.assertEqual(db_task.status, TaskStatuses.annotation.value) + assert db_task.status == TaskStatuses.annotation.value db_job = self.session.query(Job).filter_by(id=job_id).first() - self.assertEqual(db_job.status, JobStatuses.new.value) + assert db_job.status == JobStatuses.new.value - def test_process_incoming_recording_oracle_webhooks_task_task_rejected_type_invalid_project_status( + def test_process_incoming_recording_oracle_webhooks_task_task_rejected_type_invalid_project_status( # noqa: E501 self, ): cvat_id = 1 @@ -233,12 +233,12 @@ def test_process_incoming_recording_oracle_webhooks_task_task_rejected_type_inva self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 db_project = self.session.query(Project).filter_by(id=project_id).first() - self.assertEqual(db_project.status, ProjectStatuses.completed.value) + assert db_project.status == ProjectStatuses.completed.value def test_process_outgoing_recording_oracle_webhooks(self): chain_id = Networks.localhost.value @@ -278,8 +278,8 @@ def test_process_outgoing_recording_oracle_webhooks(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.completed.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.completed.value + assert updated_webhook.attempts == 1 mock_httpx_post.assert_called_once() def test_process_outgoing_recording_oracle_webhooks_invalid_type(self): @@ -306,5 +306,5 @@ def test_process_outgoing_recording_oracle_webhooks_invalid_type(self): self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() ) - self.assertEqual(updated_webhook.status, OracleWebhookStatuses.pending.value) - self.assertEqual(updated_webhook.attempts, 1) + assert updated_webhook.status == OracleWebhookStatuses.pending.value + assert updated_webhook.attempts == 1 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/__init__.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py index 0a2bf196af..3a652aa20b 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_cvat.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime, timedelta +import pytest from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import UnmappedInstanceError @@ -50,14 +51,14 @@ def test_create_project(self): project = self.session.query(Project).filter_by(id=p_id).first() - self.assertIsNotNone(project) - self.assertEqual(project.id, p_id) - self.assertEqual(project.cvat_id, cvat_id) - self.assertEqual(project.status, ProjectStatuses.creation.value) - self.assertEqual(project.job_type, job_type) - self.assertEqual(project.escrow_address, escrow_address) - self.assertEqual(project.chain_id, chain_id) - self.assertEqual(project.bucket_url, bucket_url) + assert project is not None + assert project.id == p_id + assert project.cvat_id == cvat_id + assert project.status == ProjectStatuses.creation.value + assert project.job_type == job_type + assert project.escrow_address == escrow_address + assert project.chain_id == chain_id + assert project.bucket_url == bucket_url def test_create_duplicated_project(self): cvat_id = 1 @@ -86,7 +87,7 @@ def test_create_duplicated_project(self): chain_id, bucket_url, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_project_none_cvat_id(self): @@ -104,7 +105,7 @@ def test_create_project_none_cvat_id(self): chain_id, bucket_url, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_project_none_cvat_cloudstorage_id(self): @@ -122,7 +123,7 @@ def test_create_project_none_cvat_cloudstorage_id(self): chain_id, bucket_url, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_project_none_job_type(self): @@ -140,7 +141,7 @@ def test_create_project_none_job_type(self): chain_id, bucket_url, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_project_none_escrow_address(self): @@ -158,7 +159,7 @@ def test_create_project_none_escrow_address(self): chain_id, bucket_url, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_project_none_chain_id(self): @@ -177,7 +178,7 @@ def test_create_project_none_chain_id(self): None, bucket_url, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_project_none_bucket_url(self): @@ -195,7 +196,7 @@ def test_create_project_none_bucket_url(self): chain_id, None, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_get_project_by_id(self): @@ -218,35 +219,35 @@ def test_get_project_by_id(self): project = cvat_service.get_project_by_id(self.session, p_id) - self.assertIsNotNone(project) - self.assertEqual(project.id, p_id) - self.assertEqual(project.cvat_id, cvat_id) - self.assertEqual(project.status, ProjectStatuses.annotation.value) - self.assertEqual(project.job_type, job_type) - self.assertEqual(project.escrow_address, escrow_address) - self.assertEqual(project.bucket_url, bucket_url) + assert project is not None + assert project.id == p_id + assert project.cvat_id == cvat_id + assert project.status == ProjectStatuses.annotation.value + assert project.job_type == job_type + assert project.escrow_address == escrow_address + assert project.bucket_url == bucket_url project = cvat_service.get_project_by_id(self.session, "dummy_id") - self.assertIsNone(project) + assert project is None project = cvat_service.get_project_by_id( self.session, p_id, status_in=[ProjectStatuses.annotation] ) - self.assertIsNotNone(project) - self.assertEqual(project.id, p_id) - self.assertEqual(project.cvat_id, cvat_id) - self.assertEqual(project.status, ProjectStatuses.annotation.value) - self.assertEqual(project.job_type, job_type) - self.assertEqual(project.escrow_address, escrow_address) - self.assertEqual(project.bucket_url, bucket_url) + assert project is not None + assert project.id == p_id + assert project.cvat_id == cvat_id + assert project.status == ProjectStatuses.annotation.value + assert project.job_type == job_type + assert project.escrow_address == escrow_address + assert project.bucket_url == bucket_url project = cvat_service.get_project_by_id( self.session, p_id, status_in=[ProjectStatuses.canceled] ) - self.assertIsNone(project) + assert project is None def test_get_project_by_escrow_address(self): cvat_id = 1 @@ -268,18 +269,18 @@ def test_get_project_by_escrow_address(self): project = cvat_service.get_project_by_escrow_address(self.session, escrow_address) - self.assertIsNotNone(project) - self.assertEqual(project.id, p_id) - self.assertEqual(project.cvat_id, cvat_id) - self.assertEqual(project.status, ProjectStatuses.annotation.value) - self.assertEqual(project.job_type, job_type) - self.assertEqual(project.escrow_address, escrow_address) - self.assertEqual(project.chain_id, chain_id) - self.assertEqual(project.bucket_url, bucket_url) + assert project is not None + assert project.id == p_id + assert project.cvat_id == cvat_id + assert project.status == ProjectStatuses.annotation.value + assert project.job_type == job_type + assert project.escrow_address == escrow_address + assert project.chain_id == chain_id + assert project.bucket_url == bucket_url project = cvat_service.get_project_by_escrow_address(self.session, "invalid escrow address") - self.assertIsNone(project) + assert project is None def test_get_projects_by_status(self): cvat_id = 1 @@ -333,17 +334,17 @@ def test_get_projects_by_status(self): projects = cvat_service.get_projects_by_status(self.session, ProjectStatuses.annotation) - self.assertEqual(len(projects), 3) + assert len(projects) == 3 cvat_service.update_project_status(self.session, p_id, ProjectStatuses.completed) projects = cvat_service.get_projects_by_status(self.session, ProjectStatuses.annotation) - self.assertEqual(len(projects), 2) + assert len(projects) == 2 projects = cvat_service.get_projects_by_status(self.session, ProjectStatuses.completed) - self.assertEqual(len(projects), 1) + assert len(projects) == 1 def test_get_available_projects(self): cvat_id_1 = 456 @@ -353,7 +354,7 @@ def test_get_available_projects(self): projects = cvat_service.get_available_projects(self.session) - self.assertEqual(len(projects), 1) + assert len(projects) == 1 cvat_id_2 = 457 (cvat_project, cvat_task) = create_project_and_task( @@ -377,9 +378,9 @@ def test_get_available_projects(self): ) projects = cvat_service.get_available_projects(self.session) - self.assertEqual(len(projects), 2) - self.assertTrue(any(project.cvat_id == cvat_id_1 for project in projects)) - self.assertTrue(any(project.cvat_id == cvat_id_3 for project in projects)) + assert len(projects) == 2 + assert any(project.cvat_id == cvat_id_1 for project in projects) + assert any(project.cvat_id == cvat_id_3 for project in projects) def test_get_projects_by_assignee(self): wallet_address_1 = "0x86e83d346041E8806e352681f3F14549C0d2BC60" @@ -418,13 +419,13 @@ def test_get_projects_by_assignee(self): projects = cvat_service.get_projects_by_assignee(self.session, wallet_address_1) - self.assertEqual(len(projects), 1) - self.assertEqual(projects[0].cvat_id, cvat_id_1) + assert len(projects) == 1 + assert projects[0].cvat_id == cvat_id_1 projects = cvat_service.get_projects_by_assignee(self.session, wallet_address_2) - self.assertEqual( - len(projects), 0 + assert ( + len(projects) == 0 ) # expired should not be shown, https://github.com/humanprotocol/human-protocol/pull/1879 def test_update_project_status(self): @@ -448,14 +449,14 @@ def test_update_project_status(self): project = cvat_service.get_project_by_id(self.session, p_id) - self.assertIsNotNone(project) - self.assertEqual(project.id, p_id) - self.assertEqual(project.cvat_id, cvat_id) - self.assertEqual(project.status, ProjectStatuses.completed.value) - self.assertEqual(project.job_type, job_type) - self.assertEqual(project.escrow_address, escrow_address) - self.assertEqual(project.chain_id, chain_id) - self.assertEqual(project.bucket_url, bucket_url) + assert project is not None + assert project.id == p_id + assert project.cvat_id == cvat_id + assert project.status == ProjectStatuses.completed.value + assert project.job_type == job_type + assert project.escrow_address == escrow_address + assert project.chain_id == chain_id + assert project.bucket_url == bucket_url def test_delete_project(self): cvat_id_1 = 456 @@ -465,12 +466,12 @@ def test_delete_project(self): ) projects = self.session.query(Project).all() - self.assertEqual(len(projects), 1) + assert len(projects) == 1 cvat_service.delete_project(self.session, project.id) projects = self.session.query(Project).all() - self.assertEqual(len(projects), 0) + assert len(projects) == 0 (cvat_project, cvat_task) = create_project_and_task( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", cvat_id_1 @@ -495,13 +496,13 @@ def test_delete_project(self): cvat_task_db = cvat_service.get_task_by_id(self.session, cvat_task.id) jobs = cvat_service.get_jobs_by_cvat_task_id(self.session, cvat_task_id=cvat_task.cvat_id) - self.assertIsNotNone(cvat_project) - self.assertEqual(cvat_project_db.id, cvat_project.id) + assert cvat_project is not None + assert cvat_project_db.id == cvat_project.id - self.assertIsNotNone(cvat_task_db) - self.assertEqual(cvat_task_db.id, cvat_task.id) + assert cvat_task_db is not None + assert cvat_task_db.id == cvat_task.id - self.assertEqual(len(jobs), 2) + assert len(jobs) == 2 cvat_service.delete_project(self.session, cvat_project_db.id) @@ -509,9 +510,9 @@ def test_delete_project(self): cvat_task_db = cvat_service.get_task_by_id(self.session, cvat_task.id) jobs = cvat_service.get_jobs_by_cvat_task_id(self.session, cvat_task_id=cvat_task.cvat_id) - self.assertIsNone(cvat_project_db) - self.assertIsNone(cvat_task_db) - self.assertEqual(len(jobs), 0) + assert cvat_project_db is None + assert cvat_task_db is None + assert len(jobs) == 0 def test_delete_project_wrong_project_id(self): cvat_id = 456 @@ -519,8 +520,8 @@ def test_delete_project_wrong_project_id(self): create_project(self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", cvat_id) projects = self.session.query(Project).all() - self.assertEqual(len(projects), 1) - with self.assertRaises(UnmappedInstanceError): + assert len(projects) == 1 + with pytest.raises(UnmappedInstanceError): cvat_service.delete_project(self.session, "project_id") def test_create_task(self): @@ -535,11 +536,11 @@ def test_create_task(self): task = self.session.query(Task).filter_by(id=task_id).first() - self.assertIsNotNone(task) - self.assertEqual(task.id, task_id) - self.assertEqual(task.cvat_id, cvat_id) - self.assertEqual(task.cvat_project_id, cvat_project.cvat_id) - self.assertEqual(task.status, status) + assert task is not None + assert task.id == task_id + assert task.cvat_id == cvat_id + assert task.cvat_project_id == cvat_project.cvat_id + assert task.status == status def test_create_task_duplicated_cvat_id(self): cvat_id = 1 @@ -553,22 +554,22 @@ def test_create_task_duplicated_cvat_id(self): self.session.commit() cvat_service.create_task(self.session, cvat_id, cvat_project.cvat_id, status) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_tas_without_project(self): cvat_service.create_task(self.session, 123, 123, TaskStatuses.annotation) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_task_none_cvat_id(self): cvat_service.create_task(self.session, None, 123, TaskStatuses.annotation) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_task_none_cvat_project_id(self): cvat_service.create_task(self.session, 123, None, TaskStatuses.annotation) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_get_task_by_id(self): @@ -579,15 +580,15 @@ def test_get_task_by_id(self): ) task = cvat_service.get_task_by_id(self.session, task_id) - self.assertIsNotNone(task) - self.assertEqual(task.id, task_id) - self.assertEqual(task.cvat_id, cvat_project.cvat_id) - self.assertEqual(task.cvat_project_id, cvat_project.cvat_id) - self.assertEqual(task.status, TaskStatuses.annotation.value) + assert task is not None + assert task.id == task_id + assert task.cvat_id == cvat_project.cvat_id + assert task.cvat_project_id == cvat_project.cvat_id + assert task.status == TaskStatuses.annotation.value task = cvat_service.get_task_by_id(self.session, "dummy_id") - self.assertIsNone(task) + assert task is None def test_get_tasks_by_cvat_id(self): cvat_project = create_project(self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1) @@ -598,15 +599,15 @@ def test_get_tasks_by_cvat_id(self): tasks = cvat_service.get_tasks_by_cvat_id(self.session, [1, 2]) - self.assertEqual(len(tasks), 2) + assert len(tasks) == 2 tasks = cvat_service.get_tasks_by_cvat_id(self.session, [3]) - self.assertEqual(len(tasks), 1) + assert len(tasks) == 1 tasks = cvat_service.get_tasks_by_cvat_id(self.session, [999]) - self.assertEqual(len(tasks), 0) + assert len(tasks) == 0 def test_get_tasks_by_status(self): cvat_project = create_project(self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1) @@ -617,11 +618,11 @@ def test_get_tasks_by_status(self): tasks = cvat_service.get_tasks_by_status(self.session, TaskStatuses.annotation) - self.assertEqual(len(tasks), 2) + assert len(tasks) == 2 tasks = cvat_service.get_tasks_by_status(self.session, TaskStatuses.completed) - self.assertEqual(len(tasks), 1) + assert len(tasks) == 1 def test_update_task_status(self): cvat_project = create_project(self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1) @@ -634,11 +635,11 @@ def test_update_task_status(self): task = cvat_service.get_task_by_id(self.session, task_id) - self.assertIsNotNone(task) - self.assertEqual(task.id, task_id) - self.assertEqual(task.cvat_id, cvat_project.cvat_id) - self.assertEqual(task.cvat_project_id, cvat_project.cvat_id) - self.assertEqual(task.status, TaskStatuses.completed.value) + assert task is not None + assert task.id == task_id + assert task.cvat_id == cvat_project.cvat_id + assert task.cvat_project_id == cvat_project.cvat_id + assert task.status == TaskStatuses.completed.value def test_get_tasks_by_cvat_project_id(self): cvat_project = create_project(self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1) @@ -655,15 +656,15 @@ def test_get_tasks_by_cvat_project_id(self): tasks = cvat_service.get_tasks_by_cvat_project_id(self.session, cvat_project.cvat_id) - self.assertEqual(len(tasks), 3) + assert len(tasks) == 3 tasks = cvat_service.get_tasks_by_cvat_project_id(self.session, cvat_project_2.cvat_id) - self.assertEqual(len(tasks), 1) + assert len(tasks) == 1 tasks = cvat_service.get_tasks_by_cvat_project_id(self.session, 123) - self.assertEqual(len(tasks), 0) + assert len(tasks) == 0 def test_create_data_upload(self): cvat_id = 1 @@ -675,9 +676,9 @@ def test_create_data_upload(self): data_upload = self.session.query(DataUpload).filter_by(task_id=cvat_task.cvat_id).first() - self.assertIsNotNone(data_upload) - self.assertEqual(data_upload.id, data_upload_id) - self.assertEqual(data_upload.task_id, cvat_id) + assert data_upload is not None + assert data_upload.id == data_upload_id + assert data_upload.task_id == cvat_id def test_get_active_task_uploads_by_task_id(self): cvat_id_1 = 1 @@ -695,16 +696,16 @@ def test_get_active_task_uploads_by_task_id(self): data_uploads = cvat_service.get_active_task_uploads_by_task_id( self.session, [cvat_id_1, cvat_id_2] ) - self.assertEqual(len(data_uploads), 2) + assert len(data_uploads) == 2 data_uploads = cvat_service.get_active_task_uploads_by_task_id(self.session, [cvat_id_1]) - self.assertEqual(len(data_uploads), 1) + assert len(data_uploads) == 1 data_uploads = cvat_service.get_active_task_uploads_by_task_id(self.session, [cvat_id_2]) - self.assertEqual(len(data_uploads), 1) + assert len(data_uploads) == 1 data_uploads = cvat_service.get_active_task_uploads_by_task_id(self.session, []) - self.assertEqual(len(data_uploads), 0) + assert len(data_uploads) == 0 def test_get_active_task_uploads(self): cvat_id_1 = 1 @@ -720,10 +721,10 @@ def test_get_active_task_uploads(self): cvat_service.create_data_upload(self.session, cvat_task_2.cvat_id) data_uploads = cvat_service.get_active_task_uploads(self.session) - self.assertEqual(len(data_uploads), 2) + assert len(data_uploads) == 2 data_uploads = cvat_service.get_active_task_uploads(self.session, limit=1) - self.assertEqual(len(data_uploads), 1) + assert len(data_uploads) == 1 def test_get_active_task_uploads(self): cvat_id_1 = 1 @@ -739,17 +740,17 @@ def test_get_active_task_uploads(self): cvat_service.create_data_upload(self.session, cvat_task_2.cvat_id) data_uploads = self.session.query(DataUpload).all() - self.assertEqual(len(data_uploads), 2) + assert len(data_uploads) == 2 cvat_service.finish_data_uploads(self.session, [data_uploads[0]]) data_uploads = self.session.query(DataUpload).all() - self.assertEqual(len(data_uploads), 1) + assert len(data_uploads) == 1 cvat_service.finish_data_uploads(self.session, [data_uploads[0]]) data_uploads = self.session.query(DataUpload).all() - self.assertEqual(len(data_uploads), 0) + assert len(data_uploads) == 0 def test_create_job(self): (cvat_project, cvat_task) = create_project_and_task( @@ -769,14 +770,14 @@ def test_create_job(self): ) job_count = self.session.query(Job).count() - self.assertEqual(job_count, 1) + assert job_count == 1 job = self.session.query(Job).filter_by(id=job_id).first() - self.assertEqual(job.cvat_id, cvat_job_id) - self.assertEqual(job.cvat_task_id, cvat_task_id) - self.assertEqual(job.cvat_project_id, cvat_project_id) - self.assertEqual(job.status, status) + assert job.cvat_id == cvat_job_id + assert job.cvat_task_id == cvat_task_id + assert job.cvat_project_id == cvat_project_id + assert job.status == status def test_create_job_invalid_cvat_id(self): (cvat_project, cvat_task) = create_project_and_task( @@ -790,7 +791,7 @@ def test_create_job_invalid_cvat_id(self): cvat_project_id=cvat_project.cvat_id, status=JobStatuses.new, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_job_without_task(self): @@ -803,7 +804,7 @@ def test_create_job_without_task(self): cvat_task_id=None, status=JobStatuses.new, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_job_invalid_task_reference(self): @@ -818,7 +819,7 @@ def test_create_job_invalid_task_reference(self): cvat_project_id=cvat_project.cvat_id, status=JobStatuses.new, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_job_invalid_project_reference(self): @@ -833,7 +834,7 @@ def test_create_job_invalid_project_reference(self): cvat_project_id=122, status=JobStatuses.new, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_job_duplicated_cvat_id(self): @@ -856,7 +857,7 @@ def test_create_job_duplicated_cvat_id(self): cvat_project_id=cvat_project.cvat_id, status=JobStatuses.new, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_get_job_by_id(self): @@ -878,15 +879,15 @@ def test_get_job_by_id(self): ) job = cvat_service.get_job_by_id(self.session, job_id) - self.assertIsNotNone(job) - self.assertEqual(job.id, job_id) - self.assertEqual(job.cvat_id, cvat_id) - self.assertEqual(job.cvat_task_id, cvat_task.cvat_id) - self.assertEqual(job.cvat_project_id, cvat_project.cvat_id) + assert job is not None + assert job.id == job_id + assert job.cvat_id == cvat_id + assert job.cvat_task_id == cvat_task.cvat_id + assert job.cvat_project_id == cvat_project.cvat_id job = cvat_service.get_job_by_id(self.session, "Dummy id") - self.assertIsNone(job) + assert job is None def test_get_jobs_by_cvat_id(self): (cvat_project, cvat_task) = create_project_and_task( @@ -907,15 +908,15 @@ def test_get_jobs_by_cvat_id(self): jobs = cvat_service.get_jobs_by_cvat_id(self.session, [cvat_id]) - self.assertIsNotNone(jobs) - self.assertEqual(jobs[0].id, job_id) - self.assertEqual(jobs[0].cvat_id, cvat_id) - self.assertEqual(jobs[0].cvat_task_id, cvat_task.cvat_id) - self.assertEqual(jobs[0].cvat_project_id, cvat_project.cvat_id) + assert jobs is not None + assert jobs[0].id == job_id + assert jobs[0].cvat_id == cvat_id + assert jobs[0].cvat_task_id == cvat_task.cvat_id + assert jobs[0].cvat_project_id == cvat_project.cvat_id def test_get_jobs_by_cvat_id_wrong_cvat_id(self): job = cvat_service.get_jobs_by_cvat_id(self.session, [457]) - self.assertEqual(job, []) + assert job == [] def test_update_job_status(self): (cvat_project, cvat_task) = create_project_and_task( @@ -936,10 +937,10 @@ def test_update_job_status(self): ) job = cvat_service.get_job_by_id(self.session, job_id) - self.assertEqual(job.cvat_id, cvat_id) - self.assertEqual(job.cvat_task_id, cvat_task.cvat_id) - self.assertEqual(job.cvat_project_id, cvat_project.cvat_id) - self.assertEqual(job.status, status) + assert job.cvat_id == cvat_id + assert job.cvat_task_id == cvat_task.cvat_id + assert job.cvat_project_id == cvat_project.cvat_id + assert job.status == status new_status = JobStatuses.completed @@ -947,12 +948,12 @@ def test_update_job_status(self): job = self.session.query(Job).filter_by(id=job_id).first() - self.assertIsNotNone(job) - self.assertEqual(job.id, job_id) - self.assertEqual(job.cvat_id, cvat_id) - self.assertEqual(job.cvat_task_id, cvat_task.cvat_id) - self.assertEqual(job.cvat_project_id, cvat_project.cvat_id) - self.assertEqual(job.status, new_status) + assert job is not None + assert job.id == job_id + assert job.cvat_id == cvat_id + assert job.cvat_task_id == cvat_task.cvat_id + assert job.cvat_project_id == cvat_project.cvat_id + assert job.status == new_status def test_get_jobs_by_cvat_task_id(self): (cvat_project, cvat_task) = create_project_and_task( @@ -995,7 +996,7 @@ def test_get_jobs_by_cvat_task_id(self): jobs = cvat_service.get_jobs_by_cvat_task_id(self.session, cvat_task_id=cvat_task.cvat_id) - self.assertEqual(len(jobs), 3) + assert len(jobs) == 3 def test_get_jobs_by_cvat_project_id(self): (cvat_project, cvat_task) = create_project_and_task( @@ -1038,7 +1039,7 @@ def test_get_jobs_by_cvat_project_id(self): jobs = cvat_service.get_jobs_by_cvat_project_id(self.session, cvat_project.cvat_id) - self.assertEqual(len(jobs), 3) + assert len(jobs) == 3 def test_put_user(self): wallet_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -1053,12 +1054,12 @@ def test_put_user(self): db_user = self.session.query(User).filter_by(cvat_id=cvat_id).first() - self.assertIsNotNone(user) - self.assertEqual(user, db_user) - self.assertEqual(user.id, db_user.id) - self.assertEqual(user.cvat_id, cvat_id) - self.assertEqual(user.cvat_email, cvat_email) - self.assertEqual(user.wallet_address, wallet_address) + assert user is not None + assert user == db_user + assert user.id == db_user.id + assert user.cvat_id == cvat_id + assert user.cvat_email == cvat_email + assert user.wallet_address == wallet_address def test_put_user(self): wallet_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -1073,13 +1074,13 @@ def test_put_user(self): db_user = self.session.query(User).filter_by(cvat_id=cvat_id).first() - self.assertIsNotNone(user) - self.assertEqual(user.cvat_id, cvat_id) - self.assertEqual(user.cvat_email, cvat_email) - self.assertEqual(user.wallet_address, wallet_address) - self.assertEqual(db_user.cvat_id, cvat_id) - self.assertEqual(db_user.cvat_email, cvat_email) - self.assertEqual(db_user.wallet_address, wallet_address) + assert user is not None + assert user.cvat_id == cvat_id + assert user.cvat_email == cvat_email + assert user.wallet_address == wallet_address + assert db_user.cvat_id == cvat_id + assert db_user.cvat_email == cvat_email + assert db_user.wallet_address == wallet_address def test_put_user_duplicated_address(self): wallet_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" @@ -1092,9 +1093,9 @@ def test_put_user_duplicated_address(self): self.session.commit() db_users = self.session.query(User).filter_by(wallet_address=wallet_address).all() - self.assertEqual(len(db_users), 1) - self.assertEqual(db_users[0].cvat_id, 1) - self.assertEqual(db_users[0].cvat_email, "test@hmt.ai") + assert len(db_users) == 1 + assert db_users[0].cvat_id == 1 + assert db_users[0].cvat_email == "test@hmt.ai" cvat_service.put_user( self.session, @@ -1105,9 +1106,9 @@ def test_put_user_duplicated_address(self): self.session.commit() db_users = self.session.query(User).filter_by(wallet_address=wallet_address).all() - self.assertEqual(len(db_users), 1) - self.assertEqual(db_users[0].cvat_id, 2) - self.assertEqual(db_users[0].cvat_email, "test2@hmt.ai") + assert len(db_users) == 1 + assert db_users[0].cvat_id == 2 + assert db_users[0].cvat_email == "test2@hmt.ai" def test_put_user_duplicated_email(self): email = "test@hmt.ai" @@ -1124,7 +1125,7 @@ def test_put_user_duplicated_email(self): email, 2, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_put_user_duplicated_id(self): @@ -1142,7 +1143,7 @@ def test_put_user_duplicated_id(self): "test2@hmt.ai", cvat_id, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_get_user_by_id(self): @@ -1164,19 +1165,19 @@ def test_get_user_by_id(self): user_1 = cvat_service.get_user_by_id( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67" ) - self.assertEqual(user_1.cvat_id, 1) - self.assertEqual(user_1.cvat_email, "test@hmt.ai") + assert user_1.cvat_id == 1 + assert user_1.cvat_email == "test@hmt.ai" user_2 = cvat_service.get_user_by_id( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC68" ) - self.assertEqual(user_2.cvat_id, 2) - self.assertEqual(user_2.cvat_email, "test2@hmt.ai") + assert user_2.cvat_id == 2 + assert user_2.cvat_email == "test2@hmt.ai" user_3 = cvat_service.get_user_by_id( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC69" ) - self.assertIsNone(user_3) + assert user_3 is None def test_create_assignment(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1199,14 +1200,14 @@ def test_create_assignment(self): ) assignment_count = self.session.query(Assignment).count() - self.assertEqual(assignment_count, 1) + assert assignment_count == 1 assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() - self.assertIsNotNone(assignment) - self.assertEqual(assignment.user_wallet_address, wallet_address) - self.assertEqual(assignment.cvat_job_id, cvat_job.cvat_id) - self.assertEqual(assignment.status, AssignmentStatuses.created.value) + assert assignment is not None + assert assignment.user_wallet_address == wallet_address + assert assignment.cvat_job_id == cvat_job.cvat_id + assert assignment.status == AssignmentStatuses.created.value def test_create_assignment_invalid_address(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1219,7 +1220,7 @@ def test_create_assignment_invalid_address(self): cvat_job_id=cvat_job.cvat_id, expires_at=datetime.now(), ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_assignment_invalid_address(self): @@ -1237,7 +1238,7 @@ def test_create_assignment_invalid_address(self): cvat_job_id=0, expires_at=datetime.now(), ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_get_assignments_by_id(self): @@ -1274,17 +1275,17 @@ def test_get_assignments_by_id(self): self.session.commit() assignments = cvat_service.get_assignments_by_id(self.session, [assignment, assignment_2]) - self.assertEqual(len(assignments), 2) + assert len(assignments) == 2 assignments = cvat_service.get_assignments_by_id(self.session, [assignment]) - self.assertEqual(len(assignments), 1) - self.assertEqual(assignments[0].id, assignment) - self.assertEqual(assignments[0].user_wallet_address, wallet_address_1) + assert len(assignments) == 1 + assert assignments[0].id == assignment + assert assignments[0].user_wallet_address == wallet_address_1 assignments = cvat_service.get_assignments_by_id(self.session, [assignment_2]) - self.assertEqual(len(assignments), 1) - self.assertEqual(assignments[0].id, assignment_2) - self.assertEqual(assignments[0].user_wallet_address, wallet_address_2) + assert len(assignments) == 1 + assert assignments[0].id == assignment_2 + assert assignments[0].user_wallet_address == wallet_address_2 def test_get_latest_assignment_by_cvat_job_id(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1325,9 +1326,9 @@ def test_get_latest_assignment_by_cvat_job_id(self): received_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( self.session, cvat_job.cvat_id ) - self.assertEqual(received_assignment.id, assignment_2.id) - self.assertNotEqual(received_assignment.id, assignment.id) - self.assertEqual(received_assignment.user_wallet_address, wallet_address_2) + assert received_assignment.id == assignment_2.id + assert received_assignment.id != assignment.id + assert received_assignment.user_wallet_address == wallet_address_2 def test_get_unprocessed_expired_assignments(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1365,10 +1366,10 @@ def test_get_unprocessed_expired_assignments(self): self.session.commit() assignments = cvat_service.get_unprocessed_expired_assignments(self.session) - self.assertEqual(len(assignments), 1) - self.assertEqual(assignments[0].id, assignment_2.id) - self.assertNotEqual(assignments[0].id, assignment.id) - self.assertEqual(assignments[0].user_wallet_address, wallet_address_2) + assert len(assignments) == 1 + assert assignments[0].id == assignment_2.id + assert assignments[0].id != assignment.id + assert assignments[0].user_wallet_address == wallet_address_2 def test_update_assignment(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1397,8 +1398,8 @@ def test_update_assignment(self): db_assignment = self.session.query(Assignment).filter_by(id=assignment.id).first() - self.assertEqual(db_assignment.id, assignment.id) - self.assertEqual(db_assignment.status, AssignmentStatuses.completed) + assert db_assignment.id == assignment.id + assert db_assignment.status == AssignmentStatuses.completed def test_cancel_assignment(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1425,8 +1426,8 @@ def test_cancel_assignment(self): db_assignment = self.session.query(Assignment).filter_by(id=assignment.id).first() - self.assertEqual(db_assignment.id, assignment.id) - self.assertEqual(db_assignment.status, AssignmentStatuses.canceled) + assert db_assignment.id == assignment.id + assert db_assignment.status == AssignmentStatuses.canceled def test_expire_assignment(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1453,8 +1454,8 @@ def test_expire_assignment(self): db_assignment = self.session.query(Assignment).filter_by(id=assignment.id).first() - self.assertEqual(db_assignment.id, assignment.id) - self.assertEqual(db_assignment.status, AssignmentStatuses.expired) + assert db_assignment.id == assignment.id + assert db_assignment.status == AssignmentStatuses.expired def test_complete_assignment(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1481,9 +1482,9 @@ def test_complete_assignment(self): db_assignment = self.session.query(Assignment).filter_by(id=assignment.id).first() - self.assertEqual(db_assignment.id, assignment.id) - self.assertEqual(db_assignment.status, AssignmentStatuses.completed) - self.assertEqual(db_assignment.completed_at, completed_date) + assert db_assignment.id == assignment.id + assert db_assignment.status == AssignmentStatuses.completed + assert db_assignment.completed_at == completed_date def test_test_add_project_images(self): (_, _, cvat_job) = create_project_task_and_job( @@ -1523,10 +1524,10 @@ def test_test_add_project_images(self): assignments = cvat_service.get_user_assignments_in_cvat_projects( self.session, wallet_address_1, [cvat_job.cvat_id] ) - self.assertEqual(len(assignments), 1) - self.assertEqual(assignments[0].id, assignment.id) - self.assertNotEqual(assignments[0].id, assignment_2.id) - self.assertEqual(assignments[0].user_wallet_address, wallet_address_1) + assert len(assignments) == 1 + assert assignments[0].id == assignment.id + assert assignments[0].id != assignment_2.id + assert assignments[0].user_wallet_address == wallet_address_1 def test_add_project_images(self): cvat_project = create_project(self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1) @@ -1543,16 +1544,16 @@ def test_add_project_images(self): self.session.query(Image).where(Image.cvat_project_id == cvat_project.cvat_id).all() ) - self.assertEqual(len(images), 2) - self.assertEqual(images[0].filename, filenames[0]) - self.assertEqual(images[1].filename, filenames[1]) + assert len(images) == 2 + assert images[0].filename == filenames[0] + assert images[1].filename == filenames[1] def test_add_project_images_wrong_project_id(self): filenames = [ "image1.jpg", "image2.jpg", ] - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): cvat_service.add_project_images(self.session, cvat_project_id=1, filenames=filenames) def test_add_project_images(self): @@ -1575,10 +1576,10 @@ def test_add_project_images(self): images = cvat_service.get_project_images(self.session, cvat_project.cvat_id) - self.assertEqual(len(images), 2) - self.assertEqual(images[0].filename, filenames[0]) - self.assertEqual(images[1].filename, filenames[1]) + assert len(images) == 2 + assert images[0].filename == filenames[0] + assert images[1].filename == filenames[1] images = cvat_service.get_project_images(self.session, 2) - self.assertEqual(len(images), 0) + assert len(images) == 0 diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index 7741790418..6bec628590 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta from unittest.mock import patch +import pytest from fastapi import HTTPException from pydantic import ValidationError @@ -49,19 +50,19 @@ def test_serialize_task(self): mock_get_manifest.return_value = manifest data = serialize_task(cvat_project.id) - self.assertEqual(data.id, cvat_project.id) - self.assertEqual(data.escrow_address, escrow_address) - self.assertIn("Task ", data.title) - self.assertTrue(len(data.title.split("Task ")[1]) <= 10) - self.assertIsInstance(data.description, str) - self.assertIsInstance(data.job_bounty, str) - self.assertIsInstance(data.job_time_limit, int) - self.assertIsInstance(data.job_size, int) - self.assertEqual(data.job_type, cvat_project.job_type) - self.assertEqual(data.platform, PlatformTypes.CVAT) - self.assertEqual(data.status, cvat_project.status) - self.assertIsNone(data.assignment) - self.assertIsInstance(data, service_api.TaskResponse) + assert data.id == cvat_project.id + assert data.escrow_address == escrow_address + assert "Task " in data.title + assert len(data.title.split("Task ")[1]) <= 10 + assert isinstance(data.description, str) + assert isinstance(data.job_bounty, str) + assert isinstance(data.job_time_limit, int) + assert isinstance(data.job_size, int) + assert data.job_type == cvat_project.job_type + assert data.platform == PlatformTypes.CVAT + assert data.status == cvat_project.status + assert data.assignment is None + assert isinstance(data, service_api.TaskResponse) def test_serialize_task_with_assignment(self): cvat_id = 1 @@ -96,25 +97,25 @@ def test_serialize_task_with_assignment(self): mock_get_manifest.return_value = manifest data = serialize_task(project_id=cvat_project.id, assignment_id=assignment.id) - self.assertEqual(data.id, cvat_project.id) - self.assertEqual(data.escrow_address, escrow_address) - self.assertIn("Task ", data.title) - self.assertTrue(len(data.title.split("Task ")[1]) <= 10) - self.assertIsInstance(data.description, str) - self.assertIsInstance(data.job_bounty, str) - self.assertIsInstance(data.job_time_limit, int) - self.assertIsInstance(data.job_size, int) - self.assertEqual(data.job_type, cvat_project.job_type) - self.assertEqual(data.platform, PlatformTypes.CVAT) - self.assertEqual(data.status, cvat_project.status) - self.assertIsNotNone(data.assignment) - self.assertIsInstance(data.assignment.assignment_url, str) - self.assertEqual(data.assignment.started_at, assignment.created_at) - self.assertEqual(data.assignment.finishes_at, assignment.expires_at) - self.assertIsInstance(data, service_api.TaskResponse) + assert data.id == cvat_project.id + assert data.escrow_address == escrow_address + assert "Task " in data.title + assert len(data.title.split("Task ")[1]) <= 10 + assert isinstance(data.description, str) + assert isinstance(data.job_bounty, str) + assert isinstance(data.job_time_limit, int) + assert isinstance(data.job_size, int) + assert data.job_type == cvat_project.job_type + assert data.platform == PlatformTypes.CVAT + assert data.status == cvat_project.status + assert data.assignment is not None + assert isinstance(data.assignment.assignment_url, str) + assert data.assignment.started_at == assignment.created_at + assert data.assignment.finishes_at == assignment.expires_at + assert isinstance(data, service_api.TaskResponse) def test_serialize_task_invalid_project(self): - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): serialize_task(project_id=str(uuid.uuid4())) def test_serialize_task_invalid_manifest(self): @@ -126,7 +127,7 @@ def test_serialize_task_invalid_manifest(self): with patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest: mock_get_manifest.return_value = None - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): serialize_task(project_id=cvat_project.id) def test_get_available_tasks(self): @@ -146,11 +147,11 @@ def test_get_available_tasks(self): mock_get_manifest.return_value = manifest tasks = get_available_tasks() - self.assertEqual(len(tasks), 2) - self.assertIsInstance(tasks[0], service_api.TaskResponse) - self.assertIsInstance(tasks[1], service_api.TaskResponse) - self.assertTrue(any(task.id == cvat_project_1.id for task in tasks)) - self.assertTrue(any(task.id == cvat_project_2.id for task in tasks)) + assert len(tasks) == 2 + assert isinstance(tasks[0], service_api.TaskResponse) + assert isinstance(tasks[1], service_api.TaskResponse) + assert any(task.id == cvat_project_1.id for task in tasks) + assert any(task.id == cvat_project_2.id for task in tasks) cvat_service.update_project_status( self.session, cvat_project_2.id, ProjectStatuses.completed @@ -158,9 +159,9 @@ def test_get_available_tasks(self): self.session.commit() tasks = get_available_tasks() - self.assertEqual(len(tasks), 1) - self.assertIsInstance(tasks[0], service_api.TaskResponse) - self.assertEqual(tasks[0].id, cvat_project_1.id) + assert len(tasks) == 1 + assert isinstance(tasks[0], service_api.TaskResponse) + assert tasks[0].id == cvat_project_1.id def test_get_tasks_by_assignee(self): cvat_project_1, _, cvat_job_1 = create_project_task_and_job( @@ -195,14 +196,14 @@ def test_get_tasks_by_assignee(self): mock_get_manifest.return_value = manifest tasks = get_tasks_by_assignee(user_address) - self.assertEqual(len(tasks), 1) - self.assertIsInstance(tasks[0], service_api.TaskResponse) - self.assertEqual(tasks[0].id, cvat_project_1.id) - self.assertIsNotNone(tasks[0].assignment) + assert len(tasks) == 1 + assert isinstance(tasks[0], service_api.TaskResponse) + assert tasks[0].id == cvat_project_1.id + assert tasks[0].assignment is not None def test_get_tasks_by_assignee_invalid_address(self): tasks = get_tasks_by_assignee("invalid_address") - self.assertEqual(len(tasks), 0) + assert len(tasks) == 0 def test_create_assignment(self): cvat_project_1, _, cvat_job_1 = create_project_task_and_job( @@ -228,9 +229,9 @@ def test_create_assignment(self): assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() - self.assertEqual(assignment.cvat_job_id, cvat_job_1.cvat_id) - self.assertEqual(assignment.user_wallet_address, user_address) - self.assertEqual(assignment.status, AssignmentStatuses.created) + assert assignment.cvat_job_id == cvat_job_1.cvat_id + assert assignment.user_wallet_address == user_address + assert assignment.status == AssignmentStatuses.created def test_create_assignment_many_jobs_1_completed(self): cvat_project, _, cvat_job_1 = create_project_task_and_job( @@ -274,9 +275,9 @@ def test_create_assignment_many_jobs_1_completed(self): assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() - self.assertEqual(assignment.cvat_job_id, cvat_job_2.cvat_id) - self.assertEqual(assignment.user_wallet_address, user_address) - self.assertEqual(assignment.status, AssignmentStatuses.created) + assert assignment.cvat_job_id == cvat_job_2.cvat_id + assert assignment.user_wallet_address == user_address + assert assignment.status == AssignmentStatuses.created def test_create_assignment_invalid_user_address(self): cvat_project_1, _, _ = create_project_task_and_job( @@ -284,7 +285,7 @@ def test_create_assignment_invalid_user_address(self): ) self.session.commit() - with self.assertRaises(HTTPException): + with pytest.raises(HTTPException): create_assignment(cvat_project_1.id, "invalid_address") def test_create_assignment_invalid_project(self): @@ -297,7 +298,7 @@ def test_create_assignment_invalid_project(self): self.session.add(user) self.session.commit() - with self.assertRaises(HTTPException): + with pytest.raises(HTTPException): create_assignment("1", user_address) def test_create_assignment_unfinished_assignment(self): @@ -329,7 +330,7 @@ def test_create_assignment_unfinished_assignment(self): manifest = json.load(data) mock_get_manifest.return_value = manifest - with self.assertRaises(HTTPException): + with pytest.raises(HTTPException): create_assignment("1", user_address) def test_create_assignment_no_available_jobs_completed_assignment(self): @@ -377,7 +378,7 @@ def test_create_assignment_no_available_jobs_completed_assignment(self): mock_get_manifest.return_value = manifest assignment_id = create_assignment(cvat_project.id, user_address2) - self.assertEqual(assignment_id, None) + assert assignment_id == None def test_create_assignment_no_available_jobs_active_foreign_assignment(self): cvat_project, _, cvat_job_1 = create_project_task_and_job( @@ -419,7 +420,7 @@ def test_create_assignment_no_available_jobs_active_foreign_assignment(self): mock_get_manifest.return_value = manifest assignment_id = create_assignment(cvat_project.id, user_address2) - self.assertEqual(assignment_id, None) + assert assignment_id == None def test_create_assignment_in_validated_and_rejected_job(self): cvat_project_1, _, cvat_job_1 = create_project_task_and_job( @@ -460,6 +461,6 @@ def test_create_assignment_in_validated_and_rejected_job(self): assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() - self.assertEqual(assignment.cvat_job_id, cvat_job_1.cvat_id) - self.assertEqual(assignment.user_wallet_address, user_address) - self.assertEqual(assignment.status, AssignmentStatuses.created) + assert assignment.cvat_job_id == cvat_job_1.cvat_id + assert assignment.user_wallet_address == user_address + assert assignment.status == AssignmentStatuses.created diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_webhook.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_webhook.py index baa0a7cd6b..caaf24a612 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_webhook.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_webhook.py @@ -1,6 +1,7 @@ import unittest import uuid +import pytest from sqlalchemy.exc import IntegrityError import src.services.webhook as webhook_service @@ -40,14 +41,14 @@ def test_create_incoming_webhook(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, escrow_address) - self.assertEqual(webhook.chain_id, chain_id) - self.assertEqual(webhook.signature, signature) - self.assertEqual(webhook.attempts, 0) - self.assertEqual(webhook.type, OracleWebhookTypes.job_launcher.value) - self.assertEqual(webhook.event_type, JobLauncherEventTypes.escrow_created.value) - self.assertEqual(webhook.event_data, None) - self.assertEqual(webhook.status, OracleWebhookStatuses.pending.value) + assert webhook.escrow_address == escrow_address + assert webhook.chain_id == chain_id + assert webhook.signature == signature + assert webhook.attempts == 0 + assert webhook.type == OracleWebhookTypes.job_launcher.value + assert webhook.event_type == JobLauncherEventTypes.escrow_created.value + assert webhook.event_data == None + assert webhook.status == OracleWebhookStatuses.pending.value def test_create_incoming_webhook_none_escrow_address(self): chain_id = Networks.localhost.value @@ -60,7 +61,7 @@ def test_create_incoming_webhook_none_escrow_address(self): type=OracleWebhookTypes.job_launcher, event_type=JobLauncherEventTypes.escrow_created.value, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_incoming_webhook_none_chain_id(self): @@ -74,13 +75,19 @@ def test_create_incoming_webhook_none_chain_id(self): type=OracleWebhookTypes.job_launcher, event_type=JobLauncherEventTypes.escrow_created.value, ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_incoming_webhook_none_event_type(self): escrow_address = "0x1234567890123456789012345678901234567890" signature = "signature" - with self.assertRaises(AssertionError) as error: + with pytest.raises( + AssertionError, + match=( + "'event' and 'event_type' cannot be used together. " + "Please use only one of the fields" + ), + ): webhook_service.inbox.create_webhook( self.session, escrow_address=escrow_address, @@ -88,16 +95,14 @@ def test_create_incoming_webhook_none_event_type(self): signature=signature, type=OracleWebhookTypes.job_launcher, ) - self.assertEqual( - str(error.exception), - "'event' and 'event_type' cannot be used together. Please use only one of the fields", - ) def test_create_incoming_webhook_none_signature(self): escrow_address = "0x1234567890123456789012345678901234567890" chain_id = Networks.localhost.value - with self.assertRaises(ValueError) as error: + with pytest.raises( + ValueError, match="Webhook signature must be specified for incoming events" + ): webhook_service.inbox.create_webhook( self.session, escrow_address=escrow_address, @@ -105,9 +110,6 @@ def test_create_incoming_webhook_none_signature(self): type=OracleWebhookTypes.job_launcher, event_type=JobLauncherEventTypes.escrow_created.value, ) - self.assertEqual( - str(error.exception), "Webhook signature must be specified for incoming events" - ) def test_create_outgoing_webhook(self): escrow_address = "0x1234567890123456789012345678901234567890" @@ -123,13 +125,13 @@ def test_create_outgoing_webhook(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, escrow_address) - self.assertEqual(webhook.chain_id, chain_id) - self.assertEqual(webhook.attempts, 0) - self.assertEqual(webhook.type, OracleWebhookTypes.exchange_oracle.value) - self.assertEqual(webhook.event_type, ExchangeOracleEventTypes.task_finished.value) - self.assertEqual(webhook.event_data, {}) - self.assertEqual(webhook.status, OracleWebhookStatuses.pending.value) + assert webhook.escrow_address == escrow_address + assert webhook.chain_id == chain_id + assert webhook.attempts == 0 + assert webhook.type == OracleWebhookTypes.exchange_oracle.value + assert webhook.event_type == ExchangeOracleEventTypes.task_finished.value + assert webhook.event_data == {} + assert webhook.status == OracleWebhookStatuses.pending.value def test_create_outgoing_webhook_none_escrow_address(self): chain_id = Networks.localhost.value @@ -140,7 +142,7 @@ def test_create_outgoing_webhook_none_escrow_address(self): type=OracleWebhookTypes.exchange_oracle, event=ExchangeOracleEvent_TaskFinished(), ) - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): self.session.commit() def test_create_outgoing_webhook_none_chain_id(self): @@ -152,29 +154,33 @@ def test_create_outgoing_webhook_none_chain_id(self): type=OracleWebhookTypes.exchange_oracle, event=ExchangeOracleEvent_TaskFinished(), ) - with self.assertRaises(IntegrityError) as error: + with pytest.raises(IntegrityError): self.session.commit() def test_create_outgoing_webhook_none_event_type(self): escrow_address = "0x1234567890123456789012345678901234567890" - with self.assertRaises(AssertionError) as error: + with pytest.raises( + AssertionError, + match=( + "'event' and 'event_type' cannot be used together. " + "Please use only one of the fields" + ), + ): webhook_service.outbox.create_webhook( self.session, escrow_address=escrow_address, chain_id=None, type=OracleWebhookTypes.exchange_oracle, ) - self.assertEqual( - str(error.exception), - "'event' and 'event_type' cannot be used together. Please use only one of the fields", - ) def test_create_outgoing_webhook_with_signature(self): escrow_address = "0x1234567890123456789012345678901234567890" chain_id = Networks.localhost.value signature = "signature" - with self.assertRaises(ValueError) as error: + with pytest.raises( + ValueError, match="Webhook signature must not be specified for outgoing events" + ): webhook_service.outbox.create_webhook( self.session, escrow_address=escrow_address, @@ -183,9 +189,6 @@ def test_create_outgoing_webhook_with_signature(self): event=ExchangeOracleEvent_TaskFinished(), signature=signature, ) - self.assertEqual( - str(error.exception), "Webhook signature must not be specified for outgoing events" - ) def test_get_pending_webhooks(self): chain_id = Networks.localhost.value @@ -256,21 +259,21 @@ def test_get_pending_webhooks(self): pending_webhooks = webhook_service.inbox.get_pending_webhooks( self.session, type=OracleWebhookTypes.job_launcher, limit=10 ) - self.assertEqual(len(pending_webhooks), 2) - self.assertEqual(pending_webhooks[0].id, webhook1_id) - self.assertEqual(pending_webhooks[1].id, webhook2_id) + assert len(pending_webhooks) == 2 + assert pending_webhooks[0].id == webhook1_id + assert pending_webhooks[1].id == webhook2_id pending_webhooks = webhook_service.inbox.get_pending_webhooks( self.session, type=OracleWebhookTypes.recording_oracle, limit=10 ) - self.assertEqual(len(pending_webhooks), 1) - self.assertEqual(pending_webhooks[0].id, webhook4_id) + assert len(pending_webhooks) == 1 + assert pending_webhooks[0].id == webhook4_id pending_webhooks = webhook_service.outbox.get_pending_webhooks( self.session, type=OracleWebhookTypes.job_launcher, limit=10 ) - self.assertEqual(len(pending_webhooks), 1) - self.assertEqual(pending_webhooks[0].id, webhook5_id) + assert len(pending_webhooks) == 1 + assert pending_webhooks[0].id == webhook5_id def test_update_webhook_status(self): escrow_address = "0x1234567890123456789012345678901234567890" @@ -292,12 +295,12 @@ def test_update_webhook_status(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, escrow_address) - self.assertEqual(webhook.chain_id, chain_id) - self.assertEqual(webhook.attempts, 0) - self.assertEqual(webhook.signature, signature) - self.assertEqual(webhook.type, OracleWebhookTypes.job_launcher.value) - self.assertEqual(webhook.status, OracleWebhookStatuses.completed.value) + assert webhook.escrow_address == escrow_address + assert webhook.chain_id == chain_id + assert webhook.attempts == 0 + assert webhook.signature == signature + assert webhook.type == OracleWebhookTypes.job_launcher.value + assert webhook.status == OracleWebhookStatuses.completed.value def test_handle_webhook_success(self): escrow_address = "0x1234567890123456789012345678901234567890" @@ -317,12 +320,12 @@ def test_handle_webhook_success(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, escrow_address) - self.assertEqual(webhook.chain_id, chain_id) - self.assertEqual(webhook.attempts, 1) - self.assertEqual(webhook.signature, signature) - self.assertEqual(webhook.type, OracleWebhookTypes.job_launcher.value) - self.assertEqual(webhook.status, OracleWebhookStatuses.completed.value) + assert webhook.escrow_address == escrow_address + assert webhook.chain_id == chain_id + assert webhook.attempts == 1 + assert webhook.signature == signature + assert webhook.type == OracleWebhookTypes.job_launcher.value + assert webhook.status == OracleWebhookStatuses.completed.value def test_handle_webhook_fail(self): escrow_address = "0x1234567890123456789012345678901234567890" @@ -342,21 +345,21 @@ def test_handle_webhook_fail(self): webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, escrow_address) - self.assertEqual(webhook.chain_id, chain_id) - self.assertEqual(webhook.attempts, 1) - self.assertEqual(webhook.signature, signature) - self.assertEqual(webhook.type, OracleWebhookTypes.job_launcher.value) - self.assertEqual(webhook.status, OracleWebhookStatuses.pending.value) + assert webhook.escrow_address == escrow_address + assert webhook.chain_id == chain_id + assert webhook.attempts == 1 + assert webhook.signature == signature + assert webhook.type == OracleWebhookTypes.job_launcher.value + assert webhook.status == OracleWebhookStatuses.pending.value - for i in range(4): + for _i in range(4): webhook_service.inbox.handle_webhook_fail(self.session, webhook_id) webhook = self.session.query(Webhook).filter_by(id=webhook_id).first() - self.assertEqual(webhook.escrow_address, escrow_address) - self.assertEqual(webhook.chain_id, chain_id) - self.assertEqual(webhook.attempts, 5) - self.assertEqual(webhook.signature, signature) - self.assertEqual(webhook.type, OracleWebhookTypes.job_launcher.value) - self.assertEqual(webhook.status, OracleWebhookStatuses.failed.value) + assert webhook.escrow_address == escrow_address + assert webhook.chain_id == chain_id + assert webhook.attempts == 5 + assert webhook.signature == signature + assert webhook.type == OracleWebhookTypes.job_launcher.value + assert webhook.status == OracleWebhookStatuses.failed.value diff --git a/packages/examples/cvat/exchange-oracle/tests/unit/__init__.py b/packages/examples/cvat/exchange-oracle/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/unit/helpers/__init__.py b/packages/examples/cvat/exchange-oracle/tests/unit/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/constants.py b/packages/examples/cvat/exchange-oracle/tests/utils/constants.py index 0e7e580002..2a5082c255 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/constants.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/constants.py @@ -23,7 +23,10 @@ DEFAULT_MANIFEST_URL = "http://host.docker.internal:9000/manifests/manifest.json" DEFAULT_HASH = "test" -SIGNATURE = "0xa0c5626301e3c198cb91356e492890c0c28db8c37044846134939246911a693c4d7116d04aa4bc40a41077493868b8dd533d30980f6addb28d1b3610a84cb4091c" +SIGNATURE = ( + "0xa0c5626301e3c198cb91356e492890c0c28db8c37044846134939246911a693c" + "4d7116d04aa4bc40a41077493868b8dd533d30980f6addb28d1b3610a84cb4091c" +) WEBHOOK_MESSAGE = { "escrow_address": "0x12E66A452f95bff49eD5a30b0d06Ebc37C5A94B6", @@ -32,7 +35,10 @@ "event_data": {}, } -WEBHOOK_MESSAGE_SIGNED = "0x82d5c5845da8456226baf58862c1cefd964c884464f73b66abed938475bbd7e810bf99a10f2ad68e7febb7460112788c060e16e25d0e7c4e2e2dc7aafd9b81861c" +WEBHOOK_MESSAGE_SIGNED = ( + "0x82d5c5845da8456226baf58862c1cefd964c884464f73b66abed938475bbd7e8" + "10bf99a10f2ad68e7febb7460112788c060e16e25d0e7c4e2e2dc7aafd9b81861c" +) PGP_PASSPHRASE = "passphrase" PGP_PRIVATE_KEY1 = dedent( diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py index 25d7baaf8c..3a87d3a57f 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py @@ -14,7 +14,7 @@ def generate_cvat_signature(data: dict): b_data = json.dumps(data).encode("utf-8") - signature = ( + return ( "sha256=" + hmac.new( CvatConfig.cvat_webhook_secret.encode("utf-8"), @@ -23,8 +23,6 @@ def generate_cvat_signature(data: dict): ).hexdigest() ) - return signature - def add_cvat_project_to_db(cvat_id: int) -> str: with SessionLocal.begin() as session: