diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 730190444edf..d6ce4e622580 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -24,9 +24,11 @@ import uuid import pytest +from google.rpc import code_pb2 from google.api_core import exceptions from google.api_core.datetime_helpers import DatetimeWithNanoseconds + from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1.proto.type_pb2 import ARRAY from google.cloud.spanner_v1.proto.type_pb2 import BOOL @@ -776,6 +778,11 @@ def test_transaction_execute_update_then_insert_commit(self): # [END spanner_test_dml_update] # [END spanner_test_dml_with_mutation] + @staticmethod + def _check_batch_status(status_code): + if status_code != code_pb2.OK: + raise exceptions.from_grpc_status(status_code, "batch_update failed") + def test_transaction_batch_update_success(self): # [START spanner_test_dml_with_mutation] # [START spanner_test_dml_update] @@ -808,7 +815,7 @@ def unit_of_work(transaction, self): status, row_counts = transaction.batch_update( [insert_statement, update_statement, delete_statement] ) - self.assertEqual(status.code, 0) # XXX: where are values defined? + self._check_batch_status(status.code) self.assertEqual(len(row_counts), 3) for row_count in row_counts: self.assertEqual(row_count, 1) @@ -849,7 +856,7 @@ def unit_of_work(transaction, self): status, row_counts = transaction.batch_update( insert_statements + update_statements ) - self.assertEqual(status.code, 0) # XXX: where are values defined? + self._check_batch_status(status.code) self.assertEqual(len(row_counts), len(insert_statements) + 1) for row_count in row_counts: self.assertEqual(row_count, 1)