diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index ef3573022d..80820dff91 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -20,7 +20,7 @@ from bson import _bson_to_dict from bson.raw_bson import RawBSONDocument from bson.timestamp import Timestamp -from pymongo import common +from pymongo import _csot, common from pymongo.aggregation import ( _CollectionAggregationCommand, _DatabaseAggregationCommand, @@ -128,6 +128,8 @@ def __init__( self._start_at_operation_time = start_at_operation_time self._session = session self._comment = comment + self._closed = False + self._timeout = self._target._timeout # Initialize cursor. self._cursor = self._create_cursor() @@ -234,6 +236,7 @@ def _resume(self): def close(self) -> None: """Close this ChangeStream.""" + self._closed = True self._cursor.close() def __iter__(self) -> "ChangeStream[_DocumentType]": @@ -248,6 +251,7 @@ def resume_token(self) -> Optional[Mapping[str, Any]]: """ return copy.deepcopy(self._resume_token) + @_csot.apply def next(self) -> _DocumentType: """Advance the cursor. @@ -298,8 +302,9 @@ def alive(self) -> bool: .. versionadded:: 3.8 """ - return self._cursor.alive + return not self._closed + @_csot.apply def try_next(self) -> Optional[_DocumentType]: """Advance the cursor without blocking indefinitely. @@ -332,6 +337,9 @@ def try_next(self) -> Optional[_DocumentType]: .. versionadded:: 3.8 """ + if not self._closed and not self._cursor.alive: + self._resume() + # Attempt to get the next change with at most one getMore and at most # one resume attempt. try: @@ -350,6 +358,10 @@ def try_next(self) -> Optional[_DocumentType]: self._resume() change = self._cursor._try_next(False) + # Check if the cursor was invalidated. + if not self._cursor.alive: + self._closed = True + # If no changes are available. if change is None: # We have either iterated over all documents in the cursor, diff --git a/test/test_change_stream.py b/test/test_change_stream.py index f3f206d965..11ed2895ac 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -486,7 +486,7 @@ def _get_expected_resume_token(self, stream, listener, previous_change=None): return response["cursor"]["postBatchResumeToken"] @no_type_check - def _test_raises_error_on_missing_id(self, expected_exception): + def _test_raises_error_on_missing_id(self, expected_exception, expected_exception2): """ChangeStream will raise an exception if the server response is missing the resume token. """ @@ -494,8 +494,7 @@ def _test_raises_error_on_missing_id(self, expected_exception): self.watched_collection().insert_one({}) with self.assertRaises(expected_exception): next(change_stream) - # The cursor should now be closed. - with self.assertRaises(StopIteration): + with self.assertRaises(expected_exception2): next(change_stream) @no_type_check @@ -525,17 +524,16 @@ def test_update_resume_token_legacy(self): self._test_update_resume_token(self._get_expected_resume_token_legacy) # Prose test no. 2 - @client_context.require_version_max(4, 3, 3) # PYTHON-2120 @client_context.require_version_min(4, 1, 8) def test_raises_error_on_missing_id_418plus(self): - # Server returns an error on 4.1.8+ - self._test_raises_error_on_missing_id(OperationFailure) + # Server returns an error on 4.1.8+, subsequent next() resumes and gets the same error. + self._test_raises_error_on_missing_id(OperationFailure, OperationFailure) # Prose test no. 2 @client_context.require_version_max(4, 1, 8) def test_raises_error_on_missing_id_418minus(self): - # PyMongo raises an error - self._test_raises_error_on_missing_id(InvalidOperation) + # PyMongo raises an error, closes the cursor, subsequent next() raises StopIteration. + self._test_raises_error_on_missing_id(InvalidOperation, StopIteration) # Prose test no. 3 @no_type_check diff --git a/test/test_csot.py b/test/test_csot.py index 4d71973320..7b82a49caf 100644 --- a/test/test_csot.py +++ b/test/test_csot.py @@ -19,11 +19,12 @@ sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes import pymongo from pymongo import _csot +from pymongo.errors import PyMongoError # Location of JSON test specifications. TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "csot") @@ -72,6 +73,36 @@ def test_timeout_nested(self): self.assertEqual(_csot.get_deadline(), float("inf")) self.assertEqual(_csot.get_rtt(), 0.0) + @client_context.require_version_min(3, 6) + @client_context.require_no_mmap + @client_context.require_no_standalone + def test_change_stream_can_resume_after_timeouts(self): + coll = self.db.test + with coll.watch(max_await_time_ms=150) as stream: + with pymongo.timeout(0.1): + with self.assertRaises(PyMongoError) as ctx: + stream.try_next() + self.assertTrue(ctx.exception.timeout) + self.assertTrue(stream.alive) + with self.assertRaises(PyMongoError) as ctx: + stream.try_next() + self.assertTrue(ctx.exception.timeout) + self.assertTrue(stream.alive) + # Resume before the insert on 3.6 because 4.0 is required to avoid skipping documents + if client_context.version < (4, 0): + stream.try_next() + coll.insert_one({}) + with pymongo.timeout(10): + self.assertTrue(stream.next()) + self.assertTrue(stream.alive) + # Timeout applies to entire next() call, not only individual commands. + with pymongo.timeout(0.5): + with self.assertRaises(PyMongoError) as ctx: + stream.next() + self.assertTrue(ctx.exception.timeout) + self.assertTrue(stream.alive) + self.assertFalse(stream.alive) + if __name__ == "__main__": unittest.main() diff --git a/test/unified_format.py b/test/unified_format.py index d36b5d0a48..260c8187b6 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1076,10 +1076,6 @@ def _sessionOperation_startTransaction(self, target, *args, **kwargs): self.__raise_if_unsupported("startTransaction", target, ClientSession) return target.start_transaction(*args, **kwargs) - def _cursor_iterateOnce(self, target, *args, **kwargs): - self.__raise_if_unsupported("iterateOnce", target, NonLazyCursor, ChangeStream) - return target.try_next() - def _changeStreamOperation_iterateUntilDocumentOrError(self, target, *args, **kwargs): self.__raise_if_unsupported("iterateUntilDocumentOrError", target, ChangeStream) return next(target) @@ -1202,8 +1198,11 @@ def run_entity_operation(self, spec): try: method = getattr(self, method_name) except AttributeError: + target_opname = camel_to_snake(opname) + if target_opname == "iterate_once": + target_opname = "try_next" try: - cmd = getattr(target, camel_to_snake(opname)) + cmd = getattr(target, target_opname) except AttributeError: self.fail("Unsupported operation %s on entity %s" % (opname, target)) else: