Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pymongo/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from bson import _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import common
from pymongo import _csot, common
from pymongo.aggregation import (
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
Expand Down Expand Up @@ -128,6 +128,8 @@ def __init__(
self._start_at_operation_time = start_at_operation_time
self._session = session
self._comment = comment
self._closed = False
self._timeout = self._target._timeout
# Initialize cursor.
self._cursor = self._create_cursor()

Expand Down Expand Up @@ -234,6 +236,7 @@ def _resume(self):

def close(self) -> None:
"""Close this ChangeStream."""
self._closed = True
self._cursor.close()

def __iter__(self) -> "ChangeStream[_DocumentType]":
Expand All @@ -248,6 +251,7 @@ def resume_token(self) -> Optional[Mapping[str, Any]]:
"""
return copy.deepcopy(self._resume_token)

@_csot.apply
def next(self) -> _DocumentType:
"""Advance the cursor.

Expand Down Expand Up @@ -298,8 +302,9 @@ def alive(self) -> bool:

.. versionadded:: 3.8
"""
return self._cursor.alive
return not self._closed

@_csot.apply
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.

Expand Down Expand Up @@ -332,6 +337,9 @@ def try_next(self) -> Optional[_DocumentType]:

.. versionadded:: 3.8
"""
if not self._closed and not self._cursor.alive:
self._resume()

# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
Expand All @@ -350,6 +358,10 @@ def try_next(self) -> Optional[_DocumentType]:
self._resume()
change = self._cursor._try_next(False)

# Check if the cursor was invalidated.
if not self._cursor.alive:
self._closed = True

# If no changes are available.
if change is None:
# We have either iterated over all documents in the cursor,
Expand Down
14 changes: 6 additions & 8 deletions test/test_change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,16 +486,15 @@ def _get_expected_resume_token(self, stream, listener, previous_change=None):
return response["cursor"]["postBatchResumeToken"]

@no_type_check
def _test_raises_error_on_missing_id(self, expected_exception):
def _test_raises_error_on_missing_id(self, expected_exception, expected_exception2):
"""ChangeStream will raise an exception if the server response is
missing the resume token.
"""
with self.change_stream([{"$project": {"_id": 0}}]) as change_stream:
self.watched_collection().insert_one({})
with self.assertRaises(expected_exception):
next(change_stream)
# The cursor should now be closed.
with self.assertRaises(StopIteration):
with self.assertRaises(expected_exception2):
next(change_stream)

@no_type_check
Expand Down Expand Up @@ -525,17 +524,16 @@ def test_update_resume_token_legacy(self):
self._test_update_resume_token(self._get_expected_resume_token_legacy)

# Prose test no. 2
@client_context.require_version_max(4, 3, 3) # PYTHON-2120
@client_context.require_version_min(4, 1, 8)
def test_raises_error_on_missing_id_418plus(self):
# Server returns an error on 4.1.8+
self._test_raises_error_on_missing_id(OperationFailure)
# Server returns an error on 4.1.8+, subsequent next() resumes and gets the same error.
self._test_raises_error_on_missing_id(OperationFailure, OperationFailure)

# Prose test no. 2
@client_context.require_version_max(4, 1, 8)
def test_raises_error_on_missing_id_418minus(self):
# PyMongo raises an error
self._test_raises_error_on_missing_id(InvalidOperation)
# PyMongo raises an error, closes the cursor, subsequent next() raises StopIteration.
self._test_raises_error_on_missing_id(InvalidOperation, StopIteration)

# Prose test no. 3
@no_type_check
Expand Down
33 changes: 32 additions & 1 deletion test/test_csot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

sys.path[0:0] = [""]

from test import IntegrationTest, unittest
from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes

import pymongo
from pymongo import _csot
from pymongo.errors import PyMongoError

# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "csot")
Expand Down Expand Up @@ -72,6 +73,36 @@ def test_timeout_nested(self):
self.assertEqual(_csot.get_deadline(), float("inf"))
self.assertEqual(_csot.get_rtt(), 0.0)

@client_context.require_version_min(3, 6)
@client_context.require_no_mmap
@client_context.require_no_standalone
def test_change_stream_can_resume_after_timeouts(self):
coll = self.db.test
with coll.watch(max_await_time_ms=150) as stream:
with pymongo.timeout(0.1):
with self.assertRaises(PyMongoError) as ctx:
stream.try_next()
self.assertTrue(ctx.exception.timeout)
self.assertTrue(stream.alive)
with self.assertRaises(PyMongoError) as ctx:
stream.try_next()
self.assertTrue(ctx.exception.timeout)
self.assertTrue(stream.alive)
# Resume before the insert on 3.6 because 4.0 is required to avoid skipping documents
if client_context.version < (4, 0):
stream.try_next()
coll.insert_one({})
with pymongo.timeout(10):
self.assertTrue(stream.next())
self.assertTrue(stream.alive)
# Timeout applies to entire next() call, not only individual commands.
with pymongo.timeout(0.5):
with self.assertRaises(PyMongoError) as ctx:
stream.next()
self.assertTrue(ctx.exception.timeout)
self.assertTrue(stream.alive)
self.assertFalse(stream.alive)


if __name__ == "__main__":
unittest.main()
9 changes: 4 additions & 5 deletions test/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,6 @@ def _sessionOperation_startTransaction(self, target, *args, **kwargs):
self.__raise_if_unsupported("startTransaction", target, ClientSession)
return target.start_transaction(*args, **kwargs)

def _cursor_iterateOnce(self, target, *args, **kwargs):
self.__raise_if_unsupported("iterateOnce", target, NonLazyCursor, ChangeStream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. Isn't this change and the change below combined a no-op?

Copy link
Member Author

@ShaneHarvey ShaneHarvey Jul 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a no-op. The old code only worked with find() cursors while the new code implements iterateOnce for find() cursors AND change stream cursors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iterate_once is just another name for "try_next" so I think it's fine. This would be the first time we apply a name translation to the op name but we already do similar name transformations when parsing spec args.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be:

def _cursor_iterateOnce(self, target, *args, **kwargs):
    self.__raise_if_unsupported("iterateOnce", target, NonLazyCursor, ChangeStream)
    return target.try_next()

def _changeStreamOperation_iterateOnce(self, target, *args, **kwargs):
    self.__raise_if_unsupported("iterateOnce", target, NonLazyCursor, ChangeStream)
    return target.try_next()

Which seems subpar to me.

return target.try_next()

def _changeStreamOperation_iterateUntilDocumentOrError(self, target, *args, **kwargs):
self.__raise_if_unsupported("iterateUntilDocumentOrError", target, ChangeStream)
return next(target)
Expand Down Expand Up @@ -1202,8 +1198,11 @@ def run_entity_operation(self, spec):
try:
method = getattr(self, method_name)
except AttributeError:
target_opname = camel_to_snake(opname)
if target_opname == "iterate_once":
target_opname = "try_next"
try:
cmd = getattr(target, camel_to_snake(opname))
cmd = getattr(target, target_opname)
except AttributeError:
self.fail("Unsupported operation %s on entity %s" % (opname, target))
else:
Expand Down