diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 414eef7a1b..8dad68ab20 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -67,7 +67,18 @@ jobs: mypy --install-types --non-interactive bson/codec_options.py mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index --allow-redefinition --allow-untyped-globals --exclude "test/mypy_fails/*.*" test python -m pip install -U typing_extensions - mypy --install-types --non-interactive test/test_mypy.py + mypy --install-types --non-interactive test/test_typing.py test/test_typing_strict.py + - name: Run mypy strict + run: | + mypy --strict test/test_typing_strict.py + - name: Run pyright + run: | + python -m pip install -U pip pyright==1.1.290 + pyright test/test_typing.py test/test_typing_strict.py + - name: Run pyright strict + run: | + echo '{"strict": ["tests/test_typing_strict.py"]}' >> pyrightconfig.json + pyright test/test_typing_strict.py linkcheck: name: Check Links diff --git a/bson/__init__.py b/bson/__init__.py index c6a81d97ec..2fe4aa173e 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -101,7 +101,6 @@ DEFAULT_CODEC_OPTIONS, CodecOptions, DatetimeConversion, - _DocumentType, _raw_document_class, ) from bson.datetime_ms import ( @@ -125,8 +124,7 @@ # Import some modules for type-checking only. if TYPE_CHECKING: - from array import array - from mmap import mmap + from bson.typings import _DocumentIn, _DocumentType, _ReadableBuffer try: from bson import _cbson # type: ignore[attr-defined] @@ -986,12 +984,8 @@ def _dict_to_bson(doc: Any, check_keys: bool, opts: CodecOptions, top_level: boo _CODEC_OPTIONS_TYPE_ERROR = TypeError("codec_options must be an instance of CodecOptions") -_DocumentIn = Mapping[str, Any] -_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] - - def encode( - document: _DocumentIn, + document: "_DocumentIn", check_keys: bool = False, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, ) -> bytes: @@ -1022,8 +1016,8 @@ def encode( def decode( - data: _ReadableBuffer, codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> _DocumentType: + data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None +) -> "_DocumentType": """Decode BSON to a document. By default, returns a BSON document represented as a Python @@ -1056,11 +1050,13 @@ def decode( return _bson_to_dict(data, opts) -def _decode_all(data: _ReadableBuffer, opts: "CodecOptions[_DocumentType]") -> List[_DocumentType]: +def _decode_all( + data: "_ReadableBuffer", opts: "CodecOptions[_DocumentType]" +) -> "List[_DocumentType]": """Decode a BSON data to multiple documents.""" data, view = get_data_and_view(data) data_len = len(data) - docs: List[_DocumentType] = [] + docs: "List[_DocumentType]" = [] position = 0 end = data_len - 1 use_raw = _raw_document_class(opts.document_class) @@ -1091,8 +1087,8 @@ def _decode_all(data: _ReadableBuffer, opts: "CodecOptions[_DocumentType]") -> L def decode_all( - data: _ReadableBuffer, codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> List[_DocumentType]: + data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None +) -> "List[_DocumentType]": """Decode BSON data to multiple documents. `data` must be a bytes-like object implementing the buffer protocol that @@ -1213,7 +1209,7 @@ def _decode_all_selective(data: Any, codec_options: CodecOptions, fields: Any) - # Decode documents for internal use. from bson.raw_bson import RawBSONDocument - internal_codec_options = codec_options.with_options( + internal_codec_options: CodecOptions[RawBSONDocument] = codec_options.with_options( document_class=RawBSONDocument, type_registry=None ) _doc = _bson_to_dict(data, internal_codec_options) @@ -1228,7 +1224,7 @@ def _decode_all_selective(data: Any, codec_options: CodecOptions, fields: Any) - def decode_iter( data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> Iterator[_DocumentType]: +) -> "Iterator[_DocumentType]": """Decode BSON data to multiple documents as a generator. Works similarly to the decode_all function, but yields one document at a @@ -1264,7 +1260,7 @@ def decode_iter( def decode_file_iter( file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> Iterator[_DocumentType]: +) -> "Iterator[_DocumentType]": """Decode bson data from a file to multiple documents as a generator. Works similarly to the decode_all function, but reads from the file object @@ -1325,7 +1321,7 @@ class BSON(bytes): @classmethod def encode( cls: Type["BSON"], - document: _DocumentIn, + document: "_DocumentIn", check_keys: bool = False, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, ) -> "BSON": @@ -1352,7 +1348,7 @@ def encode( """ return cls(encode(document, check_keys, codec_options)) - def decode(self, codec_options: "CodecOptions[_DocumentType]" = DEFAULT_CODEC_OPTIONS) -> _DocumentType: # type: ignore[override,assignment] + def decode(self, codec_options: "CodecOptions[_DocumentType]" = DEFAULT_CODEC_OPTIONS) -> "_DocumentType": # type: ignore[override,assignment] """Decode this BSON data. By default, returns a BSON document represented as a Python diff --git a/bson/codec_options.pyi b/bson/codec_options.pyi index 2424516f08..8242bd4cb2 100644 --- a/bson/codec_options.pyi +++ b/bson/codec_options.pyi @@ -22,7 +22,8 @@ you get the error: "TypeError: 'type' object is not subscriptable". import datetime import abc import enum -from typing import Tuple, Generic, Optional, Mapping, Any, TypeVar, Type, Dict, Iterable, Tuple, MutableMapping, Callable, Union +from typing import Tuple, Generic, Optional, Mapping, Any, Type, Dict, Iterable, Tuple, Callable, Union +from bson.typings import _DocumentType, _DocumentTypeArg class TypeEncoder(abc.ABC, metaclass=abc.ABCMeta): @@ -52,9 +53,6 @@ class TypeRegistry: def __init__(self, type_codecs: Optional[Iterable[Codec]] = ..., fallback_encoder: Optional[Fallback] = ...) -> None: ... def __eq__(self, other: Any) -> Any: ... - -_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any]) - class DatetimeConversion(int, enum.Enum): DATETIME = ... DATETIME_CLAMP = ... @@ -82,7 +80,7 @@ class CodecOptions(Tuple, Generic[_DocumentType]): ) -> CodecOptions[_DocumentType]: ... # CodecOptions API - def with_options(self, **kwargs: Any) -> CodecOptions[_DocumentType]: ... + def with_options(self, **kwargs: Any) -> CodecOptions[_DocumentTypeArg]: ... def _arguments_repr(self) -> str: ... @@ -100,7 +98,7 @@ class CodecOptions(Tuple, Generic[_DocumentType]): _fields: Tuple[str] -DEFAULT_CODEC_OPTIONS: CodecOptions[MutableMapping[str, Any]] +DEFAULT_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" _RAW_BSON_DOCUMENT_MARKER: int def _raw_document_class(document_class: Any) -> bool: ... diff --git a/bson/typings.py b/bson/typings.py new file mode 100644 index 0000000000..14a8131f69 --- /dev/null +++ b/bson/typings.py @@ -0,0 +1,30 @@ +# Copyright 2023-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type aliases used by bson""" +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, TypeVar, Union + +if TYPE_CHECKING: + from array import array + from mmap import mmap + + from bson.raw_bson import RawBSONDocument + + +# Common Shared Types. +_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] +_DocumentOut = _DocumentIn +_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any]) +_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any]) +_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] diff --git a/doc/examples/type_hints.rst b/doc/examples/type_hints.rst index b413ad7b24..e5ad3338e1 100644 --- a/doc/examples/type_hints.rst +++ b/doc/examples/type_hints.rst @@ -20,7 +20,7 @@ type of document object returned when decoding BSON documents. Due to `limitations in mypy`_, the default values for generic document types are not yet provided (they will eventually be ``Dict[str, any]``). -For a larger set of examples that use types, see the PyMongo `test_mypy module`_. +For a larger set of examples that use types, see the PyMongo `test_typing module`_. If you would like to opt out of using the provided types, add the following to your `mypy config`_: :: @@ -326,5 +326,5 @@ Another example is trying to set a value on a :class:`~bson.raw_bson.RawBSONDocu .. _mypy: https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html .. _limitations in mypy: https://github.com/python/mypy/issues/3737 .. _mypy config: https://mypy.readthedocs.io/en/stable/config_file.html -.. _test_mypy module: https://github.com/mongodb/mongo-python-driver/blob/master/test/test_mypy.py +.. _test_typing module: https://github.com/mongodb/mongo-python-driver/blob/master/test/test_typing.py .. _schema validation: https://www.mongodb.com/docs/manual/core/schema-validation/#when-to-use-schema-validation diff --git a/mypy.ini b/mypy.ini index 2562177ab1..d0e6ab5ff9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,7 +32,7 @@ ignore_missing_imports = True [mypy-snappy.*] ignore_missing_imports = True -[mypy-test.test_mypy] +[mypy-test.test_typing] warn_unused_ignores = True [mypy-winkerberos.*] diff --git a/pymongo/collection.py b/pymongo/collection.py index 77f154f5e7..4cb3fa79c9 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -72,7 +72,7 @@ InsertOneResult, UpdateResult, ) -from pymongo.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import WriteConcern _FIND_AND_MODIFY_DOC_FIELDS = {"value": 1} @@ -103,6 +103,7 @@ class ReturnDocument(object): if TYPE_CHECKING: + import bson from pymongo.client_session import ClientSession from pymongo.database import Database from pymongo.read_concern import ReadConcern @@ -116,7 +117,7 @@ def __init__( database: "Database[_DocumentType]", name: str, create: Optional[bool] = False, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional["ReadConcern"] = None, @@ -394,7 +395,7 @@ def database(self) -> "Database[_DocumentType]": def with_options( self, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional["ReadConcern"] = None, diff --git a/pymongo/database.py b/pymongo/database.py index 259c22d558..86754b2c05 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -29,7 +29,7 @@ cast, ) -from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions +from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.dbref import DBRef from bson.son import SON from bson.timestamp import Timestamp @@ -41,7 +41,7 @@ from pymongo.common import _ecc_coll_name, _ecoc_coll_name, _esc_coll_name from pymongo.errors import CollectionInvalid, InvalidName from pymongo.read_preferences import ReadPreference, _ServerMode -from pymongo.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline def _check_name(name): @@ -55,6 +55,7 @@ def _check_name(name): if TYPE_CHECKING: + import bson import bson.codec_options from pymongo.client_session import ClientSession from pymongo.mongo_client import MongoClient @@ -72,7 +73,7 @@ def __init__( self, client: "MongoClient[_DocumentType]", name: str, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional["WriteConcern"] = None, read_concern: Optional["ReadConcern"] = None, @@ -152,7 +153,7 @@ def name(self) -> str: def with_options( self, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional["WriteConcern"] = None, read_concern: Optional["ReadConcern"] = None, @@ -239,7 +240,7 @@ def __getitem__(self, name: str) -> "Collection[_DocumentType]": def get_collection( self, name: str, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional["WriteConcern"] = None, read_concern: Optional["ReadConcern"] = None, @@ -295,7 +296,7 @@ def get_collection( def create_collection( self, name: str, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional["WriteConcern"] = None, read_concern: Optional["ReadConcern"] = None, @@ -976,7 +977,7 @@ def _drop_helper(self, name, session=None, comment=None): @_csot.apply def drop_collection( self, - name_or_collection: Union[str, Collection], + name_or_collection: Union[str, Collection[_DocumentTypeArg]], session: Optional["ClientSession"] = None, comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, @@ -1068,7 +1069,7 @@ def drop_collection( def validate_collection( self, - name_or_collection: Union[str, Collection], + name_or_collection: Union[str, Collection[_DocumentTypeArg]], scandata: bool = False, full: bool = False, session: Optional["ClientSession"] = None, diff --git a/pymongo/message.py b/pymongo/message.py index 960832cb9e..9fa64a875a 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -24,7 +24,7 @@ import random import struct from io import BytesIO as _BytesIO -from typing import Any, Dict, NoReturn +from typing import Any, Mapping, NoReturn import bson from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode @@ -81,7 +81,7 @@ } _FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} -_UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Dict[str, Any]]" = CodecOptions( +_UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions( unicode_decode_error_handler="replace" ) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index dccd4bb6b1..ab0c749889 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -53,7 +53,8 @@ cast, ) -from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry +import bson +from bson.codec_options import DEFAULT_CODEC_OPTIONS, TypeRegistry from bson.son import SON from bson.timestamp import Timestamp from pymongo import ( @@ -90,7 +91,13 @@ from pymongo.settings import TopologySettings from pymongo.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription -from pymongo.typings import _Address, _CollationIn, _DocumentType, _Pipeline +from pymongo.typings import ( + _Address, + _CollationIn, + _DocumentType, + _DocumentTypeArg, + _Pipeline, +) from pymongo.uri_parser import ( _check_options, _handle_option_deprecations, @@ -1875,7 +1882,7 @@ def list_database_names( @_csot.apply def drop_database( self, - name_or_database: Union[str, database.Database], + name_or_database: Union[str, database.Database[_DocumentTypeArg]], session: Optional[client_session.ClientSession] = None, comment: Optional[Any] = None, ) -> None: @@ -1928,7 +1935,7 @@ def drop_database( def get_default_database( self, default: Optional[str] = None, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional["ReadConcern"] = None, @@ -1989,7 +1996,7 @@ def get_default_database( def get_database( self, name: Optional[str] = None, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional["bson.CodecOptions[_DocumentTypeArg]"] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional["ReadConcern"] = None, diff --git a/pymongo/typings.py b/pymongo/typings.py index fe0e8bd523..32cd980c97 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -13,30 +13,18 @@ # limitations under the License. """Type aliases used by PyMongo""" -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, Union + +from bson.typings import _DocumentIn, _DocumentOut, _DocumentType, _DocumentTypeArg if TYPE_CHECKING: - from bson.raw_bson import RawBSONDocument from pymongo.collation import Collation # Common Shared Types. _Address = Tuple[str, Optional[int]] _CollationIn = Union[Mapping[str, Any], "Collation"] -_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] _Pipeline = Sequence[Mapping[str, Any]] -_DocumentOut = _DocumentIn -_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any]) def strip_optional(elem): @@ -44,3 +32,15 @@ def strip_optional(elem): while inside a list comprehension.""" assert elem is not None return elem + + +__all__ = [ + "_DocumentIn", + "_DocumentOut", + "_DocumentType", + "_DocumentTypeArg", + "_Address", + "_CollationIn", + "_Pipeline", + "strip_optional", +] diff --git a/test/test_database.py b/test/test_database.py index 53af4912e4..b6be380aab 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -435,7 +435,7 @@ def test_id_ordering(self): db.test.insert_one(SON([("hello", "world"), ("_id", 5)])) db = self.client.get_database( - "pymongo_test", codec_options=CodecOptions(document_class=SON) + "pymongo_test", codec_options=CodecOptions(document_class=SON[str, Any]) ) cursor = db.test.find() for x in cursor: @@ -469,7 +469,7 @@ def test_deref_kwargs(self): db.test.insert_one({"_id": 4, "foo": "bar"}) db = self.client.get_database( - "pymongo_test", codec_options=CodecOptions(document_class=SON) + "pymongo_test", codec_options=CodecOptions(document_class=SON[str, Any]) ) self.assertEqual( SON([("foo", "bar")]), db.dereference(DBRef("test", 4), projection={"_id": False}) diff --git a/test/test_mypy.py b/test/test_typing.py similarity index 98% rename from test/test_mypy.py rename to test/test_typing.py index 3b29bbf20e..8fc0f5a23e 100644 --- a/test/test_mypy.py +++ b/test/test_typing.py @@ -422,7 +422,8 @@ def test_typeddict_not_required_document_type(self) -> None: assert out is not None # This should fail because the output is a Movie. assert out["foo"] # type:ignore[typeddict-item] - assert out["_id"] + # pyright gives reportTypedDictNotRequiredAccess for the following: + assert out["_id"] # type:ignore @only_type_check def test_typeddict_empty_document_type(self) -> None: @@ -442,7 +443,8 @@ def test_typeddict_find_notrequired(self): coll.insert_one(ImplicitMovie(name="THX-1138", year=1971)) out = coll.find_one({}) assert out is not None - assert out["_id"] + # pyright gives reportTypedDictNotRequiredAccess for the following: + assert out["_id"] # type:ignore @only_type_check def test_raw_bson_document_type(self) -> None: diff --git a/test/test_typing_strict.py b/test/test_typing_strict.py new file mode 100644 index 0000000000..55cb1454bc --- /dev/null +++ b/test/test_typing_strict.py @@ -0,0 +1,38 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test typings in strict mode.""" +import unittest +from typing import TYPE_CHECKING, Any, Dict + +import pymongo +from pymongo.collection import Collection +from pymongo.database import Database + + +def test_generic_arguments() -> None: + """Ensure known usages of generic arguments pass strict typing""" + if not TYPE_CHECKING: + raise unittest.SkipTest("Used for Type Checking Only") + mongo_client: pymongo.MongoClient[Dict[str, Any]] = pymongo.MongoClient() + mongo_client.drop_database("foo") + mongo_client.get_default_database() + db = mongo_client.get_database("test_db") + db = Database(mongo_client, "test_db") + db.with_options() + db.validate_collection("py_test") + col = db.get_collection("py_test") + col.insert_one({"abc": 123}) + col = Collection(db, "py_test") + col.with_options()