Skip to content

Commit 55a920e

Browse files
authored
feat(spanner): add batch_create_session calls to session pools (#9488)
* Update session pools to use batch_create_sessions * Update session pool unit tests to handle batch_create_session calls * Fix where PingingPool sessions are added to in batch_create_session * Remove unnecessary variable from FixedSizePool bind() * Remove unused import * Apply lint formatting to test_pool.py * Update 'batch_create_sessions' to remove session_count keyword
1 parent 4644da1 commit 55a920e

File tree

2 files changed

+70
-27
lines changed

2 files changed

+70
-27
lines changed

spanner/google/cloud/spanner_v1/pool.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import datetime
1818

1919
from six.moves import queue
20-
from six.moves import xrange
2120

2221
from google.cloud.exceptions import NotFound
22+
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
2323

2424

2525
_NOW = datetime.datetime.utcnow # unit tests may replace
@@ -166,11 +166,20 @@ def bind(self, database):
166166
when needed.
167167
"""
168168
self._database = database
169+
api = database.spanner_api
170+
metadata = _metadata_with_prefix(database.name)
169171

170172
while not self._sessions.full():
171-
session = self._new_session()
172-
session.create()
173-
self._sessions.put(session)
173+
resp = api.batch_create_sessions(
174+
database.name,
175+
self.size - self._sessions.qsize(),
176+
timeout=self.default_timeout,
177+
metadata=metadata,
178+
)
179+
for session_pb in resp.session:
180+
session = self._new_session()
181+
session._session_id = session_pb.name.split("/")[-1]
182+
self._sessions.put(session)
174183

175184
def get(self, timeout=None): # pylint: disable=arguments-differ
176185
"""Check a session out from the pool.
@@ -350,11 +359,22 @@ def bind(self, database):
350359
when needed.
351360
"""
352361
self._database = database
353-
354-
for _ in xrange(self.size):
355-
session = self._new_session()
356-
session.create()
357-
self.put(session)
362+
api = database.spanner_api
363+
metadata = _metadata_with_prefix(database.name)
364+
created_session_count = 0
365+
366+
while created_session_count < self.size:
367+
resp = api.batch_create_sessions(
368+
database.name,
369+
self.size - created_session_count,
370+
timeout=self.default_timeout,
371+
metadata=metadata,
372+
)
373+
for session_pb in resp.session:
374+
session = self._new_session()
375+
session._session_id = session_pb.name.split("/")[-1]
376+
self.put(session)
377+
created_session_count += len(resp.session)
358378

359379
def get(self, timeout=None): # pylint: disable=arguments-differ
360380
"""Check a session out from the pool.

spanner/tests/unit/test_pool.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def test_bind(self):
156156
self.assertEqual(pool.default_timeout, 10)
157157
self.assertTrue(pool._sessions.full())
158158

159+
api = database.spanner_api
160+
self.assertEqual(api.batch_create_sessions.call_count, 5)
159161
for session in SESSIONS:
160-
self.assertTrue(session._created)
162+
session.create.assert_not_called()
161163

162164
def test_get_non_expired(self):
163165
pool = self._make_one(size=4)
@@ -183,7 +185,7 @@ def test_get_expired(self):
183185
session = pool.get()
184186

185187
self.assertIs(session, SESSIONS[4])
186-
self.assertTrue(session._created)
188+
session.create.assert_called()
187189
self.assertTrue(SESSIONS[0]._exists_checked)
188190
self.assertFalse(pool._sessions.full())
189191

@@ -243,8 +245,10 @@ def test_clear(self):
243245
pool.bind(database)
244246
self.assertTrue(pool._sessions.full())
245247

248+
api = database.spanner_api
249+
self.assertEqual(api.batch_create_sessions.call_count, 5)
246250
for session in SESSIONS:
247-
self.assertTrue(session._created)
251+
session.create.assert_not_called()
248252

249253
pool.clear()
250254

@@ -286,7 +290,7 @@ def test_get_empty(self):
286290

287291
self.assertIsInstance(session, _Session)
288292
self.assertIs(session._database, database)
289-
self.assertTrue(session._created)
293+
session.create.assert_called()
290294
self.assertTrue(pool._sessions.empty())
291295

292296
def test_get_non_empty_session_exists(self):
@@ -299,7 +303,7 @@ def test_get_non_empty_session_exists(self):
299303
session = pool.get()
300304

301305
self.assertIs(session, previous)
302-
self.assertFalse(session._created)
306+
session.create.assert_not_called()
303307
self.assertTrue(session._exists_checked)
304308
self.assertTrue(pool._sessions.empty())
305309

@@ -316,7 +320,7 @@ def test_get_non_empty_session_expired(self):
316320

317321
self.assertTrue(previous._exists_checked)
318322
self.assertIs(session, newborn)
319-
self.assertTrue(session._created)
323+
session.create.assert_called()
320324
self.assertFalse(session._exists_checked)
321325
self.assertTrue(pool._sessions.empty())
322326

@@ -405,7 +409,6 @@ def test_bind(self):
405409
database = _Database("name")
406410
SESSIONS = [_Session(database)] * 10
407411
database._sessions.extend(SESSIONS)
408-
409412
pool.bind(database)
410413

411414
self.assertIs(pool._database, database)
@@ -414,8 +417,10 @@ def test_bind(self):
414417
self.assertEqual(pool._delta.seconds, 3000)
415418
self.assertTrue(pool._sessions.full())
416419

420+
api = database.spanner_api
421+
self.assertEqual(api.batch_create_sessions.call_count, 5)
417422
for session in SESSIONS:
418-
self.assertTrue(session._created)
423+
session.create.assert_not_called()
419424

420425
def test_get_hit_no_ping(self):
421426
pool = self._make_one(size=4)
@@ -470,7 +475,7 @@ def test_get_hit_w_ping_expired(self):
470475
session = pool.get()
471476

472477
self.assertIs(session, SESSIONS[4])
473-
self.assertTrue(session._created)
478+
session.create.assert_called()
474479
self.assertTrue(SESSIONS[0]._exists_checked)
475480
self.assertFalse(pool._sessions.full())
476481

@@ -538,8 +543,10 @@ def test_clear(self):
538543
pool.bind(database)
539544
self.assertTrue(pool._sessions.full())
540545

546+
api = database.spanner_api
547+
self.assertEqual(api.batch_create_sessions.call_count, 5)
541548
for session in SESSIONS:
542-
self.assertTrue(session._created)
549+
session.create.assert_not_called()
543550

544551
pool.clear()
545552

@@ -595,7 +602,7 @@ def test_ping_oldest_stale_and_not_exists(self):
595602
pool.ping()
596603

597604
self.assertTrue(SESSIONS[0]._exists_checked)
598-
self.assertTrue(SESSIONS[1]._created)
605+
SESSIONS[1].create.assert_called()
599606

600607

601608
class TestTransactionPingingPool(unittest.TestCase):
@@ -635,7 +642,6 @@ def test_bind(self):
635642
database = _Database("name")
636643
SESSIONS = [_Session(database) for _ in range(10)]
637644
database._sessions.extend(SESSIONS)
638-
639645
pool.bind(database)
640646

641647
self.assertIs(pool._database, database)
@@ -644,8 +650,10 @@ def test_bind(self):
644650
self.assertEqual(pool._delta.seconds, 3000)
645651
self.assertTrue(pool._sessions.full())
646652

653+
api = database.spanner_api
654+
self.assertEqual(api.batch_create_sessions.call_count, 5)
647655
for session in SESSIONS:
648-
self.assertTrue(session._created)
656+
session.create.assert_not_called()
649657
txn = session._transaction
650658
self.assertTrue(txn._begun)
651659

@@ -671,8 +679,10 @@ def test_bind_w_timestamp_race(self):
671679
self.assertEqual(pool._delta.seconds, 3000)
672680
self.assertTrue(pool._sessions.full())
673681

682+
api = database.spanner_api
683+
self.assertEqual(api.batch_create_sessions.call_count, 5)
674684
for session in SESSIONS:
675-
self.assertTrue(session._created)
685+
session.create.assert_not_called()
676686
txn = session._transaction
677687
self.assertTrue(txn._begun)
678688

@@ -843,16 +853,13 @@ def __init__(self, database, exists=True, transaction=None):
843853
self._database = database
844854
self._exists = exists
845855
self._exists_checked = False
846-
self._created = False
856+
self.create = mock.Mock()
847857
self._deleted = False
848858
self._transaction = transaction
849859

850860
def __lt__(self, other):
851861
return id(self) < id(other)
852862

853-
def create(self):
854-
self._created = True
855-
856863
def exists(self):
857864
self._exists_checked = True
858865
return self._exists
@@ -874,6 +881,22 @@ def __init__(self, name):
874881
self.name = name
875882
self._sessions = []
876883

884+
def mock_batch_create_sessions(db, session_count=10, timeout=10, metadata=[]):
885+
from google.cloud.spanner_v1.proto import spanner_pb2
886+
887+
response = spanner_pb2.BatchCreateSessionsResponse()
888+
if session_count < 2:
889+
response.session.add()
890+
else:
891+
response.session.add()
892+
response.session.add()
893+
return response
894+
895+
from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient
896+
897+
self.spanner_api = mock.create_autospec(SpannerClient, instance=True)
898+
self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions
899+
877900
def session(self):
878901
return self._sessions.pop()
879902

0 commit comments

Comments
 (0)