|
18 | 18 | from google.cloud.spanner_v1 import CommitRequest |
19 | 19 | from google.cloud.spanner_v1 import Mutation |
20 | 20 | from google.cloud.spanner_v1 import TransactionOptions |
| 21 | +from google.cloud.spanner_v1 import BatchWriteRequest |
21 | 22 |
|
22 | 23 | from google.cloud.spanner_v1._helpers import _SessionWrapper |
23 | 24 | from google.cloud.spanner_v1._helpers import _make_list_value_pbs |
@@ -215,6 +216,99 @@ def __exit__(self, exc_type, exc_val, exc_tb): |
215 | 216 | self.commit() |
216 | 217 |
|
217 | 218 |
|
| 219 | +class MutationGroup(_BatchBase): |
| 220 | + """A container for mutations. |
| 221 | +
|
| 222 | + Clients should use :class:`~google.cloud.spanner_v1.MutationGroups` to |
| 223 | + obtain instances instead of directly creating instances. |
| 224 | +
|
| 225 | + :type session: :class:`~google.cloud.spanner_v1.session.Session` |
| 226 | + :param session: The session used to perform the commit. |
| 227 | +
|
| 228 | + :type mutations: list |
| 229 | + :param mutations: The list into which mutations are to be accumulated. |
| 230 | + """ |
| 231 | + |
| 232 | + def __init__(self, session, mutations=[]): |
| 233 | + super(MutationGroup, self).__init__(session) |
| 234 | + self._mutations = mutations |
| 235 | + |
| 236 | + |
| 237 | +class MutationGroups(_SessionWrapper): |
| 238 | + """Accumulate mutation groups for transmission during :meth:`batch_write`. |
| 239 | +
|
| 240 | + :type session: :class:`~google.cloud.spanner_v1.session.Session` |
| 241 | + :param session: the session used to perform the commit |
| 242 | + """ |
| 243 | + |
| 244 | + committed = None |
| 245 | + |
| 246 | + def __init__(self, session): |
| 247 | + super(MutationGroups, self).__init__(session) |
| 248 | + self._mutation_groups = [] |
| 249 | + |
| 250 | + def _check_state(self): |
| 251 | + """Checks if the object's state is valid for making API requests. |
| 252 | +
|
| 253 | + :raises: :exc:`ValueError` if the object's state is invalid for making |
| 254 | + API requests. |
| 255 | + """ |
| 256 | + if self.committed is not None: |
| 257 | + raise ValueError("MutationGroups already committed") |
| 258 | + |
| 259 | + def group(self): |
| 260 | + """Returns a new `MutationGroup` to which mutations can be added.""" |
| 261 | + mutation_group = BatchWriteRequest.MutationGroup() |
| 262 | + self._mutation_groups.append(mutation_group) |
| 263 | + return MutationGroup(self._session, mutation_group.mutations) |
| 264 | + |
| 265 | + def batch_write(self, request_options=None): |
| 266 | + """Executes batch_write. |
| 267 | +
|
| 268 | + :type request_options: |
| 269 | + :class:`google.cloud.spanner_v1.types.RequestOptions` |
| 270 | + :param request_options: |
| 271 | + (Optional) Common options for this request. |
| 272 | + If a dict is provided, it must be of the same form as the protobuf |
| 273 | + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. |
| 274 | +
|
| 275 | + :rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]` |
| 276 | + :returns: a sequence of responses for each batch. |
| 277 | + """ |
| 278 | + self._check_state() |
| 279 | + |
| 280 | + database = self._session._database |
| 281 | + api = database.spanner_api |
| 282 | + metadata = _metadata_with_prefix(database.name) |
| 283 | + if database._route_to_leader_enabled: |
| 284 | + metadata.append( |
| 285 | + _metadata_with_leader_aware_routing(database._route_to_leader_enabled) |
| 286 | + ) |
| 287 | + trace_attributes = {"num_mutation_groups": len(self._mutation_groups)} |
| 288 | + if request_options is None: |
| 289 | + request_options = RequestOptions() |
| 290 | + elif type(request_options) is dict: |
| 291 | + request_options = RequestOptions(request_options) |
| 292 | + |
| 293 | + request = BatchWriteRequest( |
| 294 | + session=self._session.name, |
| 295 | + mutation_groups=self._mutation_groups, |
| 296 | + request_options=request_options, |
| 297 | + ) |
| 298 | + with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes): |
| 299 | + method = functools.partial( |
| 300 | + api.batch_write, |
| 301 | + request=request, |
| 302 | + metadata=metadata, |
| 303 | + ) |
| 304 | + response = _retry( |
| 305 | + method, |
| 306 | + allowed_exceptions={InternalServerError: _check_rst_stream_error}, |
| 307 | + ) |
| 308 | + self.committed = True |
| 309 | + return response |
| 310 | + |
| 311 | + |
218 | 312 | def _make_write_pb(table, columns, values): |
219 | 313 | """Helper for :meth:`Batch.insert` et al. |
220 | 314 |
|
|
0 commit comments