2626import google .api_core .gapic_v1 .method
2727import google .api_core .gapic_v1 .routing_header
2828import google .api_core .grpc_helpers
29+ import google .api_core .operation
30+ import google .api_core .operations_v1
2931import google .api_core .path_template
3032import grpc
3133
3234from google .cloud .automl_v1 .gapic import enums
3335from google .cloud .automl_v1 .gapic import prediction_service_client_config
3436from google .cloud .automl_v1 .gapic .transports import prediction_service_grpc_transport
37+ from google .cloud .automl_v1 .proto import annotation_spec_pb2
3538from google .cloud .automl_v1 .proto import data_items_pb2
3639from google .cloud .automl_v1 .proto import dataset_pb2
40+ from google .cloud .automl_v1 .proto import image_pb2
3741from google .cloud .automl_v1 .proto import io_pb2
3842from google .cloud .automl_v1 .proto import model_evaluation_pb2
3943from google .cloud .automl_v1 .proto import model_pb2
@@ -222,8 +226,18 @@ def predict(
222226 returned in the response. Available for following ML problems, and their
223227 expected request payloads:
224228
229+ - Image Classification - Image in .JPEG, .GIF or .PNG format,
230+ image\_bytes up to 30MB.
231+ - Image Object Detection - Image in .JPEG, .GIF or .PNG format,
232+ image\_bytes up to 30MB.
233+ - Text Classification - TextSnippet, content up to 60,000 characters,
234+ UTF-8 encoded.
235+ - Text Extraction - TextSnippet, content up to 30,000 characters, UTF-8
236+ NFC encoded.
225237 - Translation - TextSnippet, content up to 25,000 characters, UTF-8
226238 encoded.
239+ - Text Sentiment - TextSnippet, content up 500 characters, UTF-8
240+ encoded.
227241
228242 Example:
229243 >>> from google.cloud import automl_v1
@@ -246,6 +260,19 @@ def predict(
246260 message :class:`~google.cloud.automl_v1.types.ExamplePayload`
247261 params (dict[str -> str]): Additional domain-specific parameters, any string must be up to 25000
248262 characters long.
263+
264+ - For Image Classification:
265+
266+ ``score_threshold`` - (float) A value from 0.0 to 1.0. When the model
267+ makes predictions for an image, it will only produce results that
268+ have at least this confidence score. The default is 0.5.
269+
270+ - For Image Object Detection: ``score_threshold`` - (float) When Model
271+ detects objects on the image, it will only produce bounding boxes
272+ which have at least this confidence score. Value in 0 to 1 range,
273+ default is 0.5. ``max_bounding_box_count`` - (int64) No more than
274+ this number of bounding boxes will be returned in the response.
275+ Default is 100, the requested value may be limited by server.
249276 retry (Optional[google.api_core.retry.Retry]): A retry object used
250277 to retry requests. If ``None`` is specified, requests will
251278 be retried using a default configuration.
@@ -295,3 +322,142 @@ def predict(
295322 return self ._inner_api_calls ["predict" ](
296323 request , retry = retry , timeout = timeout , metadata = metadata
297324 )
325+
326+ def batch_predict (
327+ self ,
328+ name ,
329+ input_config ,
330+ output_config ,
331+ params = None ,
332+ retry = google .api_core .gapic_v1 .method .DEFAULT ,
333+ timeout = google .api_core .gapic_v1 .method .DEFAULT ,
334+ metadata = None ,
335+ ):
336+ """
337+ Perform a batch prediction. Unlike the online ``Predict``, batch
338+ prediction result won't be immediately available in the response.
339+ Instead, a long running operation object is returned. User can poll the
340+ operation result via ``GetOperation`` method. Once the operation is
341+ done, ``BatchPredictResult`` is returned in the ``response`` field.
342+ Available for following ML problems:
343+
344+ - Image Classification
345+ - Image Object Detection
346+ - Text Extraction
347+
348+ Example:
349+ >>> from google.cloud import automl_v1
350+ >>>
351+ >>> client = automl_v1.PredictionServiceClient()
352+ >>>
353+ >>> name = client.model_path('[PROJECT]', '[LOCATION]', '[MODEL]')
354+ >>>
355+ >>> # TODO: Initialize `input_config`:
356+ >>> input_config = {}
357+ >>>
358+ >>> # TODO: Initialize `output_config`:
359+ >>> output_config = {}
360+ >>>
361+ >>> response = client.batch_predict(name, input_config, output_config)
362+ >>>
363+ >>> def callback(operation_future):
364+ ... # Handle result.
365+ ... result = operation_future.result()
366+ >>>
367+ >>> response.add_done_callback(callback)
368+ >>>
369+ >>> # Handle metadata.
370+ >>> metadata = response.metadata()
371+
372+ Args:
373+ name (str): Name of the model requested to serve the batch prediction.
374+ input_config (Union[dict, ~google.cloud.automl_v1.types.BatchPredictInputConfig]): Required. The input configuration for batch prediction.
375+
376+ If a dict is provided, it must be of the same form as the protobuf
377+ message :class:`~google.cloud.automl_v1.types.BatchPredictInputConfig`
378+ output_config (Union[dict, ~google.cloud.automl_v1.types.BatchPredictOutputConfig]): Required. The Configuration specifying where output predictions should
379+ be written.
380+
381+ If a dict is provided, it must be of the same form as the protobuf
382+ message :class:`~google.cloud.automl_v1.types.BatchPredictOutputConfig`
383+ params (dict[str -> str]): Additional domain-specific parameters for the predictions, any string
384+ must be up to 25000 characters long.
385+
386+ - For Text Classification:
387+
388+ ``score_threshold`` - (float) A value from 0.0 to 1.0. When the model
389+ makes predictions for a text snippet, it will only produce results
390+ that have at least this confidence score. The default is 0.5.
391+
392+ - For Image Classification:
393+
394+ ``score_threshold`` - (float) A value from 0.0 to 1.0. When the model
395+ makes predictions for an image, it will only produce results that
396+ have at least this confidence score. The default is 0.5.
397+
398+ - For Image Object Detection:
399+
400+ ``score_threshold`` - (float) When Model detects objects on the
401+ image, it will only produce bounding boxes which have at least this
402+ confidence score. Value in 0 to 1 range, default is 0.5.
403+ ``max_bounding_box_count`` - (int64) No more than this number of
404+ bounding boxes will be produced per image. Default is 100, the
405+ requested value may be limited by server.
406+ retry (Optional[google.api_core.retry.Retry]): A retry object used
407+ to retry requests. If ``None`` is specified, requests will
408+ be retried using a default configuration.
409+ timeout (Optional[float]): The amount of time, in seconds, to wait
410+ for the request to complete. Note that if ``retry`` is
411+ specified, the timeout applies to each individual attempt.
412+ metadata (Optional[Sequence[Tuple[str, str]]]): Additional metadata
413+ that is provided to the method.
414+
415+ Returns:
416+ A :class:`~google.cloud.automl_v1.types._OperationFuture` instance.
417+
418+ Raises:
419+ google.api_core.exceptions.GoogleAPICallError: If the request
420+ failed for any reason.
421+ google.api_core.exceptions.RetryError: If the request failed due
422+ to a retryable error and retry attempts failed.
423+ ValueError: If the parameters are invalid.
424+ """
425+ # Wrap the transport method to add retry and timeout logic.
426+ if "batch_predict" not in self ._inner_api_calls :
427+ self ._inner_api_calls [
428+ "batch_predict"
429+ ] = google .api_core .gapic_v1 .method .wrap_method (
430+ self .transport .batch_predict ,
431+ default_retry = self ._method_configs ["BatchPredict" ].retry ,
432+ default_timeout = self ._method_configs ["BatchPredict" ].timeout ,
433+ client_info = self ._client_info ,
434+ )
435+
436+ request = prediction_service_pb2 .BatchPredictRequest (
437+ name = name ,
438+ input_config = input_config ,
439+ output_config = output_config ,
440+ params = params ,
441+ )
442+ if metadata is None :
443+ metadata = []
444+ metadata = list (metadata )
445+ try :
446+ routing_header = [("name" , name )]
447+ except AttributeError :
448+ pass
449+ else :
450+ routing_metadata = google .api_core .gapic_v1 .routing_header .to_grpc_metadata (
451+ routing_header
452+ )
453+ metadata .append (routing_metadata )
454+
455+ operation = self ._inner_api_calls ["batch_predict" ](
456+ request , retry = retry , timeout = timeout , metadata = metadata
457+ )
458+ return google .api_core .operation .from_gapic (
459+ operation ,
460+ self .transport ._operations_client ,
461+ prediction_service_pb2 .BatchPredictResult ,
462+ metadata_type = proto_operations_pb2 .OperationMetadata ,
463+ )
0 commit comments