diff --git a/pymongo/topology.py b/pymongo/topology.py index c4eb947b64..a101ec9c1c 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -429,20 +429,30 @@ def request_check_all(self, wait_time=5): self._request_check_all() self._condition.wait(wait_time) + def data_bearing_servers(self): + """Return a list of all data-bearing servers. + + This includes any server that might be selected for an operation. + """ + if self._description.topology_type == TOPOLOGY_TYPE.Single: + return self._description.known_servers + return self._description.readable_servers + def update_pool(self, all_credentials): # Remove any stale sockets and add new sockets if pool is too small. servers = [] with self._lock: - for server in self._servers.values(): - servers.append((server, server._pool.generation)) + # Only update pools for data-bearing servers. + for sd in self.data_bearing_servers(): + server = self._servers[sd.address] + servers.append((server, server.pool.generation)) for server, generation in servers: - pool = server._pool try: - pool.remove_stale_sockets(generation, all_credentials) + server.pool.remove_stale_sockets(generation, all_credentials) except PyMongoError as exc: ctx = _ErrorContext(exc, 0, generation, False) - self.handle_error(pool.address, ctx) + self.handle_error(server.description.address, ctx) raise def close(self): diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 7520f2bfb1..8a28284bf5 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -106,13 +106,13 @@ def _check_once(self): class MockClient(MongoClient): def __init__( self, standalones, members, mongoses, ismaster_hosts=None, - *args, **kwargs): + arbiters=None, down_hosts=None, *args, **kwargs): """A MongoClient connected to the default server, with a mock topology. - standalones, members, mongoses determine the configuration of the - topology. They are formatted like ['a:1', 'b:2']. ismaster_hosts - provides an alternative host list for the server's mocked ismaster - response; see test_connect_with_internal_ips. + standalones, members, mongoses, arbiters, and down_hosts determine the + configuration of the topology. They are formatted like ['a:1', 'b:2']. + ismaster_hosts provides an alternative host list for the server's + mocked ismaster response; see test_connect_with_internal_ips. """ self.mock_standalones = standalones[:] self.mock_members = members[:] @@ -122,6 +122,9 @@ def __init__( else: self.mock_primary = None + # Hosts that should be considered an arbiter. + self.mock_arbiters = arbiters[:] if arbiters else [] + if ismaster_hosts is not None: self.mock_ismaster_hosts = ismaster_hosts else: @@ -130,7 +133,7 @@ def __init__( self.mock_mongoses = mongoses[:] # Hosts that should raise socket errors. - self.mock_down_hosts = [] + self.mock_down_hosts = down_hosts[:] if down_hosts else [] # Hostname -> (min wire version, max wire version) self.mock_wire_versions = {} @@ -203,6 +206,10 @@ def mock_is_master(self, host): if self.mock_primary: response['primary'] = self.mock_primary + + if host in self.mock_arbiters: + response['arbiterOnly'] = True + response['secondary'] = False elif host in self.mock_mongoses: response = { 'ok': 1, diff --git a/test/test_client.py b/test/test_client.py index dd5e37392a..b55db450ec 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -35,7 +35,7 @@ from bson.son import SON from bson.tz_util import utc import pymongo -from pymongo import message +from pymongo import message, monitoring from pymongo.common import CONNECT_TIMEOUT, _UUID_REPRESENTATIONS from pymongo.command_cursor import CommandCursor from pymongo.compression_support import _HAVE_SNAPPY, _HAVE_ZSTD @@ -57,7 +57,7 @@ from pymongo.pool import SocketInfo, _METADATA from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription -from pymongo.server_selectors import (any_server_selector, +from pymongo.server_selectors import (readable_server_selector, writable_server_selector) from pymongo.server_type import SERVER_TYPE from pymongo.settings import TOPOLOGY_TYPE @@ -77,6 +77,7 @@ from test.pymongo_mocks import MockClient from test.utils import (assertRaisesExactly, connected, + CMAPListener, delay, FunctionCallRecorder, get_pool, @@ -448,21 +449,25 @@ def test_uri_security_options(self): class TestClient(IntegrationTest): - def test_max_idle_time_reaper(self): + def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove sockets when maxIdleTimeMS not set client = rs_or_single_client() - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) self.assertTrue(sock_info in server._pool.sockets) client.close() + def test_max_idle_time_reaper_removes_stale_minPoolSize(self): + with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) with server._pool.get_socket({}) as sock_info: pass # When the reaper runs at the same time as the get_socket, two @@ -474,11 +479,14 @@ def test_max_idle_time_reaper(self): "replace stale socket") client.close() + def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): + with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new sockets. client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) with server._pool.get_socket({}) as sock_info: pass # When the reaper runs at the same time as the get_socket, @@ -490,9 +498,12 @@ def test_max_idle_time_reaper(self): "replace stale socket") client.close() + def test_max_idle_time_reaper_removes_stale(self): + with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it client = rs_or_single_client(maxIdleTimeMS=500) - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) with server._pool.get_socket({}) as sock_info_one: pass # Assert that the pool does not close sockets prematurely. @@ -508,12 +519,14 @@ def test_max_idle_time_reaper(self): def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=.1): client = rs_or_single_client() - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) self.assertEqual(0, len(server._pool.sockets)) # Assert that pool started up at minPoolSize client = rs_or_single_client(minPoolSize=10) - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") @@ -528,7 +541,8 @@ def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = rs_or_single_client(maxIdleTimeMS=500) - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) @@ -542,7 +556,8 @@ def test_max_idle_time_checkout(self): # Test that sockets are reused if maxIdleTimeMS is not set. client = rs_or_single_client() - server = client._get_topology().select_server(any_server_selector) + server = client._get_topology().select_server( + readable_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) @@ -1944,5 +1959,61 @@ def timeout_task(): self.assertIsNone(ct.get()) +class TestClientPool(MockClientTest): + + def test_rs_client_does_not_maintain_pool_to_arbiters(self): + listener = CMAPListener() + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3', 'd:4'], + mongoses=[], + arbiters=['c:3'], # c:3 is an arbiter. + down_hosts=['d:4'], # d:4 is unreachable. + host=['a:1', 'b:2', 'c:3', 'd:4'], + replicaSet='rs', + minPoolSize=1, # minPoolSize + event_listeners=[listener], + ) + self.addCleanup(c.close) + + wait_until(lambda: len(c.nodes) == 3, 'connect') + self.assertEqual(c.address, ('a', 1)) + self.assertEqual(c.arbiters, set([('c', 3)])) + # Assert that we create 2 and only 2 pooled connections. + listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) + self.assertEqual( + listener.event_count(monitoring.ConnectionCreatedEvent), 2) + # Assert that we do not create connections to arbiters. + arbiter = c._topology.get_server_by_address(('c', 3)) + self.assertFalse(arbiter.pool.sockets) + # Assert that we do not create connections to unknown servers. + arbiter = c._topology.get_server_by_address(('d', 4)) + self.assertFalse(arbiter.pool.sockets) + + def test_direct_client_maintains_pool_to_arbiter(self): + listener = CMAPListener() + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + arbiters=['c:3'], # c:3 is an arbiter. + host='c:3', + directConnection=True, + minPoolSize=1, # minPoolSize + event_listeners=[listener], + ) + self.addCleanup(c.close) + + print(c.topology_description) + wait_until(lambda: len(c.nodes) == 1, 'connect') + self.assertEqual(c.address, ('c', 3)) + # Assert that we create 1 pooled connection. + listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) + self.assertEqual( + listener.event_count(monitoring.ConnectionCreatedEvent), 1) + arbiter = c._topology.get_server_by_address(('c', 3)) + self.assertEqual(len(arbiter.pool.sockets), 1) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_cmap.py b/test/test_cmap.py index bf5328fcda..b4a14bb97c 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -41,6 +41,7 @@ PoolClosedEvent) from pymongo.read_preferences import ReadPreference from pymongo.pool import _PoolClosedError, PoolState +from pymongo.topology_description import updated_topology_description from test import (client_knobs, IntegrationTest, @@ -226,12 +227,23 @@ def run_scenario(self, scenario_def, test): opts = test['poolOptions'].copy() opts['event_listeners'] = [self.listener] opts['_monitor_class'] = DummyMonitor + opts['connect'] = False with client_knobs(kill_cursor_frequency=.05, min_heartbeat_interval=.05): client = single_client(**opts) + # Update the SD to a known type because the DummyMonitor will not. + # Note we cannot simply call topology.on_change because that would + # internally call pool.ready() which introduces unexpected + # PoolReadyEvents. Instead, update the initial state before + # opening the Topology. + td = client_context.client._topology.description + sd = td.server_descriptions()[(client_context.host, + client_context.port)] + client._topology._description = updated_topology_description( + client._topology._description, sd) + client._get_topology() self.addCleanup(client.close) - # self.pool = get_pools(client)[0] - self.pool = list(client._get_topology()._servers.values())[0].pool + self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. self.targets = dict()