diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 00000000..a24bfec1 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,17 @@ +name: Check mypy + +on: + - pull_request +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install deps + run: pip install -r requirements.txt -r requirements-dev.txt + - name: Run mypy + run: mypy diff --git a/docs/installation/docker.rst b/docs/installation/docker.rst index b2e76bbc..5b1112c3 100644 --- a/docs/installation/docker.rst +++ b/docs/installation/docker.rst @@ -120,11 +120,12 @@ TAXII 2 instance with Compose Checkout the configuration at: :github-file:`examples/docker-compose-taxii2.yml `. -To add dummy data, you can execute: - .. code-block:: shell - # while the compose project is running + # Start + docker compose -f examples/docker-compose-taxii2.yml up + + # To add dummy data, run this while the compose project is running docker exec -i examples-opentaxii-1 bash < examples/taxii2/data-setup.sh Full Example with Compose diff --git a/opentaxii/auth/__init__.py b/opentaxii/auth/__init__.py index fa78ca89..7586727f 100644 --- a/opentaxii/auth/__init__.py +++ b/opentaxii/auth/__init__.py @@ -1,3 +1,8 @@ # flake8: noqa from .api import OpenTAXIIAuthAPI from .manager import AuthManager + +__all__ = ( + "AuthManager", + "OpenTAXIIAuthAPI", +) diff --git a/opentaxii/auth/sqldb/api.py b/opentaxii/auth/sqldb/api.py index 78da33fd..28fe3207 100644 --- a/opentaxii/auth/sqldb/api.py +++ b/opentaxii/auth/sqldb/api.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Optional import jwt import structlog @@ -10,40 +11,41 @@ from .models import Account, Base -__all__ = ['SQLDatabaseAPI'] +__all__ = ["SQLDatabaseAPI"] log = structlog.getLogger(__name__) class SQLDatabaseAPI(BaseSQLDatabaseAPI, OpenTAXIIAuthAPI): - """Naive SQL database implementation of OpenTAXII Auth API. - - Implementation will work with any DB supported by SQLAlchemy package. - - :param str db_connection: a string that indicates database dialect and - connection arguments that will be passed directly - to :func:`~sqlalchemy.engine.create_engine` method. - :param bool create_tables=False: if True, tables will be created in the DB. - :param str secret: secret string used for token generation - :param int token_ttl_secs: TTL for JWT token, in seconds. - :param engine_parameters=None: if defined, these arguments would be passed to sqlalchemy.create_engine - """ BASEMODEL = Base def __init__( self, - db_connection, - create_tables=False, - secret=None, - token_ttl_secs=None, + db_connection: str, + create_tables: bool = False, + secret: Optional[str] = None, + token_ttl_secs: Optional[int] = None, **engine_parameters, ): + """Naive SQL database implementation of OpenTAXII Auth API. + + Implementation will work with any DB supported by SQLAlchemy package. + + :param db_connection: a string that indicates database dialect and + connection arguments that will be passed directly to + :func:`~sqlalchemy.engine.create_engine` method. + :param create_tables=False: if True, tables will be created in the DB. + :param secret: secret string used for token generation + :param token_ttl_secs: TTL for JWT token, in seconds. + :param engine_parameters=None: if defined, these arguments would be passed + to sqlalchemy.create_engine + """ super().__init__(db_connection, create_tables, **engine_parameters) if not secret: raise ValueError( - 'Secret is not defined for %s.%s' + "Secret is not defined for %s.%s" % (self.__module__, self.__class__.__name__) ) self.secret = secret @@ -59,7 +61,9 @@ def authenticate(self, username, password): return self._generate_token(account.id, ttl=self.token_ttl_secs) def create_account(self, username, password, is_admin=False): - account = Account(username=username, is_admin=is_admin, permissions={}) + account = Account( # type: ignore[misc] + username=username, is_admin=is_admin, permissions={} + ) account.set_password(password) self.db.session.add(account) self.db.session.commit() diff --git a/opentaxii/auth/sqldb/models.py b/opentaxii/auth/sqldb/models.py index a296ee92..db39a30a 100644 --- a/opentaxii/auth/sqldb/models.py +++ b/opentaxii/auth/sqldb/models.py @@ -28,6 +28,7 @@ def set_password(self, password): self.password_hash = generate_password_hash(password) def is_password_valid(self, password): + assert self.password_hash is not None return check_password_hash(self.password_hash, password) @property diff --git a/opentaxii/common/sqldb.py b/opentaxii/common/sqldb.py index 2b66b33e..a77cb872 100644 --- a/opentaxii/common/sqldb.py +++ b/opentaxii/common/sqldb.py @@ -3,7 +3,7 @@ from opentaxii.sqldb_helper import SQLAlchemyDB try: - from sqlalchemy.orm import DeclarativeMeta + from sqlalchemy.orm import DeclarativeMeta # type: ignore[attr-defined] except ImportError: from sqlalchemy.ext.declarative import DeclarativeMeta diff --git a/opentaxii/config.py b/opentaxii/config.py index 5e257cf1..3e5357bd 100644 --- a/opentaxii/config.py +++ b/opentaxii/config.py @@ -112,7 +112,7 @@ def _get_env_config(env=os.environ, optional_env_var=None): @classmethod def _load_configs(cls, *configs): - result = dict() + result: dict = dict() for config in configs: # read content from path-like object if not isinstance(config, dict): diff --git a/opentaxii/middleware.py b/opentaxii/middleware.py index bd5636c5..16c6067e 100644 --- a/opentaxii/middleware.py +++ b/opentaxii/middleware.py @@ -25,7 +25,7 @@ def create_app(server): """ app = Flask(__name__) - app.taxii_server = server + app.taxii_server = server # type: ignore[attr-defined] server.init_app(app) diff --git a/opentaxii/persistence/__init__.py b/opentaxii/persistence/__init__.py index 1bcf8433..ad7157a0 100644 --- a/opentaxii/persistence/__init__.py +++ b/opentaxii/persistence/__init__.py @@ -1,7 +1,10 @@ # flake8: noqa from .api import OpenTAXII2PersistenceAPI, OpenTAXIIPersistenceAPI -from .manager import ( - BasePersistenceManager, - Taxii1PersistenceManager, - Taxii2PersistenceManager, +from .manager import Taxii1PersistenceManager, Taxii2PersistenceManager + +__all__ = ( + "OpenTAXII2PersistenceAPI", + "OpenTAXIIPersistenceAPI", + "Taxii1PersistenceManager", + "Taxii2PersistenceManager", ) diff --git a/opentaxii/persistence/api.py b/opentaxii/persistence/api.py index 0fc37af3..c9d869c5 100644 --- a/opentaxii/persistence/api.py +++ b/opentaxii/persistence/api.py @@ -1,4 +1,5 @@ import datetime +import uuid from typing import Dict, List, Optional, Tuple from opentaxii.taxii2.entities import ( @@ -34,6 +35,18 @@ def create_service(self, service_entity): """ raise NotImplementedError() + def update_service(self, obj): + """Update service. To implement in subclass""" + raise NotImplementedError() + + def delete_service(self, service_id): + """Delete service. To implement in subclass""" + raise NotImplementedError() + + def set_collection_services(self, collection_id, service_ids): + """Set collection's services. To implement in subclass""" + raise NotImplementedError() + def create_collection(self, collection_entity): """Create a collection. @@ -273,8 +286,11 @@ class OpenTAXII2PersistenceAPI: Stub, pending implementation. """ + def init_app(self, app): + pass + @staticmethod - def get_next_param(self, kwargs: Dict) -> str: + def get_next_param(kwargs: Dict) -> str: """ Get value for `next` based on :class:`Dict` instance. @@ -298,23 +314,25 @@ def get_api_roots(self) -> List[ApiRoot]: """ raise NotImplementedError - def get_api_root(self, api_root_id: str) -> Optional[ApiRoot]: + def get_api_root(self, api_root_id: uuid.UUID) -> Optional[ApiRoot]: raise NotImplementedError - def get_job_and_details(self, api_root_id: str, job_id: str) -> Optional[Job]: + def get_job_and_details( + self, api_root_id: uuid.UUID, job_id: uuid.UUID + ) -> Optional[Job]: raise NotImplementedError - def get_collections(self, api_root_id: str) -> List[Collection]: + def get_collections(self, api_root_id: uuid.UUID) -> List[Collection]: raise NotImplementedError def get_collection( - self, api_root_id: str, collection_id_or_alias: str + self, api_root_id: uuid.UUID, collection_id_or_alias: str ) -> Optional[Collection]: raise NotImplementedError def get_manifest( self, - collection_id: str, + collection_id: uuid.UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -327,7 +345,7 @@ def get_manifest( def get_objects( self, - collection_id: str, + collection_id: uuid.UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -339,13 +357,13 @@ def get_objects( raise NotImplementedError def add_objects( - self, api_root_id: str, collection_id: str, objects: List[Dict] + self, api_root_id: uuid.UUID, collection_id: uuid.UUID, objects: List[Dict] ) -> Job: raise NotImplementedError def get_object( self, - collection_id: str, + collection_id: uuid.UUID, object_id: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, @@ -362,7 +380,7 @@ def get_object( def delete_object( self, - collection_id: str, + collection_id: uuid.UUID, object_id: str, match_version: Optional[List[str]] = None, match_spec_version: Optional[List[str]] = None, @@ -371,13 +389,13 @@ def delete_object( def get_versions( self, - collection_id: str, + collection_id: uuid.UUID, object_id: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, match_spec_version: Optional[List[str]] = None, - ) -> Tuple[List[VersionRecord], bool]: + ) -> Tuple[Optional[List[VersionRecord]], bool]: """ Get all versions of single object from database. diff --git a/opentaxii/persistence/manager.py b/opentaxii/persistence/manager.py index 49dfe36a..67428a85 100644 --- a/opentaxii/persistence/manager.py +++ b/opentaxii/persistence/manager.py @@ -1,5 +1,6 @@ import datetime -from typing import Dict, List, Optional, Tuple +import uuid +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import structlog @@ -26,12 +27,15 @@ log = structlog.getLogger(__name__) +if TYPE_CHECKING: + from opentaxii.persistence.api import ( + OpenTAXII2PersistenceAPI, + OpenTAXIIPersistenceAPI, + ) + from opentaxii.server import TAXII1Server, TAXII2Server -class BasePersistenceManager: - pass - -class Taxii1PersistenceManager(BasePersistenceManager): +class Taxii1PersistenceManager: """Manager responsible for persisting and retrieving data. Manager uses API instance ``api`` for basic data CRUD operations and @@ -41,7 +45,7 @@ class Taxii1PersistenceManager(BasePersistenceManager): instance of persistence API class """ - def __init__(self, server, api): + def __init__(self, server: "TAXII1Server", api: "OpenTAXIIPersistenceAPI"): self.server = server self.api = api @@ -395,7 +399,7 @@ def delete_content_blocks( return count -class Taxii2PersistenceManager(BasePersistenceManager): +class Taxii2PersistenceManager: """Manager responsible for persisting and retrieving data. Manager uses API instance ``api`` for basic data CRUD operations and @@ -405,7 +409,7 @@ class Taxii2PersistenceManager(BasePersistenceManager): instance of persistence API class """ - def __init__(self, server, api): + def __init__(self, server: "TAXII2Server", api: "OpenTAXII2PersistenceAPI"): self.server = server self.api = api @@ -425,23 +429,23 @@ def get_api_roots(self) -> Tuple[Optional[ApiRoot], List[ApiRoot]]: break return (default_api_root, api_roots) - def get_api_root(self, api_root_id: str) -> ApiRoot: + def get_api_root(self, api_root_id: uuid.UUID) -> ApiRoot: api_root = self.api.get_api_root(api_root_id=api_root_id) if api_root is None: raise DoesNotExistError() return api_root - def get_job_and_details(self, api_root_id: str, job_id: str) -> Job: + def get_job_and_details(self, api_root_id: uuid.UUID, job_id: uuid.UUID) -> Job: job = self.api.get_job_and_details(api_root_id=api_root_id, job_id=job_id) if job is None: raise DoesNotExistError() return job - def get_collections(self, api_root_id: str) -> List[Collection]: + def get_collections(self, api_root_id: uuid.UUID) -> List[Collection]: return self.api.get_collections(api_root_id=api_root_id) def get_collection( - self, api_root_id: str, collection_id_or_alias: str + self, api_root_id: uuid.UUID, collection_id_or_alias: str ) -> Collection: collection = self.api.get_collection( api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias @@ -452,7 +456,7 @@ def get_collection( def get_manifest( self, - api_root_id: str, + api_root_id: uuid.UUID, collection_id_or_alias: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, @@ -480,7 +484,7 @@ def get_manifest( def get_objects( self, - api_root_id: str, + api_root_id: uuid.UUID, collection_id_or_alias: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, @@ -508,7 +512,7 @@ def get_objects( def add_objects( self, - api_root_id: str, + api_root_id: uuid.UUID, collection_id_or_alias: str, data: Dict, ) -> Job: @@ -526,7 +530,7 @@ def add_objects( def get_object( self, - api_root_id: str, + api_root_id: uuid.UUID, collection_id_or_alias: str, object_id: str, limit: Optional[int] = None, @@ -555,7 +559,7 @@ def get_object( def delete_object( self, - api_root_id: str, + api_root_id: uuid.UUID, collection_id_or_alias: str, object_id: str, match_version: Optional[List[str]] = None, @@ -583,7 +587,7 @@ def delete_object( def get_versions( self, - api_root_id: str, + api_root_id: uuid.UUID, collection_id_or_alias: str, object_id: str, limit: Optional[int] = None, diff --git a/opentaxii/persistence/sqldb/api.py b/opentaxii/persistence/sqldb/api.py index d41f6515..afefac33 100644 --- a/opentaxii/persistence/sqldb/api.py +++ b/opentaxii/persistence/sqldb/api.py @@ -3,7 +3,7 @@ import json import uuid from functools import reduce -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, no_type_check import six import structlog @@ -48,7 +48,8 @@ class SQLDatabaseAPI(BaseSQLDatabaseAPI, OpenTAXIIPersistenceAPI): :param bool create_tables=False: if True, tables will be created in the DB. - :param engine_parameters=None: if defined, these arguments would be passed to sqlalchemy.create_engine + :param engine_parameters=None: if defined, these arguments would be passed + to sqlalchemy.create_engine """ BASEMODEL = Base @@ -70,7 +71,9 @@ def update_service(self, obj): service.type = obj.type service.properties = obj.properties else: - service = Service(id=obj.id, type=obj.type, properties=obj.properties) + service = Service( # type: ignore[misc] + id=obj.id, type=obj.type, properties=obj.properties + ) self.db.session.add(service) self.db.session.commit() return conv.to_service_entity(service) @@ -441,7 +444,7 @@ def delete_content_blocks( content_blocks_query = ( self.db.session.query(ContentBlock.id) - .join(DataCollection.content_blocks) + .join(DataCollection.content_blocks) # type: ignore[attr-defined] .filter(DataCollection.id == collection.id) .filter(ContentBlock.timestamp_label > start_time) ) @@ -532,7 +535,7 @@ def get_api_roots(self) -> List[entities.ApiRoot]: for obj in query.all() ] - def get_api_root(self, api_root_id: str) -> Optional[entities.ApiRoot]: + def get_api_root(self, api_root_id: uuid.UUID) -> Optional[entities.ApiRoot]: api_root = ( self.db.session.query(taxii2models.ApiRoot) .filter(taxii2models.ApiRoot.id == api_root_id) @@ -553,7 +556,7 @@ def add_api_root( self, title: str, description: Optional[str] = None, - default: Optional[bool] = False, + default: bool = False, is_public: bool = False, api_root_id: Optional[str] = None, ) -> entities.ApiRoot: @@ -588,6 +591,7 @@ def add_api_root( is_public=is_public, ) + @no_type_check # taxii2models.Job has too many None allowance def _job_and_details_to_entity( self, job: taxii2models.Job, job_details: List[taxii2models.JobDetail] ) -> entities.Job: @@ -616,7 +620,7 @@ def _job_and_details_to_entity( return job_entity def get_job_and_details( - self, api_root_id: str, job_id: str + self, api_root_id: uuid.UUID, job_id: uuid.UUID ) -> Optional[entities.Job]: job = ( self.db.session.query(taxii2models.Job) @@ -647,7 +651,8 @@ def job_cleanup(self) -> int: """ return taxii2models.Job.cleanup(self.db.session) - def get_collections(self, api_root_id: str) -> List[entities.Collection]: + def get_collections(self, api_root_id: uuid.UUID) -> List[entities.Collection]: + """Get a list of collections from the database""" query = ( self.db.session.query(taxii2models.Collection) .filter(taxii2models.Collection.api_root_id == api_root_id) @@ -667,8 +672,9 @@ def get_collections(self, api_root_id: str) -> List[entities.Collection]: ] def get_collection( - self, api_root_id: str, collection_id_or_alias: str + self, api_root_id: uuid.UUID, collection_id_or_alias: str ) -> Optional[entities.Collection]: + """Get a collection from the database""" id_or_alias_filter = taxii2models.Collection.alias == collection_id_or_alias try: uuid.UUID(collection_id_or_alias) @@ -708,12 +714,14 @@ def add_collection( """ Add a new collection. - :param str api_root_id: ID of the api root the new collection is part of - :param str title: Title of the new collection - :param str description: [Optional] Description of the new collection - :param str alias: [Optional] Alias of the new collection - :param bool is_public: [Optional] Whether collection should be publicly readable - :param bool is_public_write: [Optional] Whether collection should be publicly writable + :param api_root_id: ID of the api root the new collection is part of + :param title: Title of the new collection + :param description: [Optional] Description of the new collection + :param alias: [Optional] Alias of the new collection + :param is_public: [Optional] Whether collection should be publicly + readable + :param is_public_write: [Optional] Whether collection should be + publicly writable :return: The added Collection entity. """ @@ -730,15 +738,15 @@ def add_collection( return entities.Collection( id=collection.id, - api_root_id=collection.api_root_id, + api_root_id=collection.api_root_id, # type: ignore[arg-type] title=collection.title, - description=collection.description, + description=collection.description, # type: ignore[arg-type] alias=collection.alias, is_public=collection.is_public, is_public_write=collection.is_public_write, ) - def _objects_query(self, collection_id: str, ordered: bool) -> Query: + def _objects_query(self, collection_id: uuid.UUID, ordered: bool) -> Query: query = self.db.session.query(taxii2models.STIXObject).filter( taxii2models.STIXObject.collection_id == collection_id, ) @@ -785,7 +793,7 @@ def _apply_match_type( def _apply_match_version( self, query: Query, - collection_id: str, + collection_id: uuid.UUID, match_version: Optional[List[str]] = None, ) -> Query: if match_version is None: @@ -874,7 +882,7 @@ def _apply_limit( def _filtered_objects_query( self, - collection_id: str, + collection_id: uuid.UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -882,7 +890,7 @@ def _filtered_objects_query( match_type: Optional[List[str]] = None, match_version: Optional[List[str]] = None, match_spec_version: Optional[List[str]] = None, - ordered: Optional[bool] = True, + ordered: bool = True, ) -> Tuple[Query, bool]: query = self._objects_query(collection_id, ordered) query = self._apply_added_after(query, added_after) @@ -896,7 +904,7 @@ def _filtered_objects_query( def get_manifest( self, - collection_id: str, + collection_id: uuid.UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -938,7 +946,7 @@ def get_manifest( def get_objects( self, - collection_id: str, + collection_id: uuid.UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -982,7 +990,7 @@ def get_objects( ) def add_objects( - self, api_root_id: str, collection_id: str, objects: List[Dict] + self, api_root_id: uuid.UUID, collection_id: uuid.UUID, objects: List[Dict] ) -> entities.Job: job = taxii2models.Job( api_root_id=api_root_id, @@ -1037,8 +1045,8 @@ def add_objects( ) job_details.append(job_detail) self.db.session.add(job_detail) - job.total_count += 1 - job.success_count += 1 + job.total_count += 1 # type: ignore[operator] + job.success_count += 1 # type: ignore[operator] job.status = "complete" job.completed_timestamp = datetime.datetime.now(datetime.timezone.utc) self.db.session.commit() @@ -1047,7 +1055,7 @@ def add_objects( def get_object( self, - collection_id: str, + collection_id: uuid.UUID, object_id: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, @@ -1108,7 +1116,7 @@ def get_object( def delete_object( self, - collection_id: str, + collection_id: uuid.UUID, object_id: str, match_version: Optional[List[str]] = None, match_spec_version: Optional[List[str]] = None, @@ -1126,13 +1134,13 @@ def delete_object( def get_versions( self, - collection_id: str, + collection_id: uuid.UUID, object_id: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, match_spec_version: Optional[List[str]] = None, - ) -> Tuple[List[entities.VersionRecord], bool]: + ) -> Tuple[Optional[List[entities.VersionRecord]], bool]: """ Get all versions of single object from database. diff --git a/opentaxii/server.py b/opentaxii/server.py index 87827e5b..e5b43cf9 100644 --- a/opentaxii/server.py +++ b/opentaxii/server.py @@ -1,13 +1,14 @@ import functools import importlib import json +import uuid try: from re import Pattern except ImportError: - from typing.re import Pattern + from typing.re import Pattern # type: ignore[no-redef] -from typing import Callable, ClassVar, NamedTuple, Optional, Tuple, Type +from typing import ClassVar, NamedTuple, Optional, Protocol, Tuple, Type, Union import structlog from flask import Flask, Response, request @@ -42,11 +43,7 @@ from .entities import Account from .exceptions import UnauthorizedException from .local import context -from .persistence import ( - BasePersistenceManager, - Taxii1PersistenceManager, - Taxii2PersistenceManager, -) +from .persistence import Taxii1PersistenceManager, Taxii2PersistenceManager from .taxii2.http import make_taxii2_response from .taxii.bindings import ALL_PROTOCOL_BINDINGS, MESSAGE_BINDINGS, SERVICE_BINDINGS from .taxii.exceptions import FailureStatus, StatusMessageException, raise_failure @@ -76,27 +73,39 @@ anonymous_full_access = Account(id=None, username=None, permissions={}, is_admin=True) +class EndpointFunc(Protocol): + registered_url_re: Pattern + registered_valid_methods: Tuple[str, ...] + registered_valid_accept_mimetypes: Tuple[str, ...] + registered_valid_content_types: Tuple[str, ...] + handles_own_auth: bool + + def __call__(self, **kwargs) -> Response: ... + + +class Endpoint(Protocol): + """The result of functools.partial""" + + func: EndpointFunc + server: "BaseTAXIIServer" + + def __call__(self) -> Response: ... + + class BaseTAXIIServer: """ Base class for common functionality in taxii* servers. """ - PERSISTENCE_MANAGER_CLASS: ClassVar[Type[BasePersistenceManager]] - ENDPOINT_MAPPING: Tuple[(Pattern, Callable[[], Response])] + ENDPOINT_MAPPING: Tuple[Tuple[Pattern, EndpointFunc], ...] app: Flask config: dict - - def __init__(self, config: dict): - self.config = config - self.persistence = self.PERSISTENCE_MANAGER_CLASS( - server=self, api=initialize_api(config["persistence_api"]) - ) - self.setup_endpoint_mapping() + persistence: Union[Taxii1PersistenceManager, Taxii2PersistenceManager] def setup_endpoint_mapping(self): mapping = [] for attr_name in self.__dir__(): - attr = getattr(self, attr_name) + attr: EndpointFunc = getattr(self, attr_name) if hasattr(attr, "registered_url_re"): mapping.append((attr.registered_url_re, attr)) if mapping: @@ -107,15 +116,9 @@ def init_app(self, app: Flask): self.app = app self.persistence.api.init_app(app) - def get_domain(self, service_id): - """Get domain either from request handler or config.""" - dynamic_domain = self.persistence.get_domain(service_id) - domain = dynamic_domain or self.config.get("domain") - return domain - - def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: + def get_endpoint(self, relative_path: str) -> Optional[Endpoint]: """Get first endpoint matching relative_path.""" - return + raise NotImplementedError def handle_internal_error(self, error): """ @@ -123,7 +126,7 @@ def handle_internal_error(self, error): Placeholder for subclasses to implement. """ - return + raise NotImplementedError def handle_status_exception(self, error): """ @@ -131,7 +134,7 @@ def handle_status_exception(self, error): Placeholder for subclasses to implement. """ - return + raise NotImplementedError def handle_http_exception(self, error): return error.get_response() @@ -142,7 +145,7 @@ def handle_validation_exception(self, error): Placeholder for subclasses to implement. """ - return + raise NotImplementedError def raise_unauthorized(self): """ @@ -169,10 +172,14 @@ class TAXII1Server(BaseTAXIIServer): "collection_management": CollectionManagementService, "poll": PollService, } - PERSISTENCE_MANAGER_CLASS = Taxii1PersistenceManager + persistence: Taxii1PersistenceManager def __init__(self, config: dict): - super().__init__(config) + self.config = config + self.persistence = Taxii1PersistenceManager( + server=self, api=initialize_api(config["persistence_api"]) + ) + self.setup_endpoint_mapping() signal_hooks = config["hooks"] if signal_hooks: importlib.import_module(signal_hooks) @@ -218,11 +225,21 @@ def check_allowed_methods(self): if request.method not in valid_methods: raise MethodNotAllowed(valid_methods=valid_methods) - def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: + def get_endpoint(self, relative_path: str) -> Optional[Endpoint]: """Get first endpoint matching relative_path.""" for endpoint in self.get_services(): if endpoint.path == relative_path: - return functools.partial(self.handle_request, endpoint=endpoint) + return functools.partial( # type: ignore[return-value] + self.handle_request, endpoint=endpoint + ) + + return None + + def get_domain(self, service_id): + """Get domain either from request handler or config.""" + dynamic_domain = self.persistence.get_domain(service_id) + domain = dynamic_domain or self.config.get("domain") + return domain def get_services(self, service_ids=None): """Get services registered with this TAXII server instance. @@ -285,11 +302,12 @@ def get_services_for_collection(self, collection, service_type): # Sync services for collection with registered services for this server return self.get_services(ids_for_type) - def handle_request(self, endpoint: TAXIIService): + def handle_request(self, endpoint: TAXIIService) -> Response: """ Handle request and return appropriate response. - Process :class:`TAXIIService` with either :meth:`_process_with_service` or :meth:`_process_options_request`. + Process :class:`TAXIIService` with either :meth:`_process_with_service` + or :meth:`_process_options_request`. """ self.check_allowed_methods() if endpoint.authentication_required and context.account is None: @@ -306,7 +324,7 @@ def handle_request(self, endpoint: TAXIIService): if request.method == "POST": return self._process_with_service(endpoint) - if request.method == "OPTIONS": + else: # OPTIONS return self._process_options_request(endpoint) @staticmethod @@ -332,7 +350,7 @@ def _process_with_service(cls, service) -> Response: if "application/xml" not in request.accept_mimetypes: raise_failure( "The specified values of Accept is not supported: {}".format( - ", ".join((request.accept_mimetypes or [])) + ", ".join((request.accept_mimetypes or [])) # type: ignore[arg-type] ) ) @@ -409,7 +427,14 @@ class TAXII2Server(BaseTAXIIServer): Stub, implementation pending. """ - PERSISTENCE_MANAGER_CLASS = Taxii2PersistenceManager + persistence: Taxii2PersistenceManager + + def __init__(self, config: dict): + self.config = config + self.persistence = Taxii2PersistenceManager( + server=self, api=initialize_api(config["persistence_api"]) + ) + self.setup_endpoint_mapping() def handle_http_exception(self, error): """Return JSON instead of HTML for HTTP errors.""" @@ -443,7 +468,7 @@ def raise_unauthorized(self): """ raise Unauthorized() - def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: + def get_endpoint(self, relative_path: str) -> Optional[Endpoint]: endpoint = None for regex, handler in self.ENDPOINT_MAPPING: match = regex.match(relative_path) @@ -451,9 +476,13 @@ def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: endpoint = functools.partial(handler, **match.groupdict()) break if endpoint: - return functools.partial(self.handle_request, endpoint) + return functools.partial( # type: ignore[return-value] + self.handle_request, endpoint # type: ignore[arg-type] + ) + + return None - def check_authentication(self, endpoint: Callable[[], Response]): + def check_authentication(self, endpoint: Endpoint): """Check if account is authenticated, unless endpoint handles that itself.""" if endpoint.func.handles_own_auth: # Endpoint will handle auth checks itself @@ -469,11 +498,11 @@ def check_content_length(self): ]: # untestable with flask raise RequestEntityTooLarge() - def check_headers(self, endpoint: Callable[[], Response]): + def check_headers(self, endpoint: Endpoint): if not any( [ - valid_accept_mimetype in request.accept_mimetypes - for valid_accept_mimetype in endpoint.func.registered_valid_accept_mimetypes + accept_mimetype in request.accept_mimetypes + for accept_mimetype in endpoint.func.registered_valid_accept_mimetypes ] ): raise NotAcceptable() @@ -483,11 +512,11 @@ def check_headers(self, endpoint: Callable[[], Response]): ): raise UnsupportedMediaType() - def check_allowed_methods(self, endpoint: Callable[[], Response]): + def check_allowed_methods(self, endpoint: Endpoint): if request.method not in endpoint.func.registered_valid_methods: raise MethodNotAllowed(valid_methods=endpoint.func.registered_valid_methods) - def handle_request(self, endpoint: Callable[[], Response]): + def handle_request(self, endpoint: Endpoint) -> Response: self.check_authentication(endpoint) self.check_content_length() self.check_allowed_methods(endpoint) @@ -511,9 +540,9 @@ def discovery_handler(self): return make_taxii2_response(response) @register_handler(r"^/taxii2/(?P[^/]+)/$", handles_own_auth=True) - def api_root_handler(self, api_root_id): + def api_root_handler(self, api_root_id: str): try: - api_root = self.persistence.get_api_root(api_root_id=api_root_id) + api_root = self.persistence.get_api_root(api_root_id=uuid.UUID(api_root_id)) except DoesNotExistError: if context.account is None: raise Unauthorized() @@ -533,9 +562,10 @@ def api_root_handler(self, api_root_id): r"^/taxii2/(?P[^/]+)/status/(?P[^/]+)/$", handles_own_auth=True, ) - def job_handler(self, api_root_id, job_id): + def job_handler(self, api_root_id: str, job_id: str): + api_root_uuid = uuid.UUID(api_root_id) try: - api_root = self.persistence.get_api_root(api_root_id=api_root_id) + api_root = self.persistence.get_api_root(api_root_id=api_root_uuid) except DoesNotExistError: if context.account is None: raise Unauthorized() @@ -544,7 +574,7 @@ def job_handler(self, api_root_id, job_id): raise Unauthorized() try: job = self.persistence.get_job_and_details( - api_root_id=api_root_id, job_id=job_id + api_root_id=api_root_uuid, job_id=uuid.UUID(job_id) ) except DoesNotExistError: raise NotFound() @@ -554,17 +584,18 @@ def job_handler(self, api_root_id, job_id): @register_handler( r"^/taxii2/(?P[^/]+)/collections/$", handles_own_auth=True ) - def collections_handler(self, api_root_id): + def collections_handler(self, api_root_id: str): + api_root_uuid = uuid.UUID(api_root_id) try: - api_root = self.persistence.get_api_root(api_root_id=api_root_id) + api_root = self.persistence.get_api_root(api_root_id=api_root_uuid) except DoesNotExistError: if context.account is None: raise Unauthorized() raise NotFound() if context.account is None and not api_root.is_public: raise Unauthorized() - collections = self.persistence.get_collections(api_root_id=api_root_id) - response = {} + collections = self.persistence.get_collections(api_root_id=api_root_uuid) + response: dict = {} if collections: response["collections"] = [] for collection in collections: @@ -583,13 +614,15 @@ def collections_handler(self, api_root_id): return make_taxii2_response(response) @register_handler( - r"^/taxii2/(?P[^/]+)/collections/(?P[^/]+)/$", + r"^/taxii2/(?P[^/]+)" + r"/collections/(?P[^/]+)/$", handles_own_auth=True, ) - def collection_handler(self, api_root_id, collection_id_or_alias): + def collection_handler(self, api_root_id: str, collection_id_or_alias: str): try: collection = self.persistence.get_collection( - api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + api_root_id=uuid.UUID(api_root_id), + collection_id_or_alias=collection_id_or_alias, ) except DoesNotExistError: if context.account is None: @@ -614,14 +647,15 @@ def collection_handler(self, api_root_id, collection_id_or_alias): return make_taxii2_response(response) @register_handler( - r"^/taxii2/(?P[^/]+)/collections/(?P[^/]+)/manifest/$", + r"^/taxii2/(?P[^/]+)" + r"/collections/(?P[^/]+)/manifest/$", handles_own_auth=True, ) - def manifest_handler(self, api_root_id, collection_id_or_alias): + def manifest_handler(self, api_root_id: str, collection_id_or_alias: str): filter_params = validate_list_filter_params(request.args, self.persistence.api) try: manifest, more = self.persistence.get_manifest( - api_root_id=api_root_id, + api_root_id=uuid.UUID(api_root_id), collection_id_or_alias=collection_id_or_alias, **filter_params, ) @@ -630,14 +664,15 @@ def manifest_handler(self, api_root_id, collection_id_or_alias): raise Unauthorized() raise NotFound() if manifest: - response = { + response: dict = { "more": more, "objects": [ { "id": obj.id, "date_added": taxii2_datetimeformat(obj.date_added), "version": taxii2_datetimeformat(obj.version), - "media_type": f"application/stix+json;version={obj.spec_version}", + "media_type": "application/stix+json;version=" + + obj.spec_version, } for obj in manifest ], @@ -659,18 +694,20 @@ def manifest_handler(self, api_root_id, collection_id_or_alias): ) @register_handler( - r"^/taxii2/(?P[^/]+)/collections/(?P[^/]+)/objects/$", + r"^/taxii2/(?P[^/]+)" + r"/collections/(?P[^/]+)/objects/$", ("GET", "POST"), valid_content_types=("application/taxii+json;version=2.1",), handles_own_auth=True, ) - def objects_handler(self, api_root_id, collection_id_or_alias): + def objects_handler(self, api_root_id: str, collection_id_or_alias: str): + api_root_uuid = uuid.UUID(api_root_id) if request.method == "GET": - return self.objects_get_handler(api_root_id, collection_id_or_alias) + return self.objects_get_handler(api_root_uuid, collection_id_or_alias) if request.method == "POST": - return self.objects_post_handler(api_root_id, collection_id_or_alias) + return self.objects_post_handler(api_root_uuid, collection_id_or_alias) - def objects_get_handler(self, api_root_id, collection_id_or_alias): + def objects_get_handler(self, api_root_id: uuid.UUID, collection_id_or_alias: str): filter_params = validate_list_filter_params(request.args, self.persistence.api) try: objects, more, next_param = self.persistence.get_objects( @@ -713,7 +750,7 @@ def objects_get_handler(self, api_root_id, collection_id_or_alias): extra_headers=headers, ) - def objects_post_handler(self, api_root_id, collection_id_or_alias): + def objects_post_handler(self, api_root_id: uuid.UUID, collection_id_or_alias: str): validate_envelope( request.data, allow_custom=self.config.get("allow_custom_properties", True) ) @@ -728,7 +765,7 @@ def objects_post_handler(self, api_root_id, collection_id_or_alias): raise Unauthorized() raise NotFound() response = job.as_taxii2_dict() - headers = {} + headers: dict = {} return make_taxii2_response( response, 202, @@ -736,21 +773,28 @@ def objects_post_handler(self, api_root_id, collection_id_or_alias): ) @register_handler( - r"^/taxii2/(?P[^/]+)/collections/(?P[^/]+)/objects/(?P[^/]+)/$", + r"^/taxii2/(?P[^/]+)" + r"/collections/(?P[^/]+)" + r"/objects/(?P[^/]+)/$", ("GET", "DELETE"), handles_own_auth=True, ) - def object_handler(self, api_root_id, collection_id_or_alias, object_id): + def object_handler( + self, api_root_id: str, collection_id_or_alias: str, object_id: str + ): + api_root_uuid = uuid.UUID(api_root_id) if request.method == "GET": return self.object_get_handler( - api_root_id, collection_id_or_alias, object_id + api_root_uuid, collection_id_or_alias, object_id ) if request.method == "DELETE": return self.object_delete_handler( - api_root_id, collection_id_or_alias, object_id + api_root_uuid, collection_id_or_alias, object_id ) - def object_get_handler(self, api_root_id, collection_id_or_alias, object_id): + def object_get_handler( + self, api_root_id: uuid.UUID, collection_id_or_alias: str, object_id: str + ): filter_params = validate_object_filter_params( request.args, self.persistence.api ) @@ -796,7 +840,9 @@ def object_get_handler(self, api_root_id, collection_id_or_alias, object_id): extra_headers=headers, ) - def object_delete_handler(self, api_root_id, collection_id_or_alias, object_id): + def object_delete_handler( + self, api_root_id: uuid.UUID, collection_id_or_alias: str, object_id: str + ): filter_params = validate_delete_filter_params(request.args) try: self.persistence.delete_object( @@ -817,18 +863,21 @@ def object_delete_handler(self, api_root_id, collection_id_or_alias, object_id): @register_handler( ( - r"^/taxii2/(?P[^/]+)/collections/(?P[^/]+)" + r"^/taxii2/(?P[^/]+)" + r"/collections/(?P[^/]+)" r"/objects/(?P[^/]+)/versions/$" ), handles_own_auth=True, ) - def versions_handler(self, api_root_id, collection_id_or_alias, object_id): + def versions_handler( + self, api_root_id: str, collection_id_or_alias: str, object_id: str + ): filter_params = validate_versions_filter_params( request.args, self.persistence.api ) try: versions, more = self.persistence.get_versions( - api_root_id=api_root_id, + api_root_id=uuid.UUID(api_root_id), collection_id_or_alias=collection_id_or_alias, object_id=object_id, **filter_params, @@ -882,7 +931,7 @@ class TAXIIServer: def __init__(self, config: ServerConfig): self.config = config - servers_kwargs = { + servers_kwargs: dict = { "taxii1": None, "taxii2": None, } @@ -925,7 +974,7 @@ def is_basic_auth_supported(self): """Check if basic auth is a supported feature.""" return self.config.get("support_basic_auth", False) - def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: + def get_endpoint(self, relative_path: str) -> Optional[Endpoint]: """Get first endpoint matching relative_path.""" for server in self.real_servers: endpoint = server.get_endpoint(relative_path) @@ -933,6 +982,8 @@ def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: endpoint.server = server return endpoint + return None + def handle_request(self, relative_path: str) -> Response: """Dispatch request to appropriate taxii* server.""" relative_path = "/" + relative_path @@ -972,6 +1023,7 @@ def raise_unauthorized(self): if endpoint: server = endpoint.server else: + assert self.servers.taxii1 is not None server = self.servers.taxii1 context.taxiiserver = server return server.raise_unauthorized() diff --git a/opentaxii/taxii/services/__init__.py b/opentaxii/taxii/services/__init__.py index c8128ae9..860c87ac 100644 --- a/opentaxii/taxii/services/__init__.py +++ b/opentaxii/taxii/services/__init__.py @@ -4,3 +4,10 @@ from .discovery import DiscoveryService from .inbox import InboxService from .poll import PollService + +__all__ = [ + "InboxService", + "DiscoveryService", + "CollectionManagementService", + "PollService", +] diff --git a/opentaxii/taxii2/entities.py b/opentaxii/taxii2/entities.py index 2da976c1..3b379726 100644 --- a/opentaxii/taxii2/entities.py +++ b/opentaxii/taxii2/entities.py @@ -1,5 +1,6 @@ """Taxii2 entities.""" +import uuid from datetime import datetime from typing import List, NamedTuple, Optional @@ -12,15 +13,20 @@ class ApiRoot(Entity): """ TAXII2 API Root entity. - :param str id: id of this API root - :param bool default: indicator of default api root, should only be True once - :param str title: human readable plain text name used to identify this API Root - :param str description: human readable plain text description for this API Root - :param bool is_public: whether this is a publicly readable API root + :param id: id of this API root + :param default: indicator of default api root, should only be True once + :param title: human readable plain text name used to identify this API Root + :param description: human readable plain text description for this API Root + :param is_public: whether this is a publicly readable API root """ def __init__( - self, id: str, default: bool, title: str, description: str, is_public: bool + self, + id: uuid.UUID, + default: bool, + title: str, + description: Optional[str], + is_public: bool, ): """Initialize ApiRoot.""" self.id = id @@ -34,22 +40,25 @@ class Collection(Entity): """ TAXII2 Collection entity. - :param str id: id of this collection - :param str api_root_id: id of the :class:`ApiRoot` this collection belongs to - :param str title: human readable plain text name used to identify this collection - :param str description: human readable plain text description for this collection - :param str alias: human readable collection name that can be used on systems to alias a collection id - :param bool is_public: whether this is a publicly readable collection - :param bool is_public_write: whether this is a publicly writable collection + :param id: id of this collection + :param api_root_id: id of the :class:`ApiRoot` this collection belongs to + :param title: human readable plain text name used to identify this + collection + :param description: human readable plain text description for this + collection + :param alias: human readable collection name that can be used on systems to + alias a collection id + :param is_public: whether this is a publicly readable collection + :param is_public_write: whether this is a publicly writable collection """ def __init__( self, - id: str, - api_root_id: str, + id: uuid.UUID, + api_root_id: uuid.UUID, title: str, description: str, - alias: str, + alias: Optional[str], is_public: bool, is_public_write: bool, ): @@ -85,19 +94,19 @@ class STIXObject(Entity): """ TAXII2 STIXObject entity. - :param str id: id of this stix object - :param str collection_id: id of the :class:`Collection` this stix object belongs to - :param str type: type of this stix object - :param str spec_version: stix version this object matches - :param datetime date_added: the date and time this object was added - :param datetime version: the version of this object - :param dict serialized_data: the payload of this object + :param id: id of this stix object + :param collection_id: id of the :class:`Collection` this stix object belongs to + :param type: type of this stix object + :param spec_version: stix version this object matches + :param date_added: the date and time this object was added + :param version: the version of this object + :param serialized_data: the payload of this object """ def __init__( self, id: str, - collection_id: str, + collection_id: uuid.UUID, type: str, spec_version: str, date_added: datetime, @@ -120,10 +129,10 @@ class ManifestRecord(Entity): This is a cut-down version of :class:`STIXObject`, for efficiency. - :param str id: id of this stix object - :param datetime date_added: the date and time this object was added - :param datetime version: the version of this object - :param str spec_version: stix version this object matches + :param id: id of this stix object + :param date_added: the date and time this object was added + :param version: the version of this object + :param spec_version: stix version this object matches """ def __init__( @@ -164,19 +173,20 @@ class JobDetail(Entity): """ TAXII2 JobDetail entity, part of "status resource" in taxii2 docs. - :param str id: id of this job detail - :param str job_id: id of the job this detail belongs to - :param str stix_id: id of the :class:`STIXObject` this detail tracks - :param datetime version: the version of this object - :param str message: message indicating more information about the object being created, - its pending state, or why the object failed to be created. - :param str status: status of this job + :param id: id of this job detail + :param job_id: id of the job this detail belongs to + :param stix_id: id of the :class:`STIXObject` this detail tracks + :param version: the version of this object + :param message: message indicating more information about the object + being created, its pending state, or why the object failed to be + created. + :param status: status of this job """ def __init__( self, - id: str, - job_id: str, + id: uuid.UUID, + job_id: uuid.UUID, stix_id: str, version: datetime, message: str, @@ -208,22 +218,24 @@ class Job(Entity): """ TAXII2 Job entity, called a "status resource" in taxii2 docs. - :param str id: id of this job - :param str api_root_id: id of the :class:`ApiRoot` this collection belongs to - :param str status: status of this job - :param datetime request_timestamp: the datetime of the request that this status resource is monitoring - :param datetime completed_timestamp: the datetime of the completion of this job (used for cleanup) - :param int total_count: the total number of stix objects in this job - :param int success_count: the number of successful stix objects in this job - :param int failure_count: the number of failed stix objects in this job - :param int pending_count: the number of pending stix objects in this job - :param dict details: the details per status of this job + :param id: id of this job + :param api_root_id: id of the :class:`ApiRoot` this collection belongs to + :param status: status of this job + :param request_timestamp: the datetime of the request that this status + resource is monitoring + :param completed_timestamp: the datetime of the completion of this job (used + for cleanup) + :param total_count: the total number of stix objects in this job + :param success_count: the number of successful stix objects in this job + :param failure_count: the number of failed stix objects in this job + :param pending_count: the number of pending stix objects in this job + :param details: the details per status of this job """ def __init__( self, - id: str, - api_root_id: str, + id: uuid.UUID, + api_root_id: uuid.UUID, status: str, request_timestamp: datetime, completed_timestamp: Optional[datetime] = None, diff --git a/opentaxii/taxii2/http.py b/opentaxii/taxii2/http.py index eb0a82ba..5af5d060 100644 --- a/opentaxii/taxii2/http.py +++ b/opentaxii/taxii2/http.py @@ -1,17 +1,25 @@ """Taxii2 http helper functions.""" import json +import uuid from typing import Dict, Optional from flask import Response, make_response +class UuidJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, uuid.UUID): + return str(obj) + return super().default(obj) + + def make_taxii2_response( data, status: Optional[int] = 200, extra_headers: Optional[Dict] = None ) -> Response: """Turn input data into valid taxii2 response.""" if not isinstance(data, str): - data = json.dumps(data) + data = json.dumps(data, cls=UuidJSONEncoder) response = make_response((data, status)) response.content_type = "application/taxii+json;version=2.1" response.headers.update(extra_headers or {}) diff --git a/opentaxii/taxii2/validation.py b/opentaxii/taxii2/validation.py index 4c7c6903..36d81ef2 100644 --- a/opentaxii/taxii2/validation.py +++ b/opentaxii/taxii2/validation.py @@ -2,23 +2,23 @@ import datetime import json +from typing import Mapping, Union from marshmallow import Schema, fields from stix2 import parse from stix2.exceptions import STIXError -from werkzeug.datastructures import ImmutableMultiDict from opentaxii.persistence.api import OpenTAXII2PersistenceAPI from opentaxii.taxii2.exceptions import ValidationError from opentaxii.taxii2.utils import DATETIMEFORMAT -def validate_envelope(json_data: str, allow_custom: bool = False) -> None: +def validate_envelope(json_data: Union[str, bytes], allow_custom: bool = False) -> None: """ Validate if ``json_data`` is a valid taxii2 envelope. - :param str json_data: the data to check - :param bool allow_custom: if true, allow non-standard stix types + :param json_data: the data to check + :param allow_custom: if true, allow non-standard stix types """ if not json_data: raise ValidationError("No data") @@ -53,7 +53,11 @@ class Taxii2Next(fields.Field): def _deserialize(self, value, attr, data, **kwargs): value = super()._deserialize(value, attr, data, **kwargs) try: - value = self.parent.persistence_api.parse_next_param(value) + value = ( + self.parent.persistence_api.parse_next_param( # type:ignore[union-attr] + value + ) + ) except: # noqa raise ValidationError("Not a valid value.") return value @@ -132,7 +136,7 @@ class DeleteFilterParamsSchema(Schema): def validate_object_filter_params( - filter_params: ImmutableMultiDict, persistence_api: OpenTAXII2PersistenceAPI + filter_params: Mapping, persistence_api: OpenTAXII2PersistenceAPI ) -> dict: """Validate and load filter params for the object endpoint.""" parsed_params = ObjectFilterParamsSchema(persistence_api).load(filter_params) @@ -140,7 +144,7 @@ def validate_object_filter_params( def validate_list_filter_params( - filter_params: ImmutableMultiDict, persistence_api: OpenTAXII2PersistenceAPI + filter_params: Mapping, persistence_api: OpenTAXII2PersistenceAPI ) -> dict: """Validate and load filter params for the list endpoint.""" parsed_params = ListFilterParamsSchema(persistence_api).load(filter_params) @@ -148,14 +152,14 @@ def validate_list_filter_params( def validate_versions_filter_params( - filter_params: ImmutableMultiDict, persistence_api: OpenTAXII2PersistenceAPI + filter_params: Mapping, persistence_api: OpenTAXII2PersistenceAPI ) -> dict: """Validate and load filter params for the versions endpoint.""" parsed_params = VersionFilterParamsSchema(persistence_api).load(filter_params) return parsed_params -def validate_delete_filter_params(filter_params: ImmutableMultiDict) -> dict: +def validate_delete_filter_params(filter_params: Mapping) -> dict: """Validate and load filter params for the delete endpoint.""" parsed_params = DeleteFilterParamsSchema().load(filter_params) return parsed_params diff --git a/opentaxii/utils.py b/opentaxii/utils.py index 43e2e45a..dc0ea7e8 100644 --- a/opentaxii/utils.py +++ b/opentaxii/utils.py @@ -313,20 +313,21 @@ def sync_accounts(server, accounts): def register_handler( url_re: str, - valid_methods: Optional[Tuple[str]] = None, - valid_accept_mimetypes: Optional[Tuple[str]] = None, - valid_content_types: Optional[Tuple[str]] = None, + valid_methods: Optional[Tuple[str, ...]] = None, + valid_accept_mimetypes: Optional[Tuple[str, ...]] = None, + valid_content_types: Optional[Tuple[str, ...]] = None, handles_own_auth: bool = False, ): """ Register decorated method as handler function for `url_re`. - :param str url_re: The regex to trigger the handler on - :param list valid_methods: The list of methods to accept for this handler, defaults to ("GET",) - :param list valid_accept_mimetypes: + :param url_re: The regex to trigger the handler on + :param valid_methods: The list of methods to accept for this handler, + defaults to ("GET",) + :param valid_accept_mimetypes: The list of accepted mimetypes to accept for this handler, defaults to ("application/taxii+json;version=2.1",) - :param list valid_content_types: + :param valid_content_types: The list of content types to accept for this handler, defaults to ("application/json",) """ diff --git a/pyproject.toml b/pyproject.toml index 439585a9..f3e0451b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,31 @@ omit-covered-files = true verbose = 0 exclude = ["tests"] +[tool.mypy] +plugins = ["sqlmypy"] + +ignore_missing_imports = true +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true +check_untyped_defs = true +follow_imports = "silent" + +files = [ + "opentaxii/", + "tests/", +] + +exclude = """ +(?x)( + ^opentaxii/taxii/.+ + | ^opentaxii/utils\\.py$ + | ^opentaxii/common/entities\\.py$ +) +""" + [tool.black] line-length = 88 skip_string_normalization = true diff --git a/requirements-dev-mysql.txt b/requirements-dev-mysql.txt deleted file mode 100644 index ae2018da..00000000 --- a/requirements-dev-mysql.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements-dev.txt -mysqlclient>=2.0.3 diff --git a/requirements-dev-postgres-pypy.txt b/requirements-dev-postgres-pypy.txt deleted file mode 100644 index 58317d38..00000000 --- a/requirements-dev-postgres-pypy.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements-dev.txt -psycopg2cffi>=2.9.0 diff --git a/requirements-dev-postgres.txt b/requirements-dev-postgres.txt deleted file mode 100644 index 452418fc..00000000 --- a/requirements-dev-postgres.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements-dev.txt -psycopg2-binary>=2.9.1 diff --git a/requirements-dev.txt b/requirements-dev.txt index f4f574a3..6444a695 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,12 @@ -r requirements.txt -pytest-cov -pytest>=4.6 -pytest-pythonpath +-r requirements-test.txt flake8 ipdb +mypy==1.19.0 +types-six==1.17.0.20251009 +sqlalchemy-stubs==0.4 +types-PyYAML==6.0.12.20250915 +types-pytz==2025.2.0.20251108 factory-boy>=3.2.1 black==25.11.0 isort==7.0.0 diff --git a/requirements-mysql.txt b/requirements-mysql.txt new file mode 100644 index 00000000..c9dc3f58 --- /dev/null +++ b/requirements-mysql.txt @@ -0,0 +1 @@ +mysqlclient>=2.0.3 diff --git a/requirements-postgres-pypy.txt b/requirements-postgres-pypy.txt new file mode 100644 index 00000000..102e1160 --- /dev/null +++ b/requirements-postgres-pypy.txt @@ -0,0 +1 @@ +psycopg2cffi>=2.9.0 diff --git a/requirements-postgres.txt b/requirements-postgres.txt new file mode 100644 index 00000000..1197662d --- /dev/null +++ b/requirements-postgres.txt @@ -0,0 +1 @@ +psycopg2-binary>=2.9.1 diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..3c94f223 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +pytest-cov +pytest>=4.6 +pytest-pythonpath +factory-boy>=3.2.1 diff --git a/requirements.txt b/requirements.txt index 7cdd2555..391a2b83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ libtaxii>=1.1.111 lxml>=4.3.1 pyyaml>=3.11 flask>=0.10.1 -sqlalchemy>=1.1.2 +sqlalchemy>=1.4 structlog>=18.1.0 blinker>=1.4 pyjwt>=1.4.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index e9ef3054..3f155ab5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -232,8 +232,8 @@ def truncate_app(dbconn): app = create_app(context.server) app.config["TESTING"] = True yield app - taxiiserver.servers.taxii1.persistence.api.db.engine.dispose() - taxiiserver.servers.taxii2.persistence.api.db.engine.dispose() + taxiiserver.servers.taxii1.persistence.api.db.engine.dispose() # type: ignore[union-attr] + taxiiserver.servers.taxii2.persistence.api.db.engine.dispose() # type: ignore[union-attr] @pytest.fixture() diff --git a/tests/fixtures.py b/tests/fixtures.py index c0246caa..17958341 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -10,7 +10,7 @@ CUSTOM_CONTENT_BINDING = 'custom:content:binding' INVALID_CONTENT_BINDING = 'invalid:content:binding' -INBOX_A = dict( +INBOX_A: dict = dict( id='inbox-A', type='inbox', description='inbox-A description', @@ -20,7 +20,7 @@ protocol_bindings=PROTOCOL_BINDINGS, ) -INBOX_B = dict( +INBOX_B: dict = dict( id='inbox-B', type='inbox', description='inbox-B description', @@ -30,7 +30,7 @@ protocol_bindings=PROTOCOL_BINDINGS, ) -DISCOVERY_A = dict( +DISCOVERY_A: dict = dict( id='discovery-A', type='discovery', description='discovery-A description', @@ -46,7 +46,7 @@ protocol_bindings=PROTOCOL_BINDINGS, ) -DISCOVERY_B = dict( +DISCOVERY_B: dict = dict( id='discovery-B', type='discovery', description='External discovery-B service', @@ -56,7 +56,7 @@ SUBSCRIPTION_MESSAGE = 'message about subscription' -COLLECTION_MANAGEMENT = dict( +COLLECTION_MANAGEMENT: dict = dict( id='collection-management-A', type='collection_management', description='Collection management description', @@ -68,7 +68,7 @@ POLL_RESULT_SIZE = 20 POLL_MAX_COUNT = 15 -POLL = dict( +POLL: dict = dict( id='poll-A', type='poll', description='Poll service description', diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/services/test_collection_management.py b/tests/services/test_collection_management.py index 0ebb065b..22323fc4 100644 --- a/tests/services/test_collection_management.py +++ b/tests/services/test_collection_management.py @@ -1,5 +1,8 @@ import pytest -from fixtures import ( + +from opentaxii.taxii import entities + +from ..fixtures import ( COLLECTION_DISABLED, COLLECTION_ONLY_STIX, COLLECTION_OPEN, @@ -8,9 +11,7 @@ MESSAGE_ID, SERVICES, ) -from utils import as_tm, persist_content, prepare_headers - -from opentaxii.taxii import entities +from ..utils import as_tm, persist_content, prepare_headers ASSIGNED_SERVICES = ['collection-management-A', 'inbox-A', 'inbox-B', 'poll-A'] diff --git a/tests/services/test_discovery.py b/tests/services/test_discovery.py index ee472bda..f09b8c41 100644 --- a/tests/services/test_discovery.py +++ b/tests/services/test_discovery.py @@ -1,7 +1,8 @@ import pytest -from fixtures import INBOX_A, INBOX_B, INSTANCES_CONFIGURED, MESSAGE_ID from libtaxii.constants import SVC_INBOX -from utils import as_tm, prepare_headers + +from ..fixtures import INBOX_A, INBOX_B, INSTANCES_CONFIGURED, MESSAGE_ID +from ..utils import as_tm, prepare_headers @pytest.fixture(autouse=True) diff --git a/tests/services/test_inbox.py b/tests/services/test_inbox.py index 3370c53c..e042251a 100644 --- a/tests/services/test_inbox.py +++ b/tests/services/test_inbox.py @@ -1,5 +1,11 @@ import pytest -from fixtures import ( +from libtaxii import messages_10 as tm10 +from libtaxii import messages_11 as tm11 +from libtaxii.constants import CB_STIX_XML_111, ST_SUCCESS + +from opentaxii.taxii import exceptions + +from ..fixtures import ( COLLECTION_ONLY_STIX, COLLECTION_OPEN, COLLECTIONS_A, @@ -10,12 +16,7 @@ INVALID_CONTENT_BINDING, MESSAGE_ID, ) -from libtaxii import messages_10 as tm10 -from libtaxii import messages_11 as tm11 -from libtaxii.constants import CB_STIX_XML_111, ST_SUCCESS -from utils import as_tm, prepare_headers - -from opentaxii.taxii import exceptions +from ..utils import as_tm, prepare_headers def make_content( @@ -70,7 +71,7 @@ def prepare_server(server, services): .filter_by(name=coll.name) .one() ) - service_ids = {s.id for s in coll.services} | {service} + service_ids = {s.id for s in coll.services} | {service} # type:ignore server.servers.taxii1.persistence.set_collection_services( coll.id, service_ids=service_ids diff --git a/tests/services/test_poll.py b/tests/services/test_poll.py index 38a620c5..6f7bb59a 100644 --- a/tests/services/test_poll.py +++ b/tests/services/test_poll.py @@ -1,5 +1,11 @@ import pytest -from fixtures import ( +from libtaxii import messages_10 as tm10 +from libtaxii import messages_11 as tm11 +from libtaxii.constants import ACT_SUBSCRIBE, CB_STIX_XML_111, RT_COUNT_ONLY, RT_FULL + +from opentaxii.taxii import exceptions + +from ..fixtures import ( COLLECTION_DISABLED, COLLECTION_ONLY_STIX, COLLECTION_OPEN, @@ -10,12 +16,12 @@ POLL_MAX_COUNT, POLL_RESULT_SIZE, ) -from libtaxii import messages_10 as tm10 -from libtaxii import messages_11 as tm11 -from libtaxii.constants import ACT_SUBSCRIBE, CB_STIX_XML_111, RT_COUNT_ONLY, RT_FULL -from utils import as_tm, persist_content, prepare_headers, prepare_subscription_request - -from opentaxii.taxii import exceptions +from ..utils import ( + as_tm, + persist_content, + prepare_headers, + prepare_subscription_request, +) @pytest.fixture(autouse=True) diff --git a/tests/services/test_subscription_management.py b/tests/services/test_subscription_management.py index de2c8643..fa493e86 100644 --- a/tests/services/test_subscription_management.py +++ b/tests/services/test_subscription_management.py @@ -1,10 +1,4 @@ import pytest -from fixtures import ( - COLLECTION_OPEN, - COLLECTIONS_B, - CUSTOM_CONTENT_BINDING, - SUBSCRIPTION_MESSAGE, -) from libtaxii.constants import ( ACT_PAUSE, ACT_RESUME, @@ -16,11 +10,18 @@ SS_PAUSED, SS_UNSUBSCRIBED, ) -from utils import as_tm, prepare_headers -from utils import prepare_subscription_request as prepare_request from opentaxii.taxii import exceptions +from ..fixtures import ( + COLLECTION_OPEN, + COLLECTIONS_B, + CUSTOM_CONTENT_BINDING, + SUBSCRIPTION_MESSAGE, +) +from ..utils import as_tm, prepare_headers +from ..utils import prepare_subscription_request as prepare_request + ASSIGNED_SERVICES = ['collection-management-A', 'poll-A'] diff --git a/tests/taxii2/__init__.py b/tests/taxii2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/taxii2/test_taxii2_collection.py b/tests/taxii2/test_taxii2_collection.py index fef193b9..eba5c7d2 100644 --- a/tests/taxii2/test_taxii2_collection.py +++ b/tests/taxii2/test_taxii2_collection.py @@ -24,7 +24,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": COLLECTIONS[0].id, + "id": str(COLLECTIONS[0].id), "title": "0Read only", "description": "Read only description", "can_read": True, @@ -41,7 +41,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": COLLECTIONS[4].id, + "id": str(COLLECTIONS[4].id), "title": "4No description", "can_read": True, "can_write": True, @@ -57,7 +57,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": COLLECTIONS[5].id, + "id": str(COLLECTIONS[5].id), "title": "5With alias", "description": "With alias description", "alias": "this-is-an-alias", @@ -75,7 +75,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": COLLECTIONS[5].id, + "id": str(COLLECTIONS[5].id), "title": "5With alias", "description": "With alias description", "alias": "this-is-an-alias", @@ -314,7 +314,7 @@ def test_add_collection( is_public_write=is_public_write, ) assert collection.id is not None - assert str(collection.api_root_id) == api_root_id + assert collection.api_root_id == api_root_id assert collection.title == title assert collection.description == description assert collection.alias == alias @@ -327,7 +327,7 @@ def test_add_collection( .filter(taxii2models.Collection.id == collection.id) .one() ) - assert str(db_collection.api_root_id) == api_root_id + assert db_collection.api_root_id == api_root_id assert db_collection.title == title assert db_collection.description == description assert db_collection.alias == alias diff --git a/tests/taxii2/test_taxii2_collections.py b/tests/taxii2/test_taxii2_collections.py index 37f422ab..929737d9 100644 --- a/tests/taxii2/test_taxii2_collections.py +++ b/tests/taxii2/test_taxii2_collections.py @@ -38,7 +38,7 @@ { "collections": [ { - "id": COLLECTIONS[0].id, + "id": str(COLLECTIONS[0].id), "title": "0Read only", "description": "Read only description", "can_read": True, @@ -46,7 +46,7 @@ "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[1].id, + "id": str(COLLECTIONS[1].id), "title": "1Write only", "description": "Write only description", "can_read": False, @@ -54,7 +54,7 @@ "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[2].id, + "id": str(COLLECTIONS[2].id), "title": "2Read/Write", "description": "Read/Write description", "can_read": True, @@ -62,7 +62,7 @@ "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[3].id, + "id": str(COLLECTIONS[3].id), "title": "3No permissions", "description": "No permissions description", "can_read": False, @@ -70,14 +70,14 @@ "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[4].id, + "id": str(COLLECTIONS[4].id), "title": "4No description", "can_read": True, "can_write": True, "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[5].id, + "id": str(COLLECTIONS[5].id), "title": "5With alias", "description": "With alias description", "alias": "this-is-an-alias", @@ -86,7 +86,7 @@ "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[6].id, + "id": str(COLLECTIONS[6].id), "title": "6Public", "description": "public description", "can_read": True, @@ -94,7 +94,7 @@ "media_types": ["application/stix+json;version=2.1"], }, { - "id": COLLECTIONS[7].id, + "id": str(COLLECTIONS[7].id), "title": "7Publicwrite", "description": "public write description", "can_read": False, diff --git a/tests/taxii2/test_taxii2_objects.py b/tests/taxii2/test_taxii2_objects.py index 4206a289..a4fce5ab 100644 --- a/tests/taxii2/test_taxii2_objects.py +++ b/tests/taxii2/test_taxii2_objects.py @@ -786,7 +786,7 @@ 202, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": JOBS[0].id, + "id": str(JOBS[0].id), "status": JOBS[0].status, "request_timestamp": taxii2_datetimeformat(JOBS[0].request_timestamp), "total_count": 4, @@ -1306,7 +1306,7 @@ def test_objects_unauthenticated( side_effect=ADD_OBJECTS_MOCK, ) as add_objects_mock, ): - kwargs = { + kwargs: dict = { "headers": { "Accept": "application/taxii+json;version=2.1", "Content-Type": "application/taxii+json;version=2.1", diff --git a/tests/taxii2/test_taxii2_sqldb.py b/tests/taxii2/test_taxii2_sqldb.py index a4247bcf..91b9900b 100644 --- a/tests/taxii2/test_taxii2_sqldb.py +++ b/tests/taxii2/test_taxii2_sqldb.py @@ -3,6 +3,7 @@ import pytest +from opentaxii.persistence.sqldb.api import Taxii2SQLDatabaseAPI from opentaxii.persistence.sqldb.taxii2models import Job, JobDetail, STIXObject from opentaxii.taxii2 import entities from opentaxii.taxii2.utils import DATETIMEFORMAT @@ -39,7 +40,7 @@ ], indirect=["db_api_roots"], ) -def test_get_api_roots(taxii2_sqldb_api, db_api_roots): +def test_get_api_roots(taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_api_roots): response = taxii2_sqldb_api.get_api_roots() assert response == [api_root for api_root in db_api_roots] @@ -61,7 +62,9 @@ def test_get_api_roots(taxii2_sqldb_api, db_api_roots): ), ], ) -def test_get_api_root(taxii2_sqldb_api, db_api_roots, api_root_id): +def test_get_api_root( + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_api_roots, api_root_id +): response = taxii2_sqldb_api.get_api_root(api_root_id) assert response == GET_API_ROOT_MOCK(api_root_id) @@ -96,7 +99,9 @@ def test_get_api_root(taxii2_sqldb_api, db_api_roots, api_root_id): ), ], ) -def test_get_job_and_details(taxii2_sqldb_api, db_jobs, api_root_id, job_id): +def test_get_job_and_details( + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_jobs, api_root_id, job_id +): response = taxii2_sqldb_api.get_job_and_details(api_root_id, job_id) assert response == GET_JOB_AND_DETAILS_MOCK(api_root_id, job_id) @@ -118,7 +123,9 @@ def test_get_job_and_details(taxii2_sqldb_api, db_jobs, api_root_id, job_id): ), ], ) -def test_get_collections(taxii2_sqldb_api, db_collections, api_root_id): +def test_get_collections( + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_collections, api_root_id +): response = taxii2_sqldb_api.get_collections(api_root_id) assert response == GET_COLLECTIONS_MOCK(api_root_id) @@ -147,22 +154,25 @@ def test_get_collections(taxii2_sqldb_api, db_collections, api_root_id): id="wrong api root", ), pytest.param( - str(uuid4()), + uuid4(), COLLECTIONS[0].id, id="unknown api root", ), pytest.param( API_ROOTS[0].id, - str(uuid4()), + uuid4(), id="unknown collection id", ), ], ) def test_get_collection( - taxii2_sqldb_api, db_collections, api_root_id, collection_id_or_alias + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, + db_collections, + api_root_id, + collection_id_or_alias, ): - response = taxii2_sqldb_api.get_collection(api_root_id, collection_id_or_alias) - assert response == GET_COLLECTION_MOCK(api_root_id, collection_id_or_alias) + response = taxii2_sqldb_api.get_collection(api_root_id, str(collection_id_or_alias)) + assert response == GET_COLLECTION_MOCK(api_root_id, str(collection_id_or_alias)) @pytest.mark.parametrize( @@ -428,7 +438,7 @@ def test_get_collection( ], ) def test_get_manifest( - taxii2_sqldb_api, + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_stix_objects, collection_id, limit, @@ -724,7 +734,7 @@ def test_get_manifest( ], ) def test_get_objects( - taxii2_sqldb_api, + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_stix_objects, collection_id, limit, @@ -831,7 +841,7 @@ def test_get_objects( ], ) def test_add_objects( - taxii2_sqldb_api, + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_stix_objects, api_root_id, collection_id, @@ -875,7 +885,7 @@ def test_add_objects( assert isinstance(job.completed_timestamp, datetime.datetime) # Check database state db_job = taxii2_sqldb_api.db.session.query(Job).one() - assert str(db_job.api_root_id) == api_root_id + assert db_job.api_root_id == api_root_id assert db_job.status == "complete" assert isinstance(db_job.request_timestamp, datetime.datetime) assert isinstance(db_job.completed_timestamp, datetime.datetime) @@ -892,7 +902,7 @@ def test_add_objects( .one() ) assert db_obj.id == obj["id"] - assert str(db_obj.collection_id) == collection_id + assert db_obj.collection_id == collection_id assert db_obj.type == obj["type"] assert db_obj.spec_version == obj["spec_version"] assert isinstance(db_obj.date_added, datetime.datetime) @@ -1171,7 +1181,7 @@ def test_add_objects( ], ) def test_get_object( - taxii2_sqldb_api, + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_stix_objects, collection_id, object_id, @@ -1245,7 +1255,7 @@ def test_get_object( ], ) def test_delete_object( - taxii2_sqldb_api, + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_stix_objects, collection_id, object_id, @@ -1260,7 +1270,7 @@ def test_delete_object( match_spec_version=match_spec_version, ) assert set( - (str(db_obj.collection_id), db_obj.id, db_obj.version) + (db_obj.collection_id, db_obj.id, db_obj.version) for db_obj in taxii2_sqldb_api.db.session.query(STIXObject).all() ) == set((obj.collection_id, obj.id, obj.version) for obj in expected_objects) @@ -1374,7 +1384,7 @@ def test_delete_object( ], ) def test_get_versions( - taxii2_sqldb_api, + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, db_stix_objects, collection_id, object_id, @@ -1417,7 +1427,9 @@ def test_get_versions( ), ], ) -def test_next_param(taxii2_sqldb_api, stix_id, date_added, next_param): +def test_next_param( + taxii2_sqldb_api: Taxii2SQLDatabaseAPI, stix_id, date_added, next_param +): assert ( taxii2_sqldb_api.get_next_param({"id": stix_id, "date_added": date_added}) == next_param diff --git a/tests/taxii2/test_taxii2_status.py b/tests/taxii2/test_taxii2_status.py index f698d698..3cbacbc8 100644 --- a/tests/taxii2/test_taxii2_status.py +++ b/tests/taxii2/test_taxii2_status.py @@ -40,7 +40,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": JOBS[0].id, + "id": str(JOBS[0].id), "status": JOBS[0].status, "request_timestamp": taxii2_datetimeformat(JOBS[0].request_timestamp), "total_count": 4, @@ -83,7 +83,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": JOBS[3].id, + "id": str(JOBS[3].id), "status": JOBS[3].status, "request_timestamp": taxii2_datetimeformat(JOBS[3].request_timestamp), "total_count": 0, @@ -103,7 +103,7 @@ 200, {"Content-Type": "application/taxii+json;version=2.1"}, { - "id": JOBS[6].id, + "id": str(JOBS[6].id), "status": JOBS[6].status, "request_timestamp": taxii2_datetimeformat(JOBS[6].request_timestamp), "total_count": 6, diff --git a/tests/taxii2/utils.py b/tests/taxii2/utils.py index 322b1a03..733b1c1a 100644 --- a/tests/taxii2/utils.py +++ b/tests/taxii2/utils.py @@ -1,7 +1,7 @@ import base64 import datetime -from typing import Dict, List, Optional -from uuid import uuid4 +from typing import Dict, List, Optional, Tuple +from uuid import UUID, uuid4 from opentaxii.server import ServerMapping from opentaxii.taxii2.entities import ( @@ -16,28 +16,28 @@ from opentaxii.taxii2.utils import DATETIMEFORMAT, taxii2_datetimeformat API_ROOTS_WITH_DEFAULT = ( - ApiRoot(str(uuid4()), True, "first title", "first description", False), - ApiRoot(str(uuid4()), False, "second title", "second description", True), + ApiRoot(uuid4(), True, "first title", "first description", False), + ApiRoot(uuid4(), False, "second title", "second description", True), ) API_ROOTS_WITHOUT_DEFAULT = ( - ApiRoot(str(uuid4()), False, "first title", "first description", False), - ApiRoot(str(uuid4()), False, "second title", "second description", True), - ApiRoot(str(uuid4()), False, "third title", None, False), + ApiRoot(uuid4(), False, "first title", "first description", False), + ApiRoot(uuid4(), False, "second title", "second description", True), + ApiRoot(uuid4(), False, "third title", None, False), ) API_ROOTS = API_ROOTS_WITHOUT_DEFAULT NOW = datetime.datetime.now(datetime.timezone.utc) -JOBS = tuple() +JOBS: Tuple[Job, ...] = tuple() for api_root in API_ROOTS: JOBS = JOBS + ( Job( - str(uuid4()), + uuid4(), api_root.id, "complete", NOW, NOW - datetime.timedelta(hours=24, minutes=1), ), Job( - str(uuid4()), + uuid4(), api_root.id, "pending", NOW, @@ -46,7 +46,7 @@ ) JOBS = JOBS + ( Job( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "pending", NOW, @@ -61,7 +61,7 @@ JOBS[0].details.success.extend( [ JobDetail( - id=str(uuid4()), + id=uuid4(), job_id=JOBS[0].id, stix_id="indicator--c410e480-e42b-47d1-9476-85307c12bcbf", version=datetime.datetime.strptime( @@ -76,7 +76,7 @@ JOBS[0].details.failure.extend( [ JobDetail( - id=str(uuid4()), + id=uuid4(), job_id=JOBS[0].id, stix_id="malware--664fa29d-bf65-4f28-a667-bdb76f29ec98", version=datetime.datetime.strptime( @@ -91,7 +91,7 @@ JOBS[0].details.pending.extend( [ JobDetail( - id=str(uuid4()), + id=uuid4(), job_id=JOBS[0].id, stix_id="indicator--252c7c11-daf2-42bd-843b-be65edca9f61", version=datetime.datetime.strptime( @@ -101,7 +101,7 @@ status="pending", ), JobDetail( - id=str(uuid4()), + id=uuid4(), job_id=JOBS[0].id, stix_id="relationship--045585ad-a22f-4333-af33-bfd503a683b5", version=datetime.datetime.strptime( @@ -117,7 +117,7 @@ COLLECTIONS = ( Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "0Read only", "Read only description", @@ -126,7 +126,7 @@ False, ), Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "1Write only", "Write only description", @@ -135,7 +135,7 @@ False, ), Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "2Read/Write", "Read/Write description", @@ -144,7 +144,7 @@ False, ), Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "3No permissions", "No permissions description", @@ -152,11 +152,9 @@ False, False, ), + Collection(uuid4(), API_ROOTS[0].id, "4No description", "", None, False, False), Collection( - str(uuid4()), API_ROOTS[0].id, "4No description", "", None, False, False - ), - Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "5With alias", "With alias description", @@ -165,7 +163,7 @@ False, ), Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "6Public", "public description", @@ -174,7 +172,7 @@ False, ), Collection( - str(uuid4()), + uuid4(), API_ROOTS[0].id, "7Publicwrite", "public write description", @@ -183,7 +181,7 @@ True, ), ) -STIX_OBJECTS = ( +STIX_OBJECTS: tuple[STIXObject, ...] = ( STIXObject( f"indicator--{str(uuid4())}", COLLECTIONS[5].id, @@ -288,7 +286,7 @@ def process_match_version(match_version): if match_version is None: match_version = ["last"] - versions_per_id = {} + versions_per_id: dict = {} for stix_obj in STIX_OBJECTS: if stix_obj.id not in versions_per_id: versions_per_id[stix_obj.id] = [] @@ -311,14 +309,14 @@ def process_match_version(match_version): return id_version_combos -def GET_API_ROOT_MOCK(api_root_id): +def GET_API_ROOT_MOCK(api_root_id: UUID): for api_root in API_ROOTS: if api_root.id == api_root_id: return api_root return None -def GET_JOB_AND_DETAILS_MOCK(api_root_id, job_id): +def GET_JOB_AND_DETAILS_MOCK(api_root_id: UUID, job_id: UUID): job_response = None for job in JOBS: if job.api_root_id == api_root_id and job.id == job_id: @@ -327,7 +325,7 @@ def GET_JOB_AND_DETAILS_MOCK(api_root_id, job_id): return job_response -def GET_COLLECTIONS_MOCK(api_root_id): +def GET_COLLECTIONS_MOCK(api_root_id: UUID): response = [] for collection in COLLECTIONS: if collection.api_root_id == api_root_id: @@ -335,10 +333,10 @@ def GET_COLLECTIONS_MOCK(api_root_id): return response -def GET_COLLECTION_MOCK(api_root_id, collection_id_or_alias): +def GET_COLLECTION_MOCK(api_root_id: UUID, collection_id_or_alias: str): for collection in COLLECTIONS: if collection.api_root_id == api_root_id and ( - collection.id == collection_id_or_alias + str(collection.id) == collection_id_or_alias or collection.alias == collection_id_or_alias ): return collection @@ -352,7 +350,7 @@ def STIX_OBJECT_FROM_MANIFEST(stix_id): def GET_MANIFEST_MOCK( - collection_id: str, + collection_id: UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -385,7 +383,7 @@ def GET_NEXT_PARAM(kwargs: Dict) -> str: def GET_OBJECTS_MOCK( - collection_id: str, + collection_id: UUID, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, next_kwargs: Optional[Dict] = None, @@ -395,7 +393,7 @@ def GET_OBJECTS_MOCK( match_spec_version: Optional[List[str]] = None, ): id_version_combos = process_match_version(match_version) - response = [] + response: List = [] more = False for stix_object in STIX_OBJECTS: if ( @@ -434,7 +432,7 @@ def GET_OBJECTS_MOCK( def GET_OBJECT_MOCK( - collection_id: str, + collection_id: UUID, object_id: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, @@ -443,7 +441,7 @@ def GET_OBJECT_MOCK( match_spec_version: Optional[List[str]] = None, ): id_version_combos = process_match_version(match_version) - response = [] + response: List = [] more = False at_least_one = False for stix_object in STIX_OBJECTS: @@ -479,16 +477,16 @@ def GET_OBJECT_MOCK( else: next_param = None if not at_least_one: - response = None + response = None # type: ignore[assignment] return response, more, next_param -def ADD_OBJECTS_MOCK(api_root_id: str, collection_id: str, objects: List[Dict]): +def ADD_OBJECTS_MOCK(api_root_id: UUID, collection_id: str, objects: List[Dict]): return JOBS[0] def DELETE_OBJECT_MOCK( - collection_id: str, + collection_id: UUID, object_id: str, match_version: Optional[List[str]] = None, match_spec_version: Optional[List[str]] = None, @@ -497,7 +495,7 @@ def DELETE_OBJECT_MOCK( def GET_VERSIONS_MOCK( - collection_id: str, + collection_id: UUID, object_id: str, limit: Optional[int] = None, added_after: Optional[datetime.datetime] = None, diff --git a/tests/test_auth.py b/tests/test_auth.py index 9b0336e7..81a4bfb8 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,15 +2,16 @@ import json import pytest -from fixtures import VID_TAXII_HTTP_10 from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 from libtaxii.constants import CB_STIX_XML_111, RT_FULL, ST_BAD_MESSAGE, ST_UNAUTHORIZED -from utils import as_tm, is_headers_valid, prepare_headers from opentaxii.taxii.http import HTTP_AUTHORIZATION from opentaxii.utils import sync_conf_dict_into_db +from .fixtures import VID_TAXII_HTTP_10 +from .utils import as_tm, is_headers_valid, prepare_headers + INBOX_OPEN = dict( id='inbox-A', type='inbox', diff --git a/tests/test_cli.py b/tests/test_cli.py index dcb40721..16b9bbf2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -405,13 +405,13 @@ def test_add_api_root( ["argv", "raises", "message", "stdout", "stderr", "expected_call"], [ pytest.param( - ["-r", API_ROOTS[0].id, "-t", "my new collection"], # argv + ["-r", str(API_ROOTS[0].id), "-t", "my new collection"], # argv False, # raises None, # message "", # stdout "", # stderr { - "api_root_id": API_ROOTS[0].id, + "api_root_id": str(API_ROOTS[0].id), "title": "my new collection", "description": None, "alias": None, @@ -423,7 +423,7 @@ def test_add_api_root( pytest.param( [ "-r", - API_ROOTS[0].id, + str(API_ROOTS[0].id), "-t", "my new collection", "-d", @@ -434,7 +434,7 @@ def test_add_api_root( "", # stdout "", # stderr { - "api_root_id": API_ROOTS[0].id, + "api_root_id": str(API_ROOTS[0].id), "title": "my new collection", "description": "my description", "alias": None, @@ -446,7 +446,7 @@ def test_add_api_root( pytest.param( [ "-r", - API_ROOTS[0].id, + str(API_ROOTS[0].id), "-t", "my new collection", "-d", @@ -459,7 +459,7 @@ def test_add_api_root( "", # stdout "", # stderr { - "api_root_id": API_ROOTS[0].id, + "api_root_id": str(API_ROOTS[0].id), "title": "my new collection", "description": "my description", "alias": "my-alias", @@ -469,13 +469,13 @@ def test_add_api_root( id="rootid, title, description, alias", ), pytest.param( - ["-r", API_ROOTS[0].id, "-t", "my new collection", "--public"], # argv + ["-r", str(API_ROOTS[0].id), "-t", "my new collection", "--public"], # argv False, # raises None, # message "", # stdout "", # stderr { - "api_root_id": API_ROOTS[0].id, + "api_root_id": str(API_ROOTS[0].id), "title": "my new collection", "description": None, "alias": None, @@ -487,7 +487,7 @@ def test_add_api_root( pytest.param( [ "-r", - API_ROOTS[0].id, + str(API_ROOTS[0].id), "-t", "my new collection", "--public-write", @@ -497,7 +497,7 @@ def test_add_api_root( "", # stdout "", # stderr { - "api_root_id": API_ROOTS[0].id, + "api_root_id": str(API_ROOTS[0].id), "title": "my new collection", "description": None, "alias": None, @@ -560,7 +560,7 @@ def test_add_collection( ) stderr = stderr.replace( "ROOTIDS", - ",".join([api_root.id for api_root in db_api_roots]), + ",".join([str(api_root.id) for api_root in db_api_roots]), ) with ( mock.patch("opentaxii.cli.persistence.app", app), diff --git a/tests/test_config.py b/tests/test_config.py index 54095f34..33945ed4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import Tuple import pytest @@ -169,7 +170,7 @@ def test_custom_config_file(config_file_name_expected_value): deprecation_warning, taxii2_only_warning, ) = config_file_name_expected_value - warning_classes = (UserWarning,) + warning_classes: Tuple = (UserWarning,) if deprecation_warning or taxii2_only_warning: warning_classes += (DeprecationWarning,) expected_warnings = {"Ignoring invalid configuration item 'dummy'."} diff --git a/tests/test_delete_content_blocks.py b/tests/test_delete_content_blocks.py index 19aca42d..0538ea85 100644 --- a/tests/test_delete_content_blocks.py +++ b/tests/test_delete_content_blocks.py @@ -1,7 +1,8 @@ import datetime import pytest -from fixtures import COLLECTION_OPEN, COLLECTIONS_A + +from .fixtures import COLLECTION_OPEN, COLLECTIONS_A @pytest.mark.parametrize("with_messages", [True, False]) diff --git a/tests/test_http.py b/tests/test_http.py index 35f01b58..57682198 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,10 +1,11 @@ import pytest from libtaxii.constants import ST_BAD_MESSAGE, ST_FAILURE -from utils import as_tm, is_headers_valid, prepare_headers from opentaxii.taxii.converters import dict_to_service_entity from opentaxii.taxii.http import HTTP_X_TAXII_SERVICES +from .utils import as_tm, is_headers_valid, prepare_headers + INBOX = dict( id='inbox-A', type='inbox', @@ -38,7 +39,7 @@ ) SERVICES = [INBOX, DISCOVERY, DISCOVERY_NOT_AVAILABLE] -INSTANCES_CONFIGURED = sum(len(s['protocol_bindings']) for s in SERVICES) +INSTANCES_CONFIGURED = sum(len(s['protocol_bindings']) for s in SERVICES) # type: ignore MESSAGE_ID = '123' diff --git a/tests/test_server.py b/tests/test_server.py index 31095456..86a19f22 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,13 +1,14 @@ import concurrent.futures import pytest -from fixtures import DOMAIN from opentaxii.persistence import OpenTAXII2PersistenceAPI, Taxii2PersistenceManager from opentaxii.persistence.sqldb import Taxii2SQLDatabaseAPI from opentaxii.server import TAXII2Server from opentaxii.taxii.converters import dict_to_service_entity +from .fixtures import DOMAIN + INBOX = dict( id='inbox-A', type='inbox', diff --git a/tests/utils.py b/tests/utils.py index 34b639e8..da3e5dc8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,6 @@ import re import pytest -from fixtures import CB_STIX_XML_111, CONTENT, MESSAGE, MESSAGE_ID from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 @@ -16,6 +15,8 @@ ) from opentaxii.taxii.utils import get_utc_now +from .fixtures import CB_STIX_XML_111, CONTENT, MESSAGE, MESSAGE_ID + JWT_RE = re.compile(r'[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*') diff --git a/tox.ini b/tox.ini index 058bbe73..49c718da 100644 --- a/tox.ini +++ b/tox.ini @@ -12,10 +12,14 @@ python = commands = py.test --cov {envsitepackagesdir}/opentaxii {posargs} deps = - sqlite: -rrequirements-dev.txt - mysql,mariadb: -rrequirements-dev-mysql.txt - postgres-!pypy3: -rrequirements-dev-postgres.txt - postgres-pypy3: -rrequirements-dev-postgres-pypy.txt + !pypy3: -rrequirements-dev.txt + # does not support recent mypy + # https://github.com/python/mypy/issues/20329 + pypy3: -rrequirements-test.txt + pypy3: -rrequirements.txt + mysql,mariadb: -rrequirements-mysql.txt + postgres-!pypy3: -rrequirements-postgres.txt + postgres-pypy3: -rrequirements-postgres-pypy.txt sqlalchemy14: sqlalchemy>=1.4,<1.5 werkzeuglt21: werkzeug<2.1 werkzeuggte21: werkzeug>=2.1