diff --git a/spanner/google/cloud/spanner_v1/pool.py b/spanner/google/cloud/spanner_v1/pool.py index 4ef5aee9baab..3cd57f056d41 100644 --- a/spanner/google/cloud/spanner_v1/pool.py +++ b/spanner/google/cloud/spanner_v1/pool.py @@ -503,7 +503,7 @@ def put(self, session): raise queue.Full txn = session._transaction - if txn is None or txn.committed() or txn._rolled_back: + if txn is None or txn.committed or txn._rolled_back: session.transaction() self._pending_sessions.put(session) else: diff --git a/spanner/tests/unit/test_pool.py b/spanner/tests/unit/test_pool.py index eded02ea4e6d..6704c2476486 100644 --- a/spanner/tests/unit/test_pool.py +++ b/spanner/tests/unit/test_pool.py @@ -655,7 +655,7 @@ def test_bind(self): for session in SESSIONS: session.create.assert_not_called() txn = session._transaction - self.assertTrue(txn._begun) + txn.begin.assert_called_once_with() self.assertTrue(pool._pending_sessions.empty()) @@ -684,7 +684,7 @@ def test_bind_w_timestamp_race(self): for session in SESSIONS: session.create.assert_not_called() txn = session._transaction - self.assertTrue(txn._begun) + txn.begin.assert_called_once_with() self.assertTrue(pool._pending_sessions.empty()) @@ -717,7 +717,7 @@ def test_put_non_full_w_active_txn(self): self.assertIs(queued, session) self.assertEqual(len(pending._items), 0) - self.assertFalse(txn._begun) + txn.begin.assert_not_called() def test_put_non_full_w_committed_txn(self): pool = self._make_one(size=1) @@ -726,7 +726,7 @@ def test_put_non_full_w_committed_txn(self): database = _Database("name") session = _Session(database) committed = session.transaction() - committed._committed = True + committed.committed = True pool.put(session) @@ -735,7 +735,7 @@ def test_put_non_full_w_committed_txn(self): self.assertEqual(len(pending._items), 1) self.assertIs(pending._items[0], session) self.assertIsNot(session._transaction, committed) - self.assertFalse(session._transaction._begun) + session._transaction.begin.assert_not_called() def test_put_non_full(self): pool = self._make_one(size=1) @@ -761,7 +761,7 @@ def test_begin_pending_transactions_non_empty(self): pool._sessions = _Queue() database = _Database("name") - TRANSACTIONS = [_Transaction()] + TRANSACTIONS = [_make_transaction(object())] PENDING_SESSIONS = [_Session(database, transaction=txn) for txn in TRANSACTIONS] pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS) @@ -770,7 +770,7 @@ def test_begin_pending_transactions_non_empty(self): pool.begin_pending_transactions() # no raise for txn in TRANSACTIONS: - self.assertTrue(txn._begun) + txn.begin.assert_called_once_with() self.assertTrue(pending.empty()) @@ -831,17 +831,13 @@ def test_context_manager_w_kwargs(self): self.assertEqual(pool._got, {"foo": "bar"}) -class _Transaction(object): +def _make_transaction(*args, **kw): + from google.cloud.spanner_v1.transaction import Transaction - _begun = False - _committed = False - _rolled_back = False - - def begin(self): - self._begun = True - - def committed(self): - return self._committed + txn = mock.create_autospec(Transaction)(*args, **kw) + txn.committed = None + txn._rolled_back = False + return txn @total_ordering @@ -872,7 +868,7 @@ def delete(self): raise NotFound("unknown session") def transaction(self): - txn = self._transaction = _Transaction() + txn = self._transaction = _make_transaction(self) return txn