@@ -156,8 +156,10 @@ def test_bind(self):
156156 self .assertEqual (pool .default_timeout , 10 )
157157 self .assertTrue (pool ._sessions .full ())
158158
159+ api = database .spanner_api
160+ self .assertEqual (api .batch_create_sessions .call_count , 5 )
159161 for session in SESSIONS :
160- self . assertTrue ( session ._created )
162+ session .create . assert_not_called ( )
161163
162164 def test_get_non_expired (self ):
163165 pool = self ._make_one (size = 4 )
@@ -183,7 +185,7 @@ def test_get_expired(self):
183185 session = pool .get ()
184186
185187 self .assertIs (session , SESSIONS [4 ])
186- self . assertTrue ( session ._created )
188+ session .create . assert_called ( )
187189 self .assertTrue (SESSIONS [0 ]._exists_checked )
188190 self .assertFalse (pool ._sessions .full ())
189191
@@ -243,8 +245,10 @@ def test_clear(self):
243245 pool .bind (database )
244246 self .assertTrue (pool ._sessions .full ())
245247
248+ api = database .spanner_api
249+ self .assertEqual (api .batch_create_sessions .call_count , 5 )
246250 for session in SESSIONS :
247- self . assertTrue ( session ._created )
251+ session .create . assert_not_called ( )
248252
249253 pool .clear ()
250254
@@ -286,7 +290,7 @@ def test_get_empty(self):
286290
287291 self .assertIsInstance (session , _Session )
288292 self .assertIs (session ._database , database )
289- self . assertTrue ( session ._created )
293+ session .create . assert_called ( )
290294 self .assertTrue (pool ._sessions .empty ())
291295
292296 def test_get_non_empty_session_exists (self ):
@@ -299,7 +303,7 @@ def test_get_non_empty_session_exists(self):
299303 session = pool .get ()
300304
301305 self .assertIs (session , previous )
302- self . assertFalse ( session ._created )
306+ session .create . assert_not_called ( )
303307 self .assertTrue (session ._exists_checked )
304308 self .assertTrue (pool ._sessions .empty ())
305309
@@ -316,7 +320,7 @@ def test_get_non_empty_session_expired(self):
316320
317321 self .assertTrue (previous ._exists_checked )
318322 self .assertIs (session , newborn )
319- self . assertTrue ( session ._created )
323+ session .create . assert_called ( )
320324 self .assertFalse (session ._exists_checked )
321325 self .assertTrue (pool ._sessions .empty ())
322326
@@ -405,7 +409,6 @@ def test_bind(self):
405409 database = _Database ("name" )
406410 SESSIONS = [_Session (database )] * 10
407411 database ._sessions .extend (SESSIONS )
408-
409412 pool .bind (database )
410413
411414 self .assertIs (pool ._database , database )
@@ -414,8 +417,10 @@ def test_bind(self):
414417 self .assertEqual (pool ._delta .seconds , 3000 )
415418 self .assertTrue (pool ._sessions .full ())
416419
420+ api = database .spanner_api
421+ self .assertEqual (api .batch_create_sessions .call_count , 5 )
417422 for session in SESSIONS :
418- self . assertTrue ( session ._created )
423+ session .create . assert_not_called ( )
419424
420425 def test_get_hit_no_ping (self ):
421426 pool = self ._make_one (size = 4 )
@@ -470,7 +475,7 @@ def test_get_hit_w_ping_expired(self):
470475 session = pool .get ()
471476
472477 self .assertIs (session , SESSIONS [4 ])
473- self . assertTrue ( session ._created )
478+ session .create . assert_called ( )
474479 self .assertTrue (SESSIONS [0 ]._exists_checked )
475480 self .assertFalse (pool ._sessions .full ())
476481
@@ -538,8 +543,10 @@ def test_clear(self):
538543 pool .bind (database )
539544 self .assertTrue (pool ._sessions .full ())
540545
546+ api = database .spanner_api
547+ self .assertEqual (api .batch_create_sessions .call_count , 5 )
541548 for session in SESSIONS :
542- self . assertTrue ( session ._created )
549+ session .create . assert_not_called ( )
543550
544551 pool .clear ()
545552
@@ -595,7 +602,7 @@ def test_ping_oldest_stale_and_not_exists(self):
595602 pool .ping ()
596603
597604 self .assertTrue (SESSIONS [0 ]._exists_checked )
598- self . assertTrue ( SESSIONS [1 ]._created )
605+ SESSIONS [1 ].create . assert_called ( )
599606
600607
601608class TestTransactionPingingPool (unittest .TestCase ):
@@ -635,7 +642,6 @@ def test_bind(self):
635642 database = _Database ("name" )
636643 SESSIONS = [_Session (database ) for _ in range (10 )]
637644 database ._sessions .extend (SESSIONS )
638-
639645 pool .bind (database )
640646
641647 self .assertIs (pool ._database , database )
@@ -644,8 +650,10 @@ def test_bind(self):
644650 self .assertEqual (pool ._delta .seconds , 3000 )
645651 self .assertTrue (pool ._sessions .full ())
646652
653+ api = database .spanner_api
654+ self .assertEqual (api .batch_create_sessions .call_count , 5 )
647655 for session in SESSIONS :
648- self . assertTrue ( session ._created )
656+ session .create . assert_not_called ( )
649657 txn = session ._transaction
650658 self .assertTrue (txn ._begun )
651659
@@ -671,8 +679,10 @@ def test_bind_w_timestamp_race(self):
671679 self .assertEqual (pool ._delta .seconds , 3000 )
672680 self .assertTrue (pool ._sessions .full ())
673681
682+ api = database .spanner_api
683+ self .assertEqual (api .batch_create_sessions .call_count , 5 )
674684 for session in SESSIONS :
675- self . assertTrue ( session ._created )
685+ session .create . assert_not_called ( )
676686 txn = session ._transaction
677687 self .assertTrue (txn ._begun )
678688
@@ -843,16 +853,13 @@ def __init__(self, database, exists=True, transaction=None):
843853 self ._database = database
844854 self ._exists = exists
845855 self ._exists_checked = False
846- self ._created = False
856+ self .create = mock . Mock ()
847857 self ._deleted = False
848858 self ._transaction = transaction
849859
850860 def __lt__ (self , other ):
851861 return id (self ) < id (other )
852862
853- def create (self ):
854- self ._created = True
855-
856863 def exists (self ):
857864 self ._exists_checked = True
858865 return self ._exists
@@ -874,6 +881,22 @@ def __init__(self, name):
874881 self .name = name
875882 self ._sessions = []
876883
884+ def mock_batch_create_sessions (db , session_count = 10 , timeout = 10 , metadata = []):
885+ from google .cloud .spanner_v1 .proto import spanner_pb2
886+
887+ response = spanner_pb2 .BatchCreateSessionsResponse ()
888+ if session_count < 2 :
889+ response .session .add ()
890+ else :
891+ response .session .add ()
892+ response .session .add ()
893+ return response
894+
895+ from google .cloud .spanner_v1 .gapic .spanner_client import SpannerClient
896+
897+ self .spanner_api = mock .create_autospec (SpannerClient , instance = True )
898+ self .spanner_api .batch_create_sessions .side_effect = mock_batch_create_sessions
899+
877900 def session (self ):
878901 return self ._sessions .pop ()
879902
0 commit comments