diff --git a/spanner/google/cloud/spanner_v1/pool.py b/spanner/google/cloud/spanner_v1/pool.py index 823681fbc864..e98becbb3a19 100644 --- a/spanner/google/cloud/spanner_v1/pool.py +++ b/spanner/google/cloud/spanner_v1/pool.py @@ -156,7 +156,7 @@ def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT, labels=No super(FixedSizePool, self).__init__(labels=labels) self.size = size self.default_timeout = default_timeout - self._sessions = queue.Queue(size) + self._sessions = queue.LifoQueue(size) def bind(self, database): """Associate the pool with a database. @@ -242,7 +242,7 @@ def __init__(self, target_size=10, labels=None): super(BurstyPool, self).__init__(labels=labels) self.target_size = target_size self._database = None - self._sessions = queue.Queue(target_size) + self._sessions = queue.LifoQueue(target_size) def bind(self, database): """Associate the pool with a database. diff --git a/spanner/tests/unit/test_pool.py b/spanner/tests/unit/test_pool.py index 549044b1f423..40b2cb49749f 100644 --- a/spanner/tests/unit/test_pool.py +++ b/spanner/tests/unit/test_pool.py @@ -162,15 +162,16 @@ def test_bind(self): def test_get_non_expired(self): pool = self._make_one(size=4) database = _Database("name") - SESSIONS = [_Session(database)] * 4 + SESSIONS = sorted([_Session(database) for i in range(0, 4)]) database._sessions.extend(SESSIONS) pool.bind(database) - session = pool.get() - - self.assertIs(session, SESSIONS[0]) - self.assertTrue(session._exists_checked) - self.assertFalse(pool._sessions.full()) + # check if sessions returned in LIFO order + for i in (3, 2, 1, 0): + session = pool.get() + self.assertIs(session, SESSIONS[i]) + self.assertTrue(session._exists_checked) + self.assertFalse(pool._sessions.full()) def test_get_expired(self): pool = self._make_one(size=4) @@ -875,7 +876,10 @@ def __init__(self, name): self._sessions = [] def session(self): - return self._sessions.pop() + # always return first session in the list + # to avoid reversing the order of putting + # sessions into pool (important for order tests) + return self._sessions.pop(0) class _Queue(object):