Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spanner/google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def put(self, session):
raise queue.Full

txn = session._transaction
if txn is None or txn.committed() or txn._rolled_back:
if txn is None or txn.committed or txn._rolled_back:
session.transaction()
self._pending_sessions.put(session)
else:
Expand Down
32 changes: 14 additions & 18 deletions spanner/tests/unit/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def test_bind(self):
for session in SESSIONS:
session.create.assert_not_called()
txn = session._transaction
self.assertTrue(txn._begun)
txn.begin.assert_called_once_with()

self.assertTrue(pool._pending_sessions.empty())

Expand Down Expand Up @@ -684,7 +684,7 @@ def test_bind_w_timestamp_race(self):
for session in SESSIONS:
session.create.assert_not_called()
txn = session._transaction
self.assertTrue(txn._begun)
txn.begin.assert_called_once_with()

self.assertTrue(pool._pending_sessions.empty())

Expand Down Expand Up @@ -717,7 +717,7 @@ def test_put_non_full_w_active_txn(self):
self.assertIs(queued, session)

self.assertEqual(len(pending._items), 0)
self.assertFalse(txn._begun)
txn.begin.assert_not_called()

def test_put_non_full_w_committed_txn(self):
pool = self._make_one(size=1)
Expand All @@ -726,7 +726,7 @@ def test_put_non_full_w_committed_txn(self):
database = _Database("name")
session = _Session(database)
committed = session.transaction()
committed._committed = True
committed.committed = True

pool.put(session)

Expand All @@ -735,7 +735,7 @@ def test_put_non_full_w_committed_txn(self):
self.assertEqual(len(pending._items), 1)
self.assertIs(pending._items[0], session)
self.assertIsNot(session._transaction, committed)
self.assertFalse(session._transaction._begun)
session._transaction.begin.assert_not_called()

def test_put_non_full(self):
pool = self._make_one(size=1)
Expand All @@ -761,7 +761,7 @@ def test_begin_pending_transactions_non_empty(self):
pool._sessions = _Queue()

database = _Database("name")
TRANSACTIONS = [_Transaction()]
TRANSACTIONS = [_make_transaction(object())]
PENDING_SESSIONS = [_Session(database, transaction=txn) for txn in TRANSACTIONS]

pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS)
Expand All @@ -770,7 +770,7 @@ def test_begin_pending_transactions_non_empty(self):
pool.begin_pending_transactions() # no raise

for txn in TRANSACTIONS:
self.assertTrue(txn._begun)
txn.begin.assert_called_once_with()

self.assertTrue(pending.empty())

Expand Down Expand Up @@ -831,17 +831,13 @@ def test_context_manager_w_kwargs(self):
self.assertEqual(pool._got, {"foo": "bar"})


class _Transaction(object):
def _make_transaction(*args, **kw):
from google.cloud.spanner_v1.transaction import Transaction

_begun = False
_committed = False
_rolled_back = False

def begin(self):
self._begun = True

def committed(self):
return self._committed
txn = mock.create_autospec(Transaction)(*args, **kw)
txn.committed = None
txn._rolled_back = False
return txn


@total_ordering
Expand Down Expand Up @@ -872,7 +868,7 @@ def delete(self):
raise NotFound("unknown session")

def transaction(self):
txn = self._transaction = _Transaction()
txn = self._transaction = _make_transaction(self)
return txn


Expand Down