From 162fcbc43f8908228090040f68eae26e12b0a34b Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Thu, 27 Jun 2019 17:43:22 -0400 Subject: [PATCH 01/11] Checking in staged client helper code Additional test, docs & proposed cleanup needs to happen on top of this. --- automl/docs/gapic/v1beta1/helper.rst | 5 + automl/docs/index.rst | 1 + .../google/cloud/automl_v1beta1/__init__.py | 1 + .../cloud/automl_v1beta1/helper/__init__.py | 0 .../cloud/automl_v1beta1/helper/tables.py | 1526 +++++++++++++++++ 5 files changed, 1533 insertions(+) create mode 100644 automl/docs/gapic/v1beta1/helper.rst create mode 100644 automl/google/cloud/automl_v1beta1/helper/__init__.py create mode 100644 automl/google/cloud/automl_v1beta1/helper/tables.py diff --git a/automl/docs/gapic/v1beta1/helper.rst b/automl/docs/gapic/v1beta1/helper.rst new file mode 100644 index 000000000000..0ccf49a35f40 --- /dev/null +++ b/automl/docs/gapic/v1beta1/helper.rst @@ -0,0 +1,5 @@ +Helper clients for Cloud AutoML API +=================================== + +.. automodule:: google.cloud.automl_v1beta1.helper.tables + :members: diff --git a/automl/docs/index.rst b/automl/docs/index.rst index cc1d290e2b55..dd0d74fe6846 100644 --- a/automl/docs/index.rst +++ b/automl/docs/index.rst @@ -8,6 +8,7 @@ Api Reference gapic/v1beta1/api gapic/v1beta1/types + gapic/v1beta1/helper Changelog diff --git a/automl/google/cloud/automl_v1beta1/__init__.py b/automl/google/cloud/automl_v1beta1/__init__.py index 2bc4b2a9f5a8..b5c6c5bedb40 100644 --- a/automl/google/cloud/automl_v1beta1/__init__.py +++ b/automl/google/cloud/automl_v1beta1/__init__.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from google.cloud.automl_v1beta1 import types +from google.cloud.automl_v1beta1.helper import tables from google.cloud.automl_v1beta1.gapic import auto_ml_client from google.cloud.automl_v1beta1.gapic import enums from google.cloud.automl_v1beta1.gapic import prediction_service_client diff --git a/automl/google/cloud/automl_v1beta1/helper/__init__.py b/automl/google/cloud/automl_v1beta1/helper/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/automl/google/cloud/automl_v1beta1/helper/tables.py b/automl/google/cloud/automl_v1beta1/helper/tables.py new file mode 100644 index 000000000000..d7f45b71b8f3 --- /dev/null +++ b/automl/google/cloud/automl_v1beta1/helper/tables.py @@ -0,0 +1,1526 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A helper for the google.cloud.automl_v1beta1 AutoML Tables API""" + +from google.cloud.automl_v1beta1.proto import data_types_pb2 + +class ClientHelper(object): + """ + AutoML Server API helper. + + This is intended to simplify usage of the auto-generated python client, + in particular for the `AutoML Tables product + `_. + """ + def __init__(self, client=None, prediction_client=None, project=None, + region='us-central1'): + """Constructor. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... prediction_client=automl_v1beta1.PredictionServiceClient(), + ... project='my-project', region='us-central1') + ... + + Args: + client (Optional[google.cloud.automl.v1beta1.AutoMlClient]): An + initialized AutoMLClient instance. This parameter is optional; + however, if you expect to make CRUD operations on either + models or datasets, this parameter needs to be set. + Additionally, if you want to make online predictions without + supplying a column schema, this client is during prediction. + prediction_client (Optional[google.cloud.automl.v1beta1.PredictionServiceClient]): + An initialized PredicitonServiceClient instance. This parameter + is optional; however, if you expect to make predictions, this + parameter needs to be set. + project (Optional[string]): The project all future calls will + default to. Most methods take `project` as an optional + parameter, and can override your choice of `project` supplied + here. + region (Optional[string]): The reigon all future calls will + default to. Most methods take `region` as an optional + parameter, and can override your choice of `region` supplied + here. Note, only `us-central1` is supported to-date. + """ + self.client = client + self.prediction_client = prediction_client + self.project = project + self.region = region + + def __location_path(self, project=None, region=None): + if project is None: + if self.project is None: + raise ValueError('Either initialize your client with a value ' + 'for \'project\', or provide \'project\' as a ' + 'parameter for this method.') + project = self.project + + if region is None: + if self.region is None: + raise ValueError('Either initialize your client with a value ' + 'for \'region\', or provide \'region\' as a ' + 'parameter for this method.') + region = self.region + + return self.client.location_path(project, region) + + def __dataset_name_from_args(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + if (dataset is None + and dataset_display_name is None + and dataset_name is None): + raise ValueError('One of \'dataset\', \'dataset_name\' or ' + '\'dataset_display_name\' must be set.') + + if dataset_name is None: + if dataset is None: + dataset = self.get_dataset( + dataset_display_name=dataset_display_name, + project=project, + region=region + ) + dataset_name = dataset.name + return dataset_name + + def __model_name_from_args(self, model=None, model_display_name=None, + model_name=None, project=None, region=None): + if (model is None + and model_display_name is None + and model_name is None): + raise ValueError('One of \'model\', \'model_name\' or ' + '\'model_display_name\' must be set.') + + if model_name is None: + if model is None: + model = self.get_model( + model_display_name=dataset_display_name, + project=project, + region=region + ) + model_name = model.name + return model_name + + def __column_spec_name_from_args(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + column_spec_name=None, column_spec_display_name=None, + project=None, region=None): + if column_spec_name is None: + if column_spec_display_name is None: + raise ValueError('Either supply \'column_spec_name\' or ' + '\'column_spec_display_name\' for the column to update') + column_specs = {s.display_name: s for s in + self.list_column_specs(dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + project=project, + region=region) + } + column_spec_name = column_specs[column_spec_display_name].name + return column_spec_name + + ## TODO(lwander): what other type codes are there? + ## https://github.com/googleapis/google-cloud-python/blob/master/automl/google/cloud/automl_v1beta1/proto/data_types_pb2.py#L87-L92 + def __type_code_to_value_type(self, type_code): + if type_code == data_types_pb2.FLOAT64: + return 'number_value' + if type_code == data_types_pb2.CATEGORY: + return 'string_value' + else: + raise ValueError('Unknown type_code: {}'.format(type_code)) + + def list_datasets(self, project=None, region=None): + """List all datasets in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> ds = client.list_datasets() + >>> + >>> for d in ds: + ... # do something + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of :class:`~google.cloud.automl_v1beta1.types.Dataset` + instances. You can also iterate over the pages of the response + using its `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.client.list_datasets( + self.__location_path(project=project, region=region) + ) + + def get_dataset(self, project=None, region=None, + dataset_name=None, dataset_display_name=None): + """Gets a single dataset in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> d = client.get_dataset(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_name (Optional[string]): + This is the fully-qualified name generated by the AutoML API + for this dataset. This is not to be confused with the + human-assigned `dataset_display_name` that is provided when + creating a dataset. Either `dataset_name` or + `dataset_display_name` must be provided. + dataset_display_name (Optional[string]): + This is the name you provided for the dataset when first + creating it. Either `dataset_name` or `dataset_display_name` + must be provided. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if dataset_name is None and dataset_display_name is None: + raise ValueError('One of \'dataset_name\' or ' + '\'dataset_display_name\' must be set.') + + if dataset_name is not None: + return client.get_dataset(dataset_name) + + return next(d for d in self.list_datasets(project, region) + if d.display_name == dataset_display_name) + + ## TODO(lwander): is metadata needed here? + def create_dataset(self, dataset_display_name, metadata={}, project=None, + region=None): + """Create a dataset. Keep in mind, importing data is a separate step. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> d = client.create_dataset(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (string): + A human-readable name to refer to this dataset by. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.client.create_dataset( + self.__location_path(project, region), + { + 'display_name': dataset_display_name, + 'tables_dataset_metadata': metadata + } + ) + + def delete_dataset(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + """Deletes a dataset. This does not delete any models trained on + this dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> op = client.delete_dataset(dataset_display_name='my_dataset') + >>> + >>> op.result() # blocks on delete request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to + delete. This must be supplied if `dataset` or `dataset_name` + are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + delete. This must be supplied if `dataset_display_name` or + `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to delete. This must be + supplied if `dataset_display_name` or `dataset_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + return self.client.delete_dataset(dataset_name) + + ## TODO(lwander): why multiple input GCS files? why not bq? + def import_data(self, dataset=None, dataset_display_name=None, + dataset_name=None, gcs_input_uris=None, + bigquery_input_uri=None, project=None, region=None): + """Imports data into a dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> d = client.create_dataset(dataset_display_name='my_dataset') + >>> + >>> client.import_data(dataset=d, + ... gcs_input_uris='gs://cloud-ml-tables-data/bank-marketing.csv') + ... + >>> def callback(operation_future): + ... result = operation_future.result() + ... + >>> response.add_done_callback(callback) + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to import + data into. This must be supplied if `dataset` or `dataset_name` + are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + import data into. This must be supplied if + `dataset_display_name` or `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to import data into. This must + be supplied if `dataset_display_name` or `dataset_name` are not + supplied. + gcs_input_uris (Optional[Union[string, Sequence[string]]]): + Either a single `gs://..` prefixed URI, or a list of URIs + referring to GCS-hosted CSV files containing the data to + import. This must be supplied if `bigquery_input_uri` is not. + bigquery_input_uri (Optional[string]): + A URI pointing to the BigQuery table containing the data to + import. This must be supplied if `gcs_input_uris` is not. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + request = {} + if gcs_input_uris is not None: + if type(gcs_input_uris) != list: + gcs_input_uris = [gcs_input_uris] + request = { + 'gcs_source': { + 'input_uris': gcs_input_uris + } + } + elif bigquery_input_uri is not None: + request = { + 'bigquery_source': { + 'input_uri': bigquery_input_uri + } + } + else: + raise ValueError('One of \'gcs_input_uris\', or ' + '\'bigquery_input_uri\' must be set.') + + return self.client.import_data(dataset_name, request) + + def list_table_specs(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + """Lists table specs. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> for s in client.list_table_specs(dataset_display_name='my_dataset') + ... # process the spec + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to read + specs from. This must be supplied if `dataset` or + `dataset_name` are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to read + specs from. This must be supplied if `dataset_display_name` or + `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to read specs from. This must + be supplied if `dataset_display_name` or `dataset_name` are not + supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of + :class:`~google.cloud.automl_v1beta1.types.TableSpec` instances. + You can also iterate over the pages of the response using its + `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + return self.client.list_table_specs(dataset_name) + + def list_column_specs(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + project=None, region=None): + """Lists column specs. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> for s in client.list_column_specs(dataset_display_name='my_dataset') + ... # process the spec + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose specs you want to + read. If not supplied, the client can determine this name from + a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` was provided, we use this index to + determine which table to read column specs from. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to read + specs from. If no `table_spec_name` is supplied, this will be + used together with `table_spec_index` to infer the name of + table to read specs from. This must be supplied if + `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to read + specs from. If no `table_spec_name` is supplied, this will be + used together with `table_spec_index` to infer the name of + table to read specs from. This must be supplied if + `table_spec_name`, `dataset` or `dataset_display_name` are not + supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to read specs from. If no + `table_spec_name` is supplied, this will be used together with + `table_spec_index` to infer the name of table to read specs + from. This must be supplied if `table_spec_name`, + `dataset_name` or `dataset_display_name` are not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of + :class:`~google.cloud.automl_v1beta1.types.ColumnSpec` instances. + You can also iterate over the pages of the response using its + `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if table_spec_name is None: + table_specs = [t for t in self.list_table_specs(dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + project=project, + region=region)] + + table_spec_name = table_specs[table_spec_index].name + + return self.client.list_column_specs(table_spec_name) + + def update_column_spec(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + column_spec_name=None, column_spec_display_name=None, + type_code=None, nullable=None, project=None, region=None): + """Updates a column's specs. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> client.update_column_specs(dataset_display_name='my_dataset', + ... column_spec_display_name='Outcome', type_code='CATEGORY') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to + update. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to update. If + this is supplied in place of `column_spec_name`, you also need + to provide either a way to lookup the source dataset (using one + of the `dataset*` kwargs), or the `table_spec_name` of the + table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose specs you want to + update. If not supplied, the client can determine this name + from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` was provided, we use this index to + determine which table to update column specs on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + specs on. If no `table_spec_name` is supplied, this will be + used together with `table_spec_index` to infer the name of + table to update specs on. This must be supplied if + `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update specs one. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update specs on. This must be supplied if + `table_spec_name`, `dataset` or `dataset_display_name` are not + supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update specs on. If no + `table_spec_name` is supplied, this will be used together with + `table_spec_index` to infer the name of table to update specs + on. This must be supplied if `table_spec_name`, `dataset_name` + or `dataset_display_name` are not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.ColumnSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region + ) + + # type code must always be set + if type_code is None: + type_code = {s.name: s for s in self.list_column_specs( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + project=project, + region=region) + }[column_spec_name].data_type.type_code + + data_type = {} + if nullable is not None: + data_type['nullable'] = nullable + + data_type['type_code'] = type_code + + request = { + 'name': column_spec_name, + 'data_type': data_type + } + + return self.client.update_column_spec(request) + + def set_target_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + column_spec_name=None, column_spec_display_name=None, + project=None, region=None): + """Sets the target column for a given table. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> client.set_target_column(dataset_display_name='my_dataset', + ... column_spec_display_name='Income') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to set as + the target column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + target column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose target column you + want to set . If not supplied, the client can determine this + name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the target + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the target column of. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update the target column of. This must be supplied + if `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the target column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the target column of. This + must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the target column of. + If no `table_spec_name` is supplied, this will be used together + with `table_spec_index` to infer the name of table to update + the target column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region + ) + column_spec_id = column_spec_name.rsplit('/', 1)[-1] + + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + request = { + 'name': dataset_name, + 'tables_dataset_metadata': { + 'target_column_spec_id': column_spec_id + } + } + + return self.client.update_dataset(request) + + def set_weight_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + column_spec_name=None, column_spec_display_name=None, + project=None, region=None): + """Sets the weight column for a given table. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> client.set_weight_column(dataset_display_name='my_dataset', + ... column_spec_display_name='Income') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to + set as the weight column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + weight column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose weight column you + want to set . If not supplied, the client can determine this + name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the weight + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the weight column of. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update the weight column of. This must be supplied + if `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the weight column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the weight column of. This + must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the weight column of. + If no `table_spec_name` is supplied, this will be used together + with `table_spec_index` to infer the name of table to update + the weight column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region + ) + column_spec_id = column_spec_name.rsplit('/', 1)[-1] + + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + request = { + 'name': dataset_name, + 'tables_dataset_metadata': { + 'weight_column_spec_id': column_spec_id + } + } + + return self.client.update_dataset(request) + + def set_test_train_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + column_spec_name=None, column_spec_display_name=None, + project=None, region=None): + """Sets the test/train (ml_use) column which designates which data + belongs to the test and train sets. This column must be categorical. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> client.set_test_train_column(dataset_display_name='my_dataset', + ... column_spec_display_name='TestSplit') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to set as + the test/train column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + test/train column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose test/train column + you want to set . If not supplied, the client can determine + this name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the test/train + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the test/train column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the test/train column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the test/train column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the test/train column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the test/train column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the test/train column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region + ) + column_spec_id = column_spec_name.rsplit('/', 1)[-1] + + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + request = { + 'name': dataset_name, + 'tables_dataset_metadata': { + 'ml_use_column_spec_id': column_spec_id + } + } + + return self.client.update_dataset(request) + + def list_models(self, project=None, region=None): + """List all models in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> ms = client.list_models() + >>> + >>> for m in ms: + ... # do something + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of :class:`~google.cloud.automl_v1beta1.types.Model` + instances. You can also iterate over the pages of the response + using its `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.client.list_models( + self.__location_path(project=project, region=region) + ) + + def create_model(self, model_display_name, dataset=None, + dataset_display_name=None, dataset_name=None, + train_budget_milli_node_hours=None, project=None, + region=None): + """Create a model. This will train your model on the given dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> m = client.create_model('my_model', dataset_display_name='my_dataset') + >>> + >>> m.result() # blocks on result + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (string): + A human-readable name to refer to this model by. + train_budget_milli_node_hours (int): + The amount of time (in thousandths of an hour) to spend + training. This value must be between 1,000 and 72,000 inclusive + (between 1 and 72 hours). + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to train + your model on. This must be supplied if `dataset` or + `dataset_name` are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to train + your model on. This must be supplied if `dataset_display_name` + or `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to train your model on. This + must be supplied if `dataset_display_name` or `dataset_name` + are not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if train_budget_milli_node_hours is None: + raise ValueError('\'train_budget_milli_node_hours\' must be a ' + 'value between 1,000 and 72,000 inclusive') + + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + dataset_id = dataset_name.rsplit('/', 1)[-1] + request = { + 'display_name': model_display_name, + 'dataset_id': dataset_id, + 'tables_model_metadata': { + 'train_budget_milli_node_hours': train_budget_milli_node_hours + } + } + + return self.client.create_model( + self.__location_path(project=project, region=region), + request + ) + + def delete_model(self, model=None, model_display_name=None, + model_name=None, project=None, region=None): + """Deletes a model. Note this will not delete any datasets associated + with this model. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> op = client.delete_model(model_display_name='my_model') + >>> + >>> op.result() # blocks on delete request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to + delete. This must be supplied if `model` or `model_name` + are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to + delete. This must be supplied if `model_display_name` or + `model` are not supplied. + model (Optional[model]): + The `model` instance you want to delete. This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args(model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region) + + return self.client.delete_model(model_name) + + def get_model(self, project=None, region=None, + model_name=None, model_display_name=None): + """Gets a single model in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> d = client.get_model(model_display_name='my_model') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_name (Optional[string]): + This is the fully-qualified name generated by the AutoML API + for this model. This is not to be confused with the + human-assigned `model_display_name` that is provided when + creating a model. Either `model_name` or + `model_display_name` must be provided. + model_display_name (Optional[string]): + This is the name you provided for the model when first + creating it. Either `model_name` or `model_display_name` + must be provided. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Model` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if model_name is None and model_display_name is None: + raise ValueError('One of \'model_name\' or ' + '\'model_display_name\' must be set.') + + return next(m for m in self.list_models(project, region) + if m.name == model_name + or m.display_name == model_display_name) + + #TODO(jonathanskim): allow deployment from just model ID + def deploy_model(self, model=None, model_name=None, + model_display_name=None, project=None, region=None): + """Deploys a model. This allows you make online predictions using the + model you've deployed. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> op = client.deploy_model(model_display_name='my_model') + >>> + >>> op.result() # blocks on deploy request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to + deploy. This must be supplied if `model` or `model_name` + are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to + deploy. This must be supplied if `model_display_name` or + `model` are not supplied. + model (Optional[model]): + The `model` instance you want to deploy. This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region + ) + + return self.client.deploy_model(model_name) + + def undeploy_model(self, model=None, model_name=None, + model_display_name=None, project=None, region=None): + """Undeploys a model. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> op = client.undeploy_model(model_display_name='my_model') + >>> + >>> op.result() # blocks on undeploy request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to + undeploy. This must be supplied if `model` or `model_name` + are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to + undeploy. This must be supplied if `model_display_name` or + `model` are not supplied. + model (Optional[model]): + The `model` instance you want to undeploy. This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region + ) + + return self.client.undeploy_model(model_name) + + ## TODO(lwander): support pandas DataFrame as input type + def predict(self, inputs, model=None, model_name=None, + model_display_name=None, project=None, region=None): + """Makes a prediction on a deployed model. This will fail if the model + was not deployed. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... prediction_client=automl_v1beta1.PredictionServiceClient(), + ... project='my-project', region='us-central1') + ... + >>> client.predict(inputs={'Age': 30, 'Income': 12, 'Category': 'A'} + ... model_display_name='my_model') + ... + >>> client.predict([30, 12, 'A'], model_display_name='my_model') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + inputs (Union[List[string], Dict[string, string]]): + Either the sorted list of column values to predict with, or a + key-value map of column display name to value to predict with. + model_display_name (Optional[string]): + The human-readable name given to the model you want to predict + with. This must be supplied if `model` or `model_name` are not + supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to predict + with. This must be supplied if `model_display_name` or `model` + are not supplied. + model (Optional[model]): + The `model` instance you want to predict with . This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.PredictResponse` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if model is None: + model = self.get_model( + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region + ) + + column_specs = model.tables_model_metadata.input_feature_column_specs + if type(inputs) == dict: + inputs = [inputs.get(c.display_name, None) for c in column_specs] + + if len(inputs) != len(column_specs): + raise ValueError(('Dimension mismatch, the number of provided ' + 'inputs ({}) does not match that of the model ' + '({})').format( + len(inputs), len(column_specs))) + + values = [] + for i, c in zip(inputs, column_specs): + value_type = self.__type_code_to_value_type(c.data_type.type_code) + values.append({value_type: i}) + + request = { + 'row': { + 'values': values + } + } + + return self.prediction_client.predict(model.name, request) + + def batch_predict(self, gcs_input_uris, gcs_output_uri_prefix, + model=None, model_name=None, model_display_name=None, project=None, + region=None, inputs=None): + """Makes a batch prediction on a model. This does _not_ require the + model to be deployed. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... prediction_client=automl_v1beta1.PredictionServiceClient(), + ... project='my-project', region='us-central1') + ... + >>> client.batch_predict( + ... gcs_input_uris='gs://inputs/input.csv', + ... gcs_output_uri_prefix='gs://outputs/', + ... model_display_name='my_model' + ... ).result() + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + gcs_input_uris (Union[List[string], string]) + Either a list of or a single GCS URI containing the data you + want to predict off of. + gcs_output_uri_prefix (string) + The folder in GCS you want to write output to. + model_display_name (Optional[string]): + The human-readable name given to the model you want to predict + with. This must be supplied if `model` or `model_name` are not + supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to predict + with. This must be supplied if `model_display_name` or `model` + are not supplied. + model (Optional[model]): + The `model` instance you want to predict with . This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if gcs_input_uris is None or gcs_output_uri_prefix is None: + raise ValueError('Both \'gcs_input_uris\' and ' + '\'gcs_output_uri_prefix\' must be set.') + + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region + ) + + if type(gcs_input_uris) != list: + gcs_input_uris = [gcs_input_uris] + + input_request = { + 'gcs_source': { + 'input_uris': gcs_input_uris + } + } + + output_request = { + 'gcs_source': { + 'output_uri_prefix': gcs_output_uri_prefix + } + } + + return self.prediction_client.batch_predict(model_name, input_request, + output_request) From c45eea14fcb802f923c32e0d5c0f3ba2b601b150 Mon Sep 17 00:00:00 2001 From: jonathan1920 Date: Mon, 8 Jul 2019 18:20:31 -0700 Subject: [PATCH 02/11] =?UTF-8?q?update=20create=5Fmodel=20to=20allow=20us?= =?UTF-8?q?er=20to=20specify=20included=20or=20excluded=20col=E2=80=A6=20(?= =?UTF-8?q?#16)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update create_model to allow user to specify included or excluded columns * made minor changes stylistically and with added ValueError outputs --- .../cloud/automl_v1beta1/helper/tables.py | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/helper/tables.py b/automl/google/cloud/automl_v1beta1/helper/tables.py index d7f45b71b8f3..d4ccc43a4c56 100644 --- a/automl/google/cloud/automl_v1beta1/helper/tables.py +++ b/automl/google/cloud/automl_v1beta1/helper/tables.py @@ -1043,8 +1043,8 @@ def list_models(self, project=None, region=None): def create_model(self, model_display_name, dataset=None, dataset_display_name=None, dataset_name=None, train_budget_milli_node_hours=None, project=None, - region=None): - """Create a model. This will train your model on the given dataset. + region=None, input_feature_column_specs_included=None, input_feature_column_specs_excluded=None): + """Create a model. This will train your model on the given dataset. Example: >>> from google.cloud import automl_v1beta1 @@ -1057,7 +1057,6 @@ def create_model(self, model_display_name, dataset=None, >>> >>> m.result() # blocks on result >>> - Args: project (Optional[string]): If you have initialized the client with a value for `project` @@ -1085,11 +1084,15 @@ def create_model(self, model_display_name, dataset=None, The `Dataset` instance you want to train your model on. This must be supplied if `dataset_display_name` or `dataset_name` are not supplied. - + input_feature_column_specs_included(Optional[string]): + The list of the names of the columns you want to include to train + your model on. + input_feature_column_specs_excluded(Optional[string]): + The list of the names of the columns you want to exclude and + not train your model on. Returns: A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` instance. - Raises: google.api_core.exceptions.GoogleAPICallError: If the request failed for any reason. @@ -1101,26 +1104,56 @@ def create_model(self, model_display_name, dataset=None, raise ValueError('\'train_budget_milli_node_hours\' must be a ' 'value between 1,000 and 72,000 inclusive') + if input_feature_column_specs_excluded not in [None, []] and input_feature_column_specs_included not in [None, []]: + raise ValueError('\'cannot set both input_feature_column_specs_excluded\' and ' + '\'input_feature_column_specs_included\'') + + dataset_name = self.__dataset_name_from_args(dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, project=project, region=region) - + tables_model_metadata = { + 'train_budget_milli_node_hours': train_budget_milli_node_hours + } dataset_id = dataset_name.rsplit('/', 1)[-1] + columns = [s for s in self.list_column_specs(dataset=dataset, dataset_name = dataset_name, dataset_display_name=dataset_display_name)] + + final_columns = [] + if input_feature_column_specs_included: + column_names = [a.display_name for a in columns] + if not (all (name in column_names for name in input_feature_column_specs_included)): + raise ValueError('invalid name in the list' '\'input_feature_column_specs_included\'') + for a in columns: + if a.display_name in input_feature_column_specs_included: + final_columns.append(a) + + tables_model_metadata.update( + {'input_feature_column_specs': final_columns} + ) + elif input_feature_column_specs_excluded: + for a in columns: + if a.display_name not in input_feature_column_specs_excluded: + final_columns.append(a) + + tables_model_metadata.update( + {'input_feature_column_specs': final_columns} + ) + request = { 'display_name': model_display_name, 'dataset_id': dataset_id, - 'tables_model_metadata': { - 'train_budget_milli_node_hours': train_budget_milli_node_hours - } + 'tables_model_metadata': tables_model_metadata } + return self.client.create_model( self.__location_path(project=project, region=region), request ) + def delete_model(self, model=None, model_display_name=None, model_name=None, project=None, region=None): """Deletes a model. Note this will not delete any datasets associated From be43449aef20aee15d78cc37baee17dd6e866859 Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Mon, 15 Jul 2019 13:48:55 -0400 Subject: [PATCH 03/11] Update doc gen & module structure. Add unit & system tests --- automl/README.rst | 18 + automl/docs/gapic/v1beta1/helper.rst | 5 - automl/docs/gapic/v1beta1/tables.rst | 5 + automl/docs/index.rst | 2 +- .../google/cloud/automl_v1beta1/__init__.py | 9 +- .../{helper => tables}/__init__.py | 0 .../tables.py => tables/tables_client.py} | 522 +++++--- .../v1beta1/test_tables_client_v1beta1.py | 1055 +++++++++++++++++ 8 files changed, 1446 insertions(+), 170 deletions(-) delete mode 100644 automl/docs/gapic/v1beta1/helper.rst create mode 100644 automl/docs/gapic/v1beta1/tables.rst rename automl/google/cloud/automl_v1beta1/{helper => tables}/__init__.py (100%) rename automl/google/cloud/automl_v1beta1/{helper/tables.py => tables/tables_client.py} (79%) create mode 100644 automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py diff --git a/automl/README.rst b/automl/README.rst index 08ba90ae85b0..bfcb555689e4 100644 --- a/automl/README.rst +++ b/automl/README.rst @@ -104,3 +104,21 @@ Next Steps API to see other available methods on the client. - Read the `Product documentation`_ to learn more about the product and see How-to Guides. + +Making & Testing Local Changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to make changes to this library, here is how to set up your +development environment: + +1. Make sure you have `virtualenv`_ installed and activated as shown above. +2. Run the following one-time setup (it will be persisted in your virtualenv): + +.. code-block:: console + + pip install -r ../docs/requirements.txt + pip install -U mock pytest + +3. To test any changes you've made, run `pytest` in this directory. + + diff --git a/automl/docs/gapic/v1beta1/helper.rst b/automl/docs/gapic/v1beta1/helper.rst deleted file mode 100644 index 0ccf49a35f40..000000000000 --- a/automl/docs/gapic/v1beta1/helper.rst +++ /dev/null @@ -1,5 +0,0 @@ -Helper clients for Cloud AutoML API -=================================== - -.. automodule:: google.cloud.automl_v1beta1.helper.tables - :members: diff --git a/automl/docs/gapic/v1beta1/tables.rst b/automl/docs/gapic/v1beta1/tables.rst new file mode 100644 index 000000000000..54ed6a203805 --- /dev/null +++ b/automl/docs/gapic/v1beta1/tables.rst @@ -0,0 +1,5 @@ +A tables-specific client for AutoML +=================================== + +.. automodule:: google.cloud.automl_v1beta1.tables.tables_client + :members: diff --git a/automl/docs/index.rst b/automl/docs/index.rst index dd0d74fe6846..01f577642cb1 100644 --- a/automl/docs/index.rst +++ b/automl/docs/index.rst @@ -8,7 +8,7 @@ Api Reference gapic/v1beta1/api gapic/v1beta1/types - gapic/v1beta1/helper + gapic/v1beta1/tables Changelog diff --git a/automl/google/cloud/automl_v1beta1/__init__.py b/automl/google/cloud/automl_v1beta1/__init__.py index b5c6c5bedb40..ae08470889ef 100644 --- a/automl/google/cloud/automl_v1beta1/__init__.py +++ b/automl/google/cloud/automl_v1beta1/__init__.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from google.cloud.automl_v1beta1 import types -from google.cloud.automl_v1beta1.helper import tables from google.cloud.automl_v1beta1.gapic import auto_ml_client from google.cloud.automl_v1beta1.gapic import enums from google.cloud.automl_v1beta1.gapic import prediction_service_client +from google.cloud.automl_v1beta1.tables import tables_client class AutoMlClient(auto_ml_client.AutoMlClient): @@ -34,4 +34,9 @@ class PredictionServiceClient(prediction_service_client.PredictionServiceClient) enums = enums -__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient") +class TablesClient(tables_client.TablesClient): + __doc__ = tables_client.TablesClient.__doc__ + + +__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", + "TablesClient") diff --git a/automl/google/cloud/automl_v1beta1/helper/__init__.py b/automl/google/cloud/automl_v1beta1/tables/__init__.py similarity index 100% rename from automl/google/cloud/automl_v1beta1/helper/__init__.py rename to automl/google/cloud/automl_v1beta1/tables/__init__.py diff --git a/automl/google/cloud/automl_v1beta1/helper/tables.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py similarity index 79% rename from automl/google/cloud/automl_v1beta1/helper/tables.py rename to automl/google/cloud/automl_v1beta1/tables/tables_client.py index d4ccc43a4c56..617cbd62d51a 100644 --- a/automl/google/cloud/automl_v1beta1/helper/tables.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -1,10 +1,12 @@ -# Copyright 2019 Google Inc. All Rights Reserved. +# -*- coding: utf-8 -*- +# +# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,53 +14,99 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A helper for the google.cloud.automl_v1beta1 AutoML Tables API""" +"""A tables helper for the google.cloud.automl_v1beta1 AutoML API""" + +import pkg_resources +from google.api_core.gapic_v1 import client_info +from google.api_core import exceptions +from google.cloud import automl_v1beta1 from google.cloud.automl_v1beta1.proto import data_types_pb2 -class ClientHelper(object): +_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version + +class TablesClient(object): """ - AutoML Server API helper. + AutoML Tables API helper. This is intended to simplify usage of the auto-generated python client, in particular for the `AutoML Tables product `_. """ - def __init__(self, client=None, prediction_client=None, project=None, - region='us-central1'): + def __init__(self, project=None, region='us-central1', client=None, + prediction_client=None, **kwargs): """Constructor. Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), - ... prediction_client=automl_v1beta1.PredictionServiceClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... Args: - client (Optional[google.cloud.automl.v1beta1.AutoMlClient]): An - initialized AutoMLClient instance. This parameter is optional; - however, if you expect to make CRUD operations on either - models or datasets, this parameter needs to be set. - Additionally, if you want to make online predictions without - supplying a column schema, this client is during prediction. - prediction_client (Optional[google.cloud.automl.v1beta1.PredictionServiceClient]): - An initialized PredicitonServiceClient instance. This parameter - is optional; however, if you expect to make predictions, this - parameter needs to be set. project (Optional[string]): The project all future calls will default to. Most methods take `project` as an optional parameter, and can override your choice of `project` supplied here. - region (Optional[string]): The reigon all future calls will + region (Optional[string]): The region all future calls will default to. Most methods take `region` as an optional parameter, and can override your choice of `region` supplied here. Note, only `us-central1` is supported to-date. + transport (Union[~.AutoMlGrpcTransport, Callable[[~.Credentials, type], ~.AutoMlGrpcTransport]): + A transport instance, responsible for actually making the API + calls. The default transport uses the gRPC protocol. This + argument may also be a callable which returns a transport + instance. Callables will be sent the credentials as the first + argument and the default transport class as the second + argument. + channel (grpc.Channel): DEPRECATED. A ``Channel`` instance + through which to make calls. This argument is mutually exclusive + with ``credentials``; providing both will raise an exception. + credentials (google.auth.credentials.Credentials): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is mutually exclusive with providing a + transport instance to ``transport``; doing so will raise + an exception. + client_config (dict): DEPRECATED. A dictionary of call options for + each method. If not specified, the default configuration is used. + client_options (Union[dict, google.api_core.client_options.ClientOptions]): + Client options used to set user options on the client. API Endpoint + should be set through client_options. """ - self.client = client - self.prediction_client = prediction_client + version = _GAPIC_LIBRARY_VERSION + user_agent = 'automl-tables-wrapper/{}'.format(version) + + client_info_ = kwargs.get('client_info') + if client_info_ is None: + client_info_ = client_info.ClientInfo( + user_agent=user_agent, + gapic_version=version + ) + else: + client_info_.user_agent = user_agent + client_info_.gapic_version = version + + if client is None: + self.client = automl_v1beta1.AutoMlClient(client_info=client_info_, + **kwargs) + else: + self.client = client + + if prediction_client is None: + self.prediction_client = automl_v1beta1.PredictionServiceClient( + client_info=client_info_, + **kwargs + ) + else: + self.prediction_client = prediction_client + self.project = project self.region = region @@ -79,6 +127,55 @@ def __location_path(self, project=None, region=None): return self.client.location_path(project, region) + # the returned metadata object doesn't allow for updating fields, so + # we need to manually copy user-updated fields over + def __update_metadata(self, metadata, k, v): + new_metadata = {} + new_metadata['ml_use_column_spec_id'] = metadata.ml_use_column_spec_id + new_metadata['weight_column_spec_id'] = metadata.weight_column_spec_id + new_metadata['target_column_spec_id'] = metadata.target_column_spec_id + new_metadata[k] = v + + return new_metadata + + def __dataset_from_args(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + if (dataset is None + and dataset_display_name is None + and dataset_name is None): + raise ValueError('One of \'dataset\', \'dataset_name\' or ' + '\'dataset_display_name\' must be set.') + # we prefer to make a live call here in the case that the + # dataset object is out-of-date + if dataset is not None: + dataset_name = dataset.name + + return self.get_dataset( + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + project=project, + region=region + ) + + def __model_from_args(self, model=None, model_display_name=None, + model_name=None, project=None, region=None): + if (model is None + and model_display_name is None + and model_name is None): + raise ValueError('One of \'model\', \'model_name\' or ' + '\'model_display_name\' must be set.') + # we prefer to make a live call here in the case that the + # model object is out-of-date + if model is not None: + model_name = model.name + + return self.get_model( + model_display_name=model_display_name, + model_name=model_name, + project=project, + region=region + ) + def __dataset_name_from_args(self, dataset=None, dataset_display_name=None, dataset_name=None, project=None, region=None): if (dataset is None @@ -94,7 +191,15 @@ def __dataset_name_from_args(self, dataset=None, dataset_display_name=None, project=project, region=region ) + dataset_name = dataset.name + else: + # we do this to force a NotFound error when needed + self.get_dataset( + dataset_name=dataset_name, + project=project, + region=region + ) return dataset_name def __model_name_from_args(self, model=None, model_display_name=None, @@ -108,39 +213,64 @@ def __model_name_from_args(self, model=None, model_display_name=None, if model_name is None: if model is None: model = self.get_model( - model_display_name=dataset_display_name, + model_display_name=model_display_name, project=project, region=region ) model_name = model.name + else: + # we do this to force a NotFound error when needed + self.get_model( + model_name=model_name, + project=project, + region=region + ) return model_name def __column_spec_name_from_args(self, dataset=None, dataset_display_name=None, dataset_name=None, table_spec_name=None, table_spec_index=0, column_spec_name=None, column_spec_display_name=None, project=None, region=None): - if column_spec_name is None: - if column_spec_display_name is None: - raise ValueError('Either supply \'column_spec_name\' or ' - '\'column_spec_display_name\' for the column to update') - column_specs = {s.display_name: s for s in - self.list_column_specs(dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - project=project, - region=region) - } + column_specs = self.list_column_specs(dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + project=project, + region=region) + if column_spec_display_name is not None: + column_specs = {s.display_name: s for s in column_specs} + if column_specs.get(column_spec_display_name) is None: + raise exceptions.NotFound('No column with ' + + 'column_spec_display_name: \'{}\' found'.format( + column_spec_display_name + )) column_spec_name = column_specs[column_spec_display_name].name + elif column_spec_name is not None: + column_specs = {s.name: s for s in column_specs} + if column_specs.get(column_spec_name) is None: + raise exceptions.NotFound('No column with ' + + 'column_spec_name: \'{}\' found'.format( + column_spec_name + )) + else: + raise ValueError('Either supply \'column_spec_name\' or ' + '\'column_spec_display_name\' for the column to update') + return column_spec_name - ## TODO(lwander): what other type codes are there? - ## https://github.com/googleapis/google-cloud-python/blob/master/automl/google/cloud/automl_v1beta1/proto/data_types_pb2.py#L87-L92 def __type_code_to_value_type(self, type_code): if type_code == data_types_pb2.FLOAT64: return 'number_value' - if type_code == data_types_pb2.CATEGORY: + elif type_code == data_types_pb2.TIMESTAMP: + return 'string_value' + elif type_code == data_types_pb2.STRING: + return 'string_value' + elif type_code == data_types_pb2.ARRAY: + return 'list_value' + elif type_code == data_types_pb2.STRUCT: + return 'struct_value' + elif type_code == data_types_pb2.CATEGORY: return 'string_value' else: raise ValueError('Unknown type_code: {}'.format(type_code)) @@ -151,8 +281,10 @@ def list_datasets(self, project=None, region=None): Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> ds = client.list_datasets() @@ -196,8 +328,10 @@ def get_dataset(self, project=None, region=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> d = client.get_dataset(dataset_display_name='my_dataset') @@ -224,7 +358,8 @@ def get_dataset(self, project=None, region=None, must be provided. Returns: - A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance if + found, `None` otherwise. Raises: google.api_core.exceptions.GoogleAPICallError: If the request @@ -238,12 +373,17 @@ def get_dataset(self, project=None, region=None, '\'dataset_display_name\' must be set.') if dataset_name is not None: - return client.get_dataset(dataset_name) + return self.client.get_dataset(dataset_name) + + result = next((d for d in self.list_datasets(project, region) + if d.display_name == dataset_display_name), None) + + if result is None: + raise exceptions.NotFound(('Dataset with display_name: \'{}\' ' + + 'not found').format(dataset_display_name)) - return next(d for d in self.list_datasets(project, region) - if d.display_name == dataset_display_name) + return result - ## TODO(lwander): is metadata needed here? def create_dataset(self, dataset_display_name, metadata={}, project=None, region=None): """Create a dataset. Keep in mind, importing data is a separate step. @@ -251,8 +391,10 @@ def create_dataset(self, dataset_display_name, metadata={}, project=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> d = client.create_dataset(dataset_display_name='my_dataset') @@ -296,8 +438,10 @@ def delete_dataset(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> op = client.delete_dataset(dataset_display_name='my_dataset') @@ -338,15 +482,18 @@ def delete_dataset(self, dataset=None, dataset_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + try: + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + # delete is idempotent + except exceptions.NotFound: + return None return self.client.delete_dataset(dataset_name) - ## TODO(lwander): why multiple input GCS files? why not bq? def import_data(self, dataset=None, dataset_display_name=None, dataset_name=None, gcs_input_uris=None, bigquery_input_uri=None, project=None, region=None): @@ -355,8 +502,10 @@ def import_data(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> d = client.create_dataset(dataset_display_name='my_dataset') @@ -444,8 +593,10 @@ def list_table_specs(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> for s in client.list_table_specs(dataset_display_name='my_dataset') @@ -505,8 +656,10 @@ def list_column_specs(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> for s in client.list_column_specs(dataset_display_name='my_dataset') @@ -585,8 +738,10 @@ def update_column_spec(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.update_column_specs(dataset_display_name='my_dataset', @@ -663,6 +818,8 @@ def update_column_spec(self, dataset=None, dataset_display_name=None, # type code must always be set if type_code is None: + # this index is safe, we would have already thrown a NotFound + # had the column_spec_name not existed type_code = {s.name: s for s in self.list_column_specs( dataset=dataset, dataset_display_name=dataset_display_name, @@ -695,8 +852,10 @@ def set_target_column(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.set_target_column(dataset_display_name='my_dataset', @@ -774,17 +933,19 @@ def set_target_column(self, dataset=None, dataset_display_name=None, ) column_spec_id = column_spec_name.rsplit('/', 1)[-1] - dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset = self.__dataset_from_args(dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, project=project, region=region) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, + 'target_column_spec_id', + column_spec_id) request = { - 'name': dataset_name, - 'tables_dataset_metadata': { - 'target_column_spec_id': column_spec_id - } + 'name': dataset.name, + 'tables_dataset_metadata': metadata, } return self.client.update_dataset(request) @@ -798,8 +959,10 @@ def set_weight_column(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.set_weight_column(dataset_display_name='my_dataset', @@ -877,17 +1040,19 @@ def set_weight_column(self, dataset=None, dataset_display_name=None, ) column_spec_id = column_spec_name.rsplit('/', 1)[-1] - dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset = self.__dataset_from_args(dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, project=project, region=region) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, + 'weight_column_spec_id', + column_spec_id) request = { - 'name': dataset_name, - 'tables_dataset_metadata': { - 'weight_column_spec_id': column_spec_id - } + 'name': dataset.name, + 'tables_dataset_metadata': metadata, } return self.client.update_dataset(request) @@ -902,8 +1067,10 @@ def set_test_train_column(self, dataset=None, dataset_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.set_test_train_column(dataset_display_name='my_dataset', @@ -981,17 +1148,17 @@ def set_test_train_column(self, dataset=None, dataset_display_name=None, ) column_spec_id = column_spec_name.rsplit('/', 1)[-1] - dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset = self.__dataset_from_args(dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, project=project, region=region) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, 'ml_use_column_spec_id', column_spec_id) request = { - 'name': dataset_name, - 'tables_dataset_metadata': { - 'ml_use_column_spec_id': column_spec_id - } + 'name': dataset.name, + 'tables_dataset_metadata': metadata, } return self.client.update_dataset(request) @@ -1002,8 +1169,10 @@ def list_models(self, project=None, region=None): Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> ms = client.list_models() @@ -1043,20 +1212,25 @@ def list_models(self, project=None, region=None): def create_model(self, model_display_name, dataset=None, dataset_display_name=None, dataset_name=None, train_budget_milli_node_hours=None, project=None, - region=None, input_feature_column_specs_included=None, input_feature_column_specs_excluded=None): - """Create a model. This will train your model on the given dataset. + region=None, model_metadata={}, + include_column_spec_names=None, + exclude_column_spec_names=None): + """Create a model. This will train your model on the given dataset. Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> m = client.create_model('my_model', dataset_display_name='my_dataset') >>> >>> m.result() # blocks on result >>> + Args: project (Optional[string]): If you have initialized the client with a value for `project` @@ -1084,12 +1258,14 @@ def create_model(self, model_display_name, dataset=None, The `Dataset` instance you want to train your model on. This must be supplied if `dataset_display_name` or `dataset_name` are not supplied. - input_feature_column_specs_included(Optional[string]): - The list of the names of the columns you want to include to train - your model on. - input_feature_column_specs_excluded(Optional[string]): - The list of the names of the columns you want to exclude and - not train your model on. + model_metadata (Optional[Dict]): + Optional model metadata to supply to the client. + include_column_spec_names(Optional[string]): + The list of the names of the columns you want to include to train + your model on. + exclude_column_spec_names(Optional[string]): + The list of the names of the columns you want to exclude and + not train your model on. Returns: A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` instance. @@ -1100,60 +1276,56 @@ def create_model(self, model_display_name, dataset=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - if train_budget_milli_node_hours is None: + if (train_budget_milli_node_hours is None or + train_budget_milli_node_hours < 1000 or + train_budget_milli_node_hours > 72000): raise ValueError('\'train_budget_milli_node_hours\' must be a ' 'value between 1,000 and 72,000 inclusive') - if input_feature_column_specs_excluded not in [None, []] and input_feature_column_specs_included not in [None, []]: - raise ValueError('\'cannot set both input_feature_column_specs_excluded\' and ' - '\'input_feature_column_specs_included\'') - + if (exclude_column_spec_names not in [None, []] and + include_column_spec_names not in [None, []]): + raise ValueError('Cannot set both ' + '\'exclude_column_spec_names\' and ' + '\'include_column_spec_names\'') dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) - tables_model_metadata = { - 'train_budget_milli_node_hours': train_budget_milli_node_hours - } + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + model_metadata['train_budget_milli_node_hours'] = train_budget_milli_node_hours + dataset_id = dataset_name.rsplit('/', 1)[-1] - columns = [s for s in self.list_column_specs(dataset=dataset, dataset_name = dataset_name, dataset_display_name=dataset_display_name)] + columns = [s for s in self.list_column_specs(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name)] final_columns = [] - if input_feature_column_specs_included: - column_names = [a.display_name for a in columns] - if not (all (name in column_names for name in input_feature_column_specs_included)): - raise ValueError('invalid name in the list' '\'input_feature_column_specs_included\'') - for a in columns: - if a.display_name in input_feature_column_specs_included: - final_columns.append(a) + if include_column_spec_names: + for c in columns: + if c.display_name in include_column_spec_names: + final_columns.append(c) - tables_model_metadata.update( - {'input_feature_column_specs': final_columns} - ) - elif input_feature_column_specs_excluded: + model_metadata['input_feature_column_specs'] = final_columns + elif exclude_column_spec_names: for a in columns: - if a.display_name not in input_feature_column_specs_excluded: + if a.display_name not in exclude_column_spec_names: final_columns.append(a) - tables_model_metadata.update( - {'input_feature_column_specs': final_columns} - ) + model_metadata['input_feature_column_specs'] = final_columns request = { - 'display_name': model_display_name, - 'dataset_id': dataset_id, - 'tables_model_metadata': tables_model_metadata + 'display_name': model_display_name, + 'dataset_id': dataset_id, + 'tables_model_metadata': model_metadata } - return self.client.create_model( - self.__location_path(project=project, region=region), - request + self.__location_path(project=project, region=region), + request ) - def delete_model(self, model=None, model_display_name=None, model_name=None, project=None, region=None): """Deletes a model. Note this will not delete any datasets associated @@ -1162,8 +1334,10 @@ def delete_model(self, model=None, model_display_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> op = client.delete_model(model_display_name='my_model') @@ -1204,11 +1378,15 @@ def delete_model(self, model=None, model_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - model_name = self.__model_name_from_args(model=model, - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region) + try: + model_name = self.__model_name_from_args(model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region) + # delete is idempotent + except exceptions.NotFound: + return None return self.client.delete_model(model_name) @@ -1219,8 +1397,10 @@ def get_model(self, project=None, region=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> d = client.get_model(model_display_name='my_model') @@ -1260,9 +1440,17 @@ def get_model(self, project=None, region=None, raise ValueError('One of \'model_name\' or ' '\'model_display_name\' must be set.') - return next(m for m in self.list_models(project, region) - if m.name == model_name - or m.display_name == model_display_name) + if model_name is not None: + return self.client.get_model(model_name) + + model = next((d for d in self.list_models(project, region) + if d.display_name == model_display_name), None) + + if model is None: + raise exceptions.NotFound('No model with model_diplay_name: ' + + '\'{}\' found'.format(model_display_name)) + + return model #TODO(jonathanskim): allow deployment from just model ID def deploy_model(self, model=None, model_name=None, @@ -1273,8 +1461,10 @@ def deploy_model(self, model=None, model_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> op = client.deploy_model(model_display_name='my_model') @@ -1332,8 +1522,10 @@ def undeploy_model(self, model=None, model_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> op = client.undeploy_model(model_display_name='my_model') @@ -1393,8 +1585,10 @@ def predict(self, inputs, model=None, model_name=None, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... prediction_client=automl_v1beta1.PredictionServiceClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.predict(inputs={'Age': 30, 'Income': 12, 'Category': 'A'} @@ -1439,13 +1633,15 @@ def predict(self, inputs, model=None, model_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - if model is None: - model = self.get_model( - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region - ) + model = self.__model_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region + ) + print(model) + print(model.tables_model_metadata) column_specs = model.tables_model_metadata.input_feature_column_specs if type(inputs) == dict: @@ -1479,8 +1675,10 @@ def batch_predict(self, gcs_input_uris, gcs_output_uri_prefix, Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... prediction_client=automl_v1beta1.PredictionServiceClient(), + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.batch_predict( diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py new file mode 100644 index 000000000000..5e07184bc906 --- /dev/null +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -0,0 +1,1055 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests.""" + +import mock +import pytest + +from google.cloud import automl_v1beta1 +from google.api_core import exceptions +from google.cloud.automl_v1beta1.proto import data_types_pb2 + +PROJECT='project' +REGION='region' +LOCATION_PATH='projects/{}/locations/{}'.format(PROJECT, REGION) + +class TestTablesClient(object): + + def tables_client(self, client_attrs={}, + prediction_client_attrs={}): + client_mock = mock.Mock(**client_attrs) + prediction_client_mock = mock.Mock(**prediction_client_attrs) + return automl_v1beta1.TablesClient(client=client_mock, + prediction_client=prediction_client_mock, + project=PROJECT, region=REGION) + + def test_list_datasets_empty(self): + client = self.tables_client({ + 'list_datasets.return_value': [], + 'location_path.return_value': LOCATION_PATH, + }, {}) + ds = client.list_datasets() + client.client.location_path.assert_called_with(PROJECT, REGION) + client.client.list_datasets.assert_called_with(LOCATION_PATH) + assert ds == [] + + def test_list_datasets_not_empty(self): + datasets = ['some_dataset'] + client = self.tables_client({ + 'list_datasets.return_value': datasets, + 'location_path.return_value': LOCATION_PATH, + }, {}) + ds = client.list_datasets() + client.client.location_path.assert_called_with(PROJECT, REGION) + client.client.list_datasets.assert_called_with(LOCATION_PATH) + assert len(ds) == 1 + assert ds[0] == 'some_dataset' + + def test_get_dataset_no_value(self): + dataset_actual = 'dataset' + client = self.tables_client({}, {}) + error = None + try: + dataset = client.get_dataset() + except ValueError as e: + error = e + assert error is not None + client.client.get_dataset.assert_not_called() + + def test_get_dataset_name(self): + dataset_actual = 'dataset' + client = self.tables_client({ + 'get_dataset.return_value': dataset_actual + }, {}) + dataset = client.get_dataset(dataset_name='my_dataset') + client.client.get_dataset.assert_called_with('my_dataset') + assert dataset == dataset_actual + + def test_get_no_dataset(self): + client = self.tables_client({ + 'get_dataset.side_effect': exceptions.NotFound('err') + }, {}) + error = None + try: + client.get_dataset(dataset_name='my_dataset') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.get_dataset.assert_called_with('my_dataset') + + def test_get_dataset_from_empty_list(self): + client = self.tables_client({'list_datasets.return_value': []}, {}) + error = None + try: + client.get_dataset(dataset_display_name='my_dataset') + except exceptions.NotFound as e: + error = e + assert error is not None + + def test_get_dataset_from_list_not_found(self): + client = self.tables_client({ + 'list_datasets.return_value': [mock.Mock(display_name='not_it')] + }, {}) + error = None + try: + client.get_dataset(dataset_display_name='my_dataset') + except exceptions.NotFound as e: + error = e + assert error is not None + + def test_get_dataset_from_list(self): + client = self.tables_client({ + 'list_datasets.return_value': [ + mock.Mock(display_name='not_it'), + mock.Mock(display_name='my_dataset'), + ] + }, {}) + dataset = client.get_dataset(dataset_display_name='my_dataset') + assert dataset.display_name == 'my_dataset' + + def test_create_dataset(self): + client = self.tables_client({ + 'location_path.return_value': LOCATION_PATH, + 'create_dataset.return_value': mock.Mock(display_name='name'), + }, {}) + metadata = {'metadata': 'values'} + dataset = client.create_dataset('name', metadata=metadata) + client.client.location_path.assert_called_with(PROJECT, REGION) + client.client.create_dataset.assert_called_with( + LOCATION_PATH, + {'display_name': 'name', 'tables_dataset_metadata': metadata} + ) + assert dataset.display_name == 'name' + + def test_delete_dataset(self): + dataset = mock.Mock() + dataset.configure_mock(name='name') + client = self.tables_client({ + 'delete_dataset.return_value': None, + }, {}) + client.delete_dataset(dataset=dataset) + client.client.delete_dataset.assert_called_with('name') + + def test_delete_dataset_not_found(self): + client = self.tables_client({ + 'list_datasets.return_value': [], + }, {}) + client.delete_dataset(dataset_display_name='not_found') + client.client.delete_dataset.assert_not_called() + + def test_delete_dataset_name(self): + client = self.tables_client({ + 'delete_dataset.return_value': None, + }, {}) + client.delete_dataset(dataset_name='name') + client.client.delete_dataset.assert_called_with('name') + + def test_import_not_found(self): + client = self.tables_client({ + 'list_datasets.return_value': [], + }, {}) + error = None + try: + client.import_data(dataset_display_name='name', gcs_input_uris='uri') + except exceptions.NotFound as e: + error = e + assert error is not None + + client.client.import_data.assert_not_called() + + def test_import_gcs_uri(self): + client = self.tables_client({ + 'import_data.return_value': None, + }, {}) + client.import_data(dataset_name='name', gcs_input_uris='uri') + client.client.import_data.assert_called_with('name', { + 'gcs_source': {'input_uris': ['uri']} + }) + + def test_import_gcs_uris(self): + client = self.tables_client({ + 'import_data.return_value': None, + }, {}) + client.import_data(dataset_name='name', + gcs_input_uris=['uri', 'uri']) + client.client.import_data.assert_called_with('name', { + 'gcs_source': {'input_uris': ['uri', 'uri']} + }) + + def test_import_bq_uri(self): + client = self.tables_client({ + 'import_data.return_value': None, + }, {}) + client.import_data(dataset_name='name', + bigquery_input_uri='uri') + client.client.import_data.assert_called_with('name', { + 'bigquery_source': {'input_uri': 'uri'} + }) + + def test_list_table_specs(self): + client = self.tables_client({ + 'list_table_specs.return_value': None, + }, {}) + client.list_table_specs(dataset_name='name') + client.client.list_table_specs.assert_called_with('name') + + def test_list_table_specs_not_found(self): + client = self.tables_client({ + 'list_table_specs.side_effect': exceptions.NotFound('not found'), + }, {}) + error = None + try: + client.list_table_specs(dataset_name='name') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + + def test_list_column_specs(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [], + }, {}) + client.list_column_specs(dataset_name='name') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + + def test_update_column_spec_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + error = None + try: + client.update_column_spec(dataset_name='name', + column_spec_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_not_called() + + def test_update_column_spec_display_name_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + error = None + try: + client.update_column_spec(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_not_called() + + def test_update_column_spec_name_no_args(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column/2', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.update_column_spec(dataset_name='name', + column_spec_name='column/2') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_called_with({ + 'name': 'column/2', + 'data_type': { + 'type_code': 'type_code', + } + }) + + def test_update_column_spec_no_args(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.update_column_spec(dataset_name='name', + column_spec_display_name='column') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_called_with({ + 'name': 'column', + 'data_type': { + 'type_code': 'type_code', + } + }) + + def test_update_column_spec_nullable(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.update_column_spec(dataset_name='name', + column_spec_display_name='column', nullable=True) + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_called_with({ + 'name': 'column', + 'data_type': { + 'type_code': 'type_code', + 'nullable': True, + } + }) + + def test_update_column_spec_type_code(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.update_column_spec(dataset_name='name', + column_spec_display_name='column', type_code='type_code2') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_called_with({ + 'name': 'column', + 'data_type': { + 'type_code': 'type_code2', + } + }) + + def test_update_column_spec_type_code_nullable(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.update_column_spec(dataset_name='name', nullable=True, + column_spec_display_name='column', type_code='type_code2') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_called_with({ + 'name': 'column', + 'data_type': { + 'type_code': 'type_code2', + 'nullable': True, + } + }) + + def test_update_column_spec_type_code_nullable_false(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code='type_code') + column_spec_mock.configure_mock(name='column', display_name='column', + data_type=data_type_mock) + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.update_column_spec(dataset_name='name', nullable=False, + column_spec_display_name='column', type_code='type_code2') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_column_spec.assert_called_with({ + 'name': 'column', + 'data_type': { + 'type_code': 'type_code2', + 'nullable': False, + } + }) + + def test_set_target_column_table_not_found(self): + client = self.tables_client({ + 'list_table_specs.side_effect': exceptions.NotFound('err'), + }, {}) + error = None + try: + client.set_target_column(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_not_called() + client.client.update_dataset.assert_not_called() + + def test_set_target_column_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/1', display_name='column') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + error = None + try: + client.set_target_column(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_dataset.assert_not_called() + + def test_set_target_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/1', display_name='column') + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock(target_column_spec_id='2', + weight_column_spec_id='2', + ml_use_column_spec_id='3') + dataset_mock.configure_mock(name='dataset', + tables_dataset_metadata=tables_dataset_metadata_mock) + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.set_target_column(dataset_name='name', + column_spec_display_name='column') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_dataset.assert_called_with({ + 'name': 'dataset', + 'tables_dataset_metadata': { + 'target_column_spec_id': '1', + 'weight_column_spec_id': '2', + 'ml_use_column_spec_id': '3', + } + }) + + def test_set_weight_column_table_not_found(self): + client = self.tables_client({ + 'list_table_specs.side_effect': exceptions.NotFound('err'), + }, {}) + try: + client.set_weight_column(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound: + pass + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_not_called() + client.client.update_dataset.assert_not_called() + + def test_set_weight_column_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/1', display_name='column') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + error = None + try: + client.set_weight_column(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_dataset.assert_not_called() + + def test_set_weight_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/2', display_name='column') + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', + weight_column_spec_id='1', + ml_use_column_spec_id='3') + dataset_mock.configure_mock(name='dataset', + tables_dataset_metadata=tables_dataset_metadata_mock) + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.set_weight_column(dataset_name='name', + column_spec_display_name='column') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_dataset.assert_called_with({ + 'name': 'dataset', + 'tables_dataset_metadata': { + 'target_column_spec_id': '1', + 'weight_column_spec_id': '2', + 'ml_use_column_spec_id': '3', + } + }) + + def test_set_test_train_column_table_not_found(self): + client = self.tables_client({ + 'list_table_specs.side_effect': exceptions.NotFound('err'), + }, {}) + error = None + try: + client.set_test_train_column(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_not_called() + client.client.update_dataset.assert_not_called() + + def test_set_test_train_column_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/1', display_name='column') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + error = None + try: + client.set_test_train_column(dataset_name='name', + column_spec_display_name='column2') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_dataset.assert_not_called() + + def test_set_test_train_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/3', display_name='column') + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', + weight_column_spec_id='2', + ml_use_column_spec_id='2') + dataset_mock.configure_mock(name='dataset', + tables_dataset_metadata=tables_dataset_metadata_mock) + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.set_test_train_column(dataset_name='name', + column_spec_display_name='column') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_dataset.assert_called_with({ + 'name': 'dataset', + 'tables_dataset_metadata': { + 'target_column_spec_id': '1', + 'weight_column_spec_id': '2', + 'ml_use_column_spec_id': '3', + } + }) + + def test_list_models_empty(self): + client = self.tables_client({ + 'list_models.return_value': [], + 'location_path.return_value': LOCATION_PATH, + }, {}) + ds = client.list_models() + client.client.location_path.assert_called_with(PROJECT, REGION) + client.client.list_models.assert_called_with(LOCATION_PATH) + assert ds == [] + + def test_list_models_not_empty(self): + models = ['some_model'] + client = self.tables_client({ + 'list_models.return_value': models, + 'location_path.return_value': LOCATION_PATH, + }, {}) + ds = client.list_models() + client.client.location_path.assert_called_with(PROJECT, REGION) + client.client.list_models.assert_called_with(LOCATION_PATH) + assert len(ds) == 1 + assert ds[0] == 'some_model' + + def test_get_model_name(self): + model_actual = 'model' + client = self.tables_client({ + 'get_model.return_value': model_actual + }, {}) + model = client.get_model(model_name='my_model') + client.client.get_model.assert_called_with('my_model') + assert model == model_actual + + def test_get_no_model(self): + client = self.tables_client({ + 'get_model.side_effect': exceptions.NotFound('err') + }, {}) + error = None + try: + client.get_model(model_name='my_model') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.get_model.assert_called_with('my_model') + + def test_get_model_from_empty_list(self): + client = self.tables_client({'list_models.return_value': []}, {}) + error = None + try: + client.get_model(model_display_name='my_model') + except exceptions.NotFound as e: + error = e + assert error is not None + + def test_get_model_from_list_not_found(self): + client = self.tables_client({ + 'list_models.return_value': [mock.Mock(display_name='not_it')] + }, {}) + error = None + try: + client.get_model(model_display_name='my_model') + except exceptions.NotFound as e: + error = e + assert error is not None + + def test_get_model_from_list(self): + client = self.tables_client({ + 'list_models.return_value': [ + mock.Mock(display_name='not_it'), + mock.Mock(display_name='my_model'), + ] + }, {}) + model = client.get_model(model_display_name='my_model') + assert model.display_name == 'my_model' + + def test_delete_model(self): + model = mock.Mock() + model.configure_mock(name='name') + client = self.tables_client({ + 'delete_model.return_value': None, + }, {}) + client.delete_model(model=model) + client.client.delete_model.assert_called_with('name') + + def test_delete_model_not_found(self): + client = self.tables_client({ + 'list_models.return_value': [], + }, {}) + client.delete_model(model_display_name='not_found') + client.client.delete_model.assert_not_called() + + def test_delete_model_name(self): + client = self.tables_client({ + 'delete_model.return_value': None, + }, {}) + client.delete_model(model_name='name') + client.client.delete_model.assert_called_with('name') + + def test_deploy_model(self): + client = self.tables_client({}, {}) + client.deploy_model(model_name='name') + client.client.deploy_model.assert_called_with('name') + + def test_deploy_model_not_found(self): + client = self.tables_client({ + 'list_models.return_value': [], + }, {}) + error = None + try: + client.deploy_model(model_display_name='name') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.deploy_model.assert_not_called() + + def test_undeploy_model(self): + client = self.tables_client({}, {}) + client.undeploy_model(model_name='name') + client.client.undeploy_model.assert_called_with('name') + + def test_undeploy_model_not_found(self): + client = self.tables_client({ + 'list_models.return_value': [], + }, {}) + error = None + try: + client.undeploy_model(model_display_name='name') + except exceptions.NotFound as e: + error = e + assert error is not None + client.client.undeploy_model.assert_not_called() + + def test_create_model(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/2', display_name='column') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + 'location_path.return_value': LOCATION_PATH, + }, {}) + client.create_model('my_model', dataset_name='my_dataset', + train_budget_milli_node_hours=1000) + client.client.create_model.assert_called_with(LOCATION_PATH, { + 'display_name': 'my_model', + 'dataset_id': 'my_dataset', + 'tables_model_metadata': { + 'train_budget_milli_node_hours': 1000, + }, + }) + + def test_create_model_include_columns(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock1 = mock.Mock() + column_spec_mock1.configure_mock(name='column/1', + display_name='column1') + column_spec_mock2 = mock.Mock() + column_spec_mock2.configure_mock(name='column/2', + display_name='column2') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock1, + column_spec_mock2], + 'location_path.return_value': LOCATION_PATH, + }, {}) + client.create_model('my_model', dataset_name='my_dataset', + include_column_spec_names=['column1'], + train_budget_milli_node_hours=1000) + client.client.create_model.assert_called_with(LOCATION_PATH, { + 'display_name': 'my_model', + 'dataset_id': 'my_dataset', + 'tables_model_metadata': { + 'train_budget_milli_node_hours': 1000, + 'input_feature_column_specs': [column_spec_mock1] + }, + }) + + def test_create_model_exclude_columns(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock1 = mock.Mock() + column_spec_mock1.configure_mock(name='column/1', + display_name='column1') + column_spec_mock2 = mock.Mock() + column_spec_mock2.configure_mock(name='column/2', + display_name='column2') + client = self.tables_client({ + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock1, + column_spec_mock2], + 'location_path.return_value': LOCATION_PATH, + }, {}) + client.create_model('my_model', dataset_name='my_dataset', + exclude_column_spec_names=['column1'], + train_budget_milli_node_hours=1000) + client.client.create_model.assert_called_with(LOCATION_PATH, { + 'display_name': 'my_model', + 'dataset_id': 'my_dataset', + 'tables_model_metadata': { + 'train_budget_milli_node_hours': 1000, + 'input_feature_column_specs': [column_spec_mock2] + }, + }) + + def test_create_model_invalid_hours_small(self): + client = self.tables_client({}, {}) + error = None + try: + client.create_model('my_model', dataset_name='my_dataset', + train_budget_milli_node_hours=1) + except ValueError as e: + error = e + assert error is not None + client.client.create_model.assert_not_called() + + def test_create_model_invalid_hours_large(self): + client = self.tables_client({}, {}) + error = None + try: + client.create_model('my_model', dataset_name='my_dataset', + train_budget_milli_node_hours=1000000) + except ValueError as e: + error = e + assert error is not None + client.client.create_model.assert_not_called() + + def test_create_model_invalid_no_dataset(self): + client = self.tables_client({}, {}) + error = None + try: + client.create_model('my_model', + train_budget_milli_node_hours=1000) + except ValueError as e: + error = e + assert error is not None + client.client.get_dataset.assert_not_called() + client.client.create_model.assert_not_called() + + def test_create_model_invalid_include_exclude(self): + client = self.tables_client({}, {}) + error = None + try: + client.create_model('my_model', dataset_name='my_dataset', + include_column_spec_names=['a'], + exclude_column_spec_names=['b'], + train_budget_milli_node_hours=1000) + except ValueError as e: + error = e + assert error is not None + client.client.get_dataset.assert_not_called() + client.client.create_model.assert_not_called() + + def test_predict_from_array(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec = mock.Mock(display_name='a', data_type=data_type) + model_metadata = mock.Mock(input_feature_column_specs=[column_spec]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, + name='my_model') + client = self.tables_client({ + 'get_model.return_value': model + }, {}) + client.predict(['1'], model_name='my_model') + client.prediction_client.predict.assert_called_with('my_model', { + 'row': { + 'values': [{'string_value': '1'}] + } + }) + + def test_predict_from_dict(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_a = mock.Mock(display_name='a', data_type=data_type) + column_spec_b = mock.Mock(display_name='b', data_type=data_type) + model_metadata = mock.Mock(input_feature_column_specs=[ + column_spec_a, + column_spec_b, + ]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, + name='my_model') + client = self.tables_client({ + 'get_model.return_value': model + }, {}) + client.predict({'a': '1', 'b': '2'}, model_name='my_model') + client.prediction_client.predict.assert_called_with('my_model', { + 'row': { + 'values': [{'string_value': '1'}, {'string_value': '2'}] + } + }) + + def test_predict_from_dict_missing(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_a = mock.Mock(display_name='a', data_type=data_type) + column_spec_b = mock.Mock(display_name='b', data_type=data_type) + model_metadata = mock.Mock(input_feature_column_specs=[ + column_spec_a, + column_spec_b, + ]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, + name='my_model') + client = self.tables_client({ + 'get_model.return_value': model + }, {}) + client.predict({'a': '1'}, model_name='my_model') + client.prediction_client.predict.assert_called_with('my_model', { + 'row': { + 'values': [{'string_value': '1'}, {'string_value': None}] + } + }) + + def test_predict_all_types(self): + float_type = mock.Mock(type_code=data_types_pb2.FLOAT64) + timestamp_type = mock.Mock(type_code=data_types_pb2.TIMESTAMP) + string_type = mock.Mock(type_code=data_types_pb2.STRING) + array_type = mock.Mock(type_code=data_types_pb2.ARRAY) + struct_type = mock.Mock(type_code=data_types_pb2.STRUCT) + category_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_float = mock.Mock(display_name='float', + data_type=float_type) + column_spec_timestamp = mock.Mock(display_name='timestamp', + data_type=timestamp_type) + column_spec_string = mock.Mock(display_name='string', + data_type=string_type) + column_spec_array = mock.Mock(display_name='array', + data_type=array_type) + column_spec_struct = mock.Mock(display_name='struct', + data_type=struct_type) + column_spec_category = mock.Mock(display_name='category', + data_type=category_type) + model_metadata = mock.Mock(input_feature_column_specs=[ + column_spec_float, + column_spec_timestamp, + column_spec_string, + column_spec_array, + column_spec_struct, + column_spec_category, + ]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, + name='my_model') + client = self.tables_client({ + 'get_model.return_value': model + }, {}) + client.predict({ + 'float': 1.0, + 'timestamp': 'EST', + 'string': 'text', + 'array': [1], + 'struct': {'a': 'b'}, + 'category': 'a', + } , model_name='my_model') + client.prediction_client.predict.assert_called_with('my_model', { + 'row': { + 'values': [ + {'number_value': 1.0}, + {'string_value': 'EST'}, + {'string_value': 'text'}, + {'list_value': [1]}, + {'struct_value': {'a': 'b'}}, + {'string_value': 'a'}, + ], + } + }) + + def test_predict_from_array_missing(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec = mock.Mock(display_name='a', data_type=data_type) + model_metadata = mock.Mock(input_feature_column_specs=[column_spec]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, + name='my_model') + client = self.tables_client({ + 'get_model.return_value': model + }, {}) + error = None + try: + client.predict([], model_name='my_model') + except ValueError as e: + error = e + assert error is not None + client.prediction_client.predict.assert_not_called() + + def test_batch_predict(self): + client = self.tables_client({}, {}) + client.batch_predict(model_name='my_model', + gcs_input_uris='gs://input', + gcs_output_uri_prefix='gs://output') + client.prediction_client.batch_predict.assert_called_with('my_model', + { 'gcs_source': { + 'input_uris': ['gs://input'], + }}, { 'gcs_source': { + 'output_uri_prefix': 'gs://output', + }}, + ) + + def test_batch_predict_missing_input_gcs_uri(self): + client = self.tables_client({}, {}) + error = None + try: + client.batch_predict(model_name='my_model', + gcs_input_uris=None, + gcs_output_uri_prefix='gs://output') + except ValueError as e: + error = e + assert error is not None + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_input_gcs_uri(self): + client = self.tables_client({}, {}) + error = None + try: + client.batch_predict(model_name='my_model', + gcs_input_uris='gs://input', + gcs_output_uri_prefix=None) + except ValueError as e: + error = e + assert error is not None + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_model(self): + client = self.tables_client({ + 'list_models.return_value': [], + }, {}) + error = None + try: + client.batch_predict(model_display_name='my_model', + gcs_input_uris='gs://input', + gcs_output_uri_prefix='gs://output') + except exceptions.NotFound as e: + error = e + assert error is not None + client.prediction_client.batch_predict.assert_not_called() From e19f141b54f146c04b1b760b44b6d2dd4543513e Mon Sep 17 00:00:00 2001 From: jonathan1920 Date: Fri, 19 Jul 2019 15:21:01 -0500 Subject: [PATCH 04/11] added two new func: set time, get table address (#23) * added two new func: set time, get table address * changed indentation --- .../automl_v1beta1/tables/tables_client.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 617cbd62d51a..970fb720f76c 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -201,6 +201,21 @@ def __dataset_name_from_args(self, dataset=None, dataset_display_name=None, region=region ) return dataset_name + + def __table_spec_full_name_from_args(self, table_spec_index=0, + dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region + ) + table_specs = [t for t in self.list_table_specs(dataset_name=dataset_name)] + table_spec_full_id = table_specs[table_spec_index] + table_spec_full_id = table_specs[table_spec_index].name + return table_spec_full_id def __model_name_from_args(self, model=None, model_display_name=None, model_name=None, project=None, region=None): @@ -950,6 +965,107 @@ def set_target_column(self, dataset=None, dataset_display_name=None, return self.client.update_dataset(request) + def set_time_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, table_spec_name=None, table_spec_index=0, + column_spec_name=None, column_spec_display_name=None, + project=None, region=None): + """Sets the time column which designates which data + will be of type timestamp and will be used for the timeseries data. + . This column must be of type timestamp. + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> client.set_time_column(dataset_display_name='my_dataset', + ... column_spec__name='Unix Time') + ... + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to set as + the time column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + time column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose time column + you want to set . If not supplied, the client can determine + this name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the time + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the time column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the time column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the time column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the time column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the time column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the time column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region + ) + column_spec_id = column_spec_name.rsplit('/', 1)[-1] + + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + table_spec_full_id = self.__table_spec_full_name_from_args(dataset_name=dataset_name) + + my_table_spec = { + 'name': table_spec_full_id, + 'time_column_spec_id': column_spec_id + } + + response = self.client.update_table_spec(my_table_spec) + def set_weight_column(self, dataset=None, dataset_display_name=None, dataset_name=None, table_spec_name=None, table_spec_index=0, column_spec_name=None, column_spec_display_name=None, From a741c3f6617c83f98ebbc4ac119a26fa0dbd882d Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Fri, 19 Jul 2019 14:18:06 -0400 Subject: [PATCH 05/11] Add system tests --- automl/README.rst | 15 +- .../automl_v1beta1/tables/tables_client.py | 285 ++++++++++++++++-- .../v1beta1/test_system_tables_client_v1.py | 223 ++++++++++++++ .../v1beta1/test_tables_client_v1beta1.py | 109 ++++++- 4 files changed, 600 insertions(+), 32 deletions(-) create mode 100644 automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py diff --git a/automl/README.rst b/automl/README.rst index bfcb555689e4..994cf397ef55 100644 --- a/automl/README.rst +++ b/automl/README.rst @@ -119,6 +119,17 @@ development environment: pip install -r ../docs/requirements.txt pip install -U mock pytest -3. To test any changes you've made, run `pytest` in this directory. - +3. If you want to run all tests, you will need a billing-enabled + `GCP project`_, and a `service account`_ with access to the AutoML APIs. + Note: the first time the tests run in a new project it will take a _long_ + time, on the order of 2-3 hours. This is one-time setup that will be skipped + in future runs. + +.. _service account: https://cloud.google.com/iam/docs/creating-managing-service-accounts +.. _GCP project: https://cloud.google.com/resource-manager/docs/creating-managing-projects + +.. code-block:: console + + export PROJECT_ID= GOOGLE_APPLICATION_CREDENTIALS= + pytest diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 970fb720f76c..01e9f07d4ed5 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -201,19 +201,21 @@ def __dataset_name_from_args(self, dataset=None, dataset_display_name=None, region=region ) return dataset_name - - def __table_spec_full_name_from_args(self, table_spec_index=0, + + def __table_spec_name_from_args(self, table_spec_index=0, dataset=None, dataset_display_name=None, dataset_name=None, project=None, region=None): - dataset_name = self.__dataset_name_from_args(dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, project=project, region=region ) - table_specs = [t for t in self.list_table_specs(dataset_name=dataset_name)] - table_spec_full_id = table_specs[table_spec_index] + + table_specs = [t for t in + self.list_table_specs(dataset_name=dataset_name) + ] + table_spec_full_id = table_specs[table_spec_index].name return table_spec_full_id @@ -274,19 +276,21 @@ def __column_spec_name_from_args(self, dataset=None, dataset_display_name=None, return column_spec_name - def __type_code_to_value_type(self, type_code): - if type_code == data_types_pb2.FLOAT64: - return 'number_value' + def __type_code_to_value_type(self, type_code, value): + if value is None: + return {'null_value': 0} + elif type_code == data_types_pb2.FLOAT64: + return {'number_value': value} elif type_code == data_types_pb2.TIMESTAMP: - return 'string_value' + return {'string_value': value} elif type_code == data_types_pb2.STRING: - return 'string_value' + return {'string_value': value} elif type_code == data_types_pb2.ARRAY: - return 'list_value' + return {'list_value': value} elif type_code == data_types_pb2.STRUCT: - return 'struct_value' + return {'struct_value': value} elif type_code == data_types_pb2.CATEGORY: - return 'string_value' + return {'string_value': value} else: raise ValueError('Unknown type_code: {}'.format(type_code)) @@ -967,11 +971,12 @@ def set_target_column(self, dataset=None, dataset_display_name=None, def set_time_column(self, dataset=None, dataset_display_name=None, dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, + column_spec_name=None, column_spec_display_name=None, project=None, region=None): - """Sets the time column which designates which data - will be of type timestamp and will be used for the timeseries data. - . This column must be of type timestamp. + """Sets the time column which designates which data will be of type + timestamp and will be used for the timeseries data. + This column must be of type timestamp. + Example: >>> from google.cloud import automl_v1beta1 >>> @@ -980,8 +985,9 @@ def set_time_column(self, dataset=None, dataset_display_name=None, ... project='my-project', region='us-central1') ... >>> client.set_time_column(dataset_display_name='my_dataset', - ... column_spec__name='Unix Time') + ... column_spec_name='Unix Time') ... + Args: project (Optional[string]): If you have initialized the client with a value for `project` @@ -1056,15 +1062,90 @@ def set_time_column(self, dataset=None, dataset_display_name=None, dataset_display_name=dataset_display_name, project=project, region=region) - - table_spec_full_id = self.__table_spec_full_name_from_args(dataset_name=dataset_name) - + + table_spec_full_id = self.__table_spec_name_from_args( + dataset_name=dataset_name) + my_table_spec = { 'name': table_spec_full_id, 'time_column_spec_id': column_spec_id } - response = self.client.update_table_spec(my_table_spec) + self.client.update_table_spec(my_table_spec) + return self.get_dataset(dataset_name=dataset_name) + + def clear_time_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + """Clears the time column which designates which data will be of type + timestamp and will be used for the timeseries data. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.tables.ClientHelper( + ... client=automl_v1beta1.AutoMlClient(), + ... project='my-project', region='us-central1') + ... + >>> client.set_time_column(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the time column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the time column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the time column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the time column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the time column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the time column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + + table_spec_full_id = self.__table_spec_name_from_args( + dataset_name=dataset_name) + + my_table_spec = { + 'name': table_spec_full_id, + 'time_column_spec_id': None, + } + + response = self.client.update_table_spec(my_table_spec) + return self.get_dataset(dataset_name=dataset_name) def set_weight_column(self, dataset=None, dataset_display_name=None, dataset_name=None, table_spec_name=None, table_spec_index=0, @@ -1173,6 +1254,79 @@ def set_weight_column(self, dataset=None, dataset_display_name=None, return self.client.update_dataset(request) + def clear_weight_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + """Clears the weight column for a given dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.clear_weight_column(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the weight column of. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update the weight column of. This must be supplied + if `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the weight column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the weight column of. This + must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the weight column of. + If no `table_spec_name` is supplied, this will be used together + with `table_spec_index` to infer the name of table to update + the weight column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset = self.__dataset_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, 'weight_column_spec_id', + None) + + request = { + 'name': dataset.name, + 'tables_dataset_metadata': metadata, + } + + return self.client.update_dataset(request) + def set_test_train_column(self, dataset=None, dataset_display_name=None, dataset_name=None, table_spec_name=None, table_spec_index=0, column_spec_name=None, column_spec_display_name=None, @@ -1279,6 +1433,80 @@ def set_test_train_column(self, dataset=None, dataset_display_name=None, return self.client.update_dataset(request) + def clear_test_train_column(self, dataset=None, dataset_display_name=None, + dataset_name=None, project=None, region=None): + """Clears the test/train (ml_use) column which designates which data + belongs to the test and train sets. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.clear_test_train_column(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the test/train column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the test/train column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the test/train column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the test/train column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the test/train column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the test/train column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset = self.__dataset_from_args(dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, 'ml_use_column_spec_id', + None) + + request = { + 'name': dataset.name, + 'tables_dataset_metadata': metadata, + } + + return self.client.update_dataset(request) + def list_models(self, project=None, region=None): """List all models in a particular project and region. @@ -1756,8 +1984,6 @@ def predict(self, inputs, model=None, model_name=None, project=project, region=region ) - print(model) - print(model.tables_model_metadata) column_specs = model.tables_model_metadata.input_feature_column_specs if type(inputs) == dict: @@ -1766,19 +1992,20 @@ def predict(self, inputs, model=None, model_name=None, if len(inputs) != len(column_specs): raise ValueError(('Dimension mismatch, the number of provided ' 'inputs ({}) does not match that of the model ' - '({})').format( - len(inputs), len(column_specs))) + '({})').format(len(inputs), len(column_specs))) values = [] for i, c in zip(inputs, column_specs): - value_type = self.__type_code_to_value_type(c.data_type.type_code) - values.append({value_type: i}) + value_type = self.__type_code_to_value_type( + c.data_type.type_code, i + ) + values.append(value_type) request = { 'row': { 'values': values } - } + } return self.prediction_client.predict(model.name, request) diff --git a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py new file mode 100644 index 000000000000..6c2a30386b2a --- /dev/null +++ b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest +import random +import string +import time + +from google.cloud import automl_v1beta1 +from google.api_core import exceptions +from google.cloud.automl_v1beta1.gapic import enums + +PROJECT = os.environ['PROJECT_ID'] +REGION = 'us-central1' +MAX_WAIT_TIME_SECONDS = 30 +MAX_SLEEP_TIME_SECONDS = 5 +STATIC_DATASET='test_dataset_do_not_delete' +#STATIC_MODEL='test_model_do_not_delete' +STATIC_MODEL='test_online_model_do_not_delete' + +ID = '{rand}_{time}'.format( + rand=''.join([random.choice(string.ascii_letters + string.digits) + for n in range(4)]), + time=int(time.time()) +) + +def _id(name): + return '{}_{}'.format(name, ID) + +class TestSystemTablesClient(object): + def cancel_and_wait(self, op): + op.cancel() + start = time.time() + sleep_time = 1 + while time.time() - start < MAX_WAIT_TIME_SECONDS: + if op.cancelled(): + return + time.sleep(sleep_time) + sleep_time = min(sleep_time * 2, MAX_SLEEP_TIME_SECONDS) + assert op.cancelled() + + def test_list_datasets(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + # we need to unroll the iterator to actually make client calls + [d for d in client.list_datasets()] + + def test_list_models(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + # we need to unroll the iterator to actually make client calls + [m for m in client.list_models()] + + def test_create_delete_dataset(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + display_name = _id('t_cr_dl') + dataset = client.create_dataset(display_name) + assert dataset is not None + assert dataset.name == client.get_dataset( + dataset_display_name=display_name + ).name + client.delete_dataset(dataset=dataset) + + def test_import_data(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + display_name = _id('t_import') + dataset = client.create_dataset(display_name) + op = client.import_data(dataset=dataset, + gcs_input_uris='gs://cloud-ml-tables-data/bank-marketing.csv') + self.cancel_and_wait(op) + client.delete_dataset(dataset=dataset) + + def ensure_dataset_ready(self, client): + dataset = None + try: + dataset = client.get_dataset(dataset_display_name=STATIC_DATASET) + except exceptions.NotFound: + dataset = client.create_dataset(STATIC_DATASET) + + if dataset.example_count is None or dataset.example_count == 0: + op = client.import_data(dataset=dataset, + gcs_input_uris='gs://cloud-ml-tables-data/bank-marketing.csv') + op.result() + dataset = client.get_dataset(dataset_name=dataset.name) + + return dataset + + def test_list_column_specs(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + # we need to unroll the iterator to actually make client calls + [d for d in client.list_column_specs(dataset=dataset)] + + def test_list_table_specs(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + # we need to unroll the iterator to actually make client calls + [d for d in client.list_table_specs(dataset=dataset)] + + def test_set_column_nullable(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.update_column_spec(dataset=dataset, + column_spec_display_name='POutcome', nullable=True) + columns = {c.display_name: c + for c in client.list_column_specs(dataset=dataset)} + assert columns['POutcome'].data_type.nullable == True + + def test_set_target_column(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_target_column(dataset=dataset, + column_spec_display_name='Age') + columns = {c.display_name: c + for c in client.list_column_specs(dataset=dataset)} + dataset = client.get_dataset(dataset_name=dataset.name) + metadata = dataset.tables_dataset_metadata + assert columns['Age'].name.endswith( + '/{}'.format(metadata.target_column_spec_id) + ) + + def test_set_weight_column(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_weight_column(dataset=dataset, + column_spec_display_name='Duration') + columns = {c.display_name: c + for c in client.list_column_specs(dataset=dataset)} + dataset = client.get_dataset(dataset_name=dataset.name) + metadata = dataset.tables_dataset_metadata + assert columns['Duration'].name.endswith( + '/{}'.format(metadata.weight_column_spec_id) + ) + + def test_set_weight_and_target_column(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_weight_column(dataset=dataset, + column_spec_display_name='Day') + client.set_target_column(dataset=dataset, + column_spec_display_name='Campaign') + columns = {c.display_name: c + for c in client.list_column_specs(dataset=dataset)} + dataset = client.get_dataset(dataset_name=dataset.name) + metadata = dataset.tables_dataset_metadata + assert columns['Day'].name.endswith( + '/{}'.format(metadata.weight_column_spec_id) + ) + assert columns['Campaign'].name.endswith( + '/{}'.format(metadata.target_column_spec_id) + ) + + def test_create_delete_model(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_target_column(dataset=dataset, + column_spec_display_name='Deposit') + display_name = _id('t_cr_dl') + op = client.create_model(display_name, dataset=dataset, + train_budget_milli_node_hours=1000) + self.cancel_and_wait(op) + client.delete_model(model_display_name=display_name) + + def test_online_predict(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_online(client) + result = client.predict(inputs={ + 'Age': 31, + 'Balance': 200, + 'Campaign': 2, + 'Contact': 'cellular', + 'Day': 4, + 'Default': 'no', + 'Duration': 12, + 'Education': 'primary', + 'Housing': 'yes', + 'Job': 'blue-collar', + 'Loan': 'no', + 'MaritalStatus': 'divorced', + 'Month': 'jul', + 'PDays': 4, + 'POutcome': None, + 'Previous': 12 + }, model=model) + assert result is not None + + def ensure_model_online(self, client): + model = self.ensure_model_ready(client) + if model.deployment_state != enums.Model.DeploymentState.DEPLOYED: + client.deploy_model(model=model).result() + + return client.get_model(model_name=model.name) + + def ensure_model_ready(self, client): + try: + return client.get_model(model_display_name=STATIC_MODEL) + except exceptions.NotFound: + pass + + dataset = self.ensure_dataset_ready(client) + client.set_target_column(dataset=dataset, + column_spec_display_name='Deposit') + client.clear_weight_column(dataset=dataset) + client.clear_test_train_column(dataset=dataset) + client.update_column_spec(dataset=dataset, + column_spec_display_name='POutcome', nullable=True) + op = client.create_model(STATIC_MODEL, dataset=dataset, + train_budget_milli_node_hours=1000) + op.result() + return client.get_model(model_display_name=STATIC_MODEL) + diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index 5e07184bc906..1784d5274477 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -550,6 +550,27 @@ def test_set_weight_column(self): } }) + def test_clear_weight_column(self): + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', + weight_column_spec_id='2', + ml_use_column_spec_id='3') + dataset_mock.configure_mock(name='dataset', + tables_dataset_metadata=tables_dataset_metadata_mock) + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + }, {}) + client.clear_weight_column(dataset_name='name') + client.client.update_dataset.assert_called_with({ + 'name': 'dataset', + 'tables_dataset_metadata': { + 'target_column_spec_id': '1', + 'weight_column_spec_id': None, + 'ml_use_column_spec_id': '3', + } + }) + def test_set_test_train_column_table_not_found(self): client = self.tables_client({ 'list_table_specs.side_effect': exceptions.NotFound('err'), @@ -617,6 +638,65 @@ def test_set_test_train_column(self): } }) + def test_clear_test_train_column(self): + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', + weight_column_spec_id='2', + ml_use_column_spec_id='2') + dataset_mock.configure_mock(name='dataset', + tables_dataset_metadata=tables_dataset_metadata_mock) + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + }, {}) + client.clear_test_train_column(dataset_name='name') + client.client.update_dataset.assert_called_with({ + 'name': 'dataset', + 'tables_dataset_metadata': { + 'target_column_spec_id': '1', + 'weight_column_spec_id': '2', + 'ml_use_column_spec_id': None, + } + }) + + def test_set_time_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name='column/3', display_name='column') + dataset_mock = mock.Mock() + dataset_mock.configure_mock(name='dataset') + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + 'list_table_specs.return_value': [table_spec_mock], + 'list_column_specs.return_value': [column_spec_mock], + }, {}) + client.set_time_column(dataset_name='name', + column_spec_display_name='column') + client.client.list_table_specs.assert_called_with('name') + client.client.list_column_specs.assert_called_with('table') + client.client.update_table_spec.assert_called_with({ + 'name': 'table', + 'time_column_spec_id': '3', + }) + + def test_clear_time_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name='table') + dataset_mock = mock.Mock() + dataset_mock.configure_mock(name='dataset') + client = self.tables_client({ + 'get_dataset.return_value': dataset_mock, + 'list_table_specs.return_value': [table_spec_mock], + }, {}) + client.clear_time_column(dataset_name='name') + client.client.update_table_spec.assert_called_with({ + 'name': 'table', + 'time_column_spec_id': None, + }) + def test_list_models_empty(self): client = self.tables_client({ 'list_models.return_value': [], @@ -713,6 +793,16 @@ def test_delete_model_name(self): client.delete_model(model_name='name') client.client.delete_model.assert_called_with('name') + def test_deploy_model_no_args(self): + client = self.tables_client({}, {}) + error = None + try: + client.deploy_model() + except ValueError as e: + error = e + assert error is not None + client.client.deploy_model.assert_not_called() + def test_deploy_model(self): client = self.tables_client({}, {}) client.deploy_model(model_name='name') @@ -927,7 +1017,7 @@ def test_predict_from_dict_missing(self): client.predict({'a': '1'}, model_name='my_model') client.prediction_client.predict.assert_called_with('my_model', { 'row': { - 'values': [{'string_value': '1'}, {'string_value': None}] + 'values': [{'string_value': '1'}, {'null_value': 0}] } }) @@ -950,6 +1040,8 @@ def test_predict_all_types(self): data_type=struct_type) column_spec_category = mock.Mock(display_name='category', data_type=category_type) + column_spec_null = mock.Mock(display_name='null', + data_type=category_type) model_metadata = mock.Mock(input_feature_column_specs=[ column_spec_float, column_spec_timestamp, @@ -957,6 +1049,7 @@ def test_predict_all_types(self): column_spec_array, column_spec_struct, column_spec_category, + column_spec_null, ]) model = mock.Mock() model.configure_mock(tables_model_metadata=model_metadata, @@ -971,6 +1064,7 @@ def test_predict_all_types(self): 'array': [1], 'struct': {'a': 'b'}, 'category': 'a', + 'null': None, } , model_name='my_model') client.prediction_client.predict.assert_called_with('my_model', { 'row': { @@ -981,6 +1075,7 @@ def test_predict_all_types(self): {'list_value': [1]}, {'struct_value': {'a': 'b'}}, {'string_value': 'a'}, + {'null_value': 0}, ], } }) @@ -1053,3 +1148,15 @@ def test_batch_predict_missing_model(self): error = e assert error is not None client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_no_model(self): + client = self.tables_client({}, {}) + error = None + try: + client.batch_predict(gcs_input_uris='gs://input', + gcs_output_uri_prefix='gs://output') + except ValueError as e: + error = e + assert error is not None + client.client.list_models.assert_not_called() + client.prediction_client.batch_predict.assert_not_called() From b7d951ec1524c7336cda2ff02cddd1330c71ddc4 Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Mon, 22 Jul 2019 09:16:35 -0400 Subject: [PATCH 06/11] Address linter & python2.7 import errors --- .../google/cloud/automl_v1beta1/__init__.py | 3 +- .../automl_v1beta1/tables/tables_client.py | 1059 +++++++----- .../v1beta1/test_system_tables_client_v1.py | 164 +- .../v1beta1/test_tables_client_v1beta1.py | 1439 +++++++++-------- 4 files changed, 1475 insertions(+), 1190 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/__init__.py b/automl/google/cloud/automl_v1beta1/__init__.py index ae08470889ef..20055ef19a50 100644 --- a/automl/google/cloud/automl_v1beta1/__init__.py +++ b/automl/google/cloud/automl_v1beta1/__init__.py @@ -38,5 +38,4 @@ class TablesClient(tables_client.TablesClient): __doc__ = tables_client.TablesClient.__doc__ -__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", - "TablesClient") +__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", "TablesClient") diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 01e9f07d4ed5..e243a4533813 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -20,11 +20,12 @@ from google.api_core.gapic_v1 import client_info from google.api_core import exceptions -from google.cloud import automl_v1beta1 +from google.cloud.automl_v1beta1 import gapic from google.cloud.automl_v1beta1.proto import data_types_pb2 _GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version + class TablesClient(object): """ AutoML Tables API helper. @@ -33,8 +34,15 @@ class TablesClient(object): in particular for the `AutoML Tables product `_. """ - def __init__(self, project=None, region='us-central1', client=None, - prediction_client=None, **kwargs): + + def __init__( + self, + project=None, + region="us-central1", + client=None, + prediction_client=None, + **kwargs + ): """Constructor. Example: @@ -81,28 +89,27 @@ def __init__(self, project=None, region='us-central1', client=None, should be set through client_options. """ version = _GAPIC_LIBRARY_VERSION - user_agent = 'automl-tables-wrapper/{}'.format(version) + user_agent = "automl-tables-wrapper/{}".format(version) - client_info_ = kwargs.get('client_info') + client_info_ = kwargs.get("client_info") if client_info_ is None: client_info_ = client_info.ClientInfo( - user_agent=user_agent, - gapic_version=version + user_agent=user_agent, gapic_version=version ) else: client_info_.user_agent = user_agent client_info_.gapic_version = version if client is None: - self.client = automl_v1beta1.AutoMlClient(client_info=client_info_, - **kwargs) + self.client = gapic.auto_ml_client.AutoMlClient( + client_info=client_info_, **kwargs + ) else: self.client = client if prediction_client is None: - self.prediction_client = automl_v1beta1.PredictionServiceClient( - client_info=client_info_, - **kwargs + self.prediction_client = gapic.prediction_service_client.PredictionServiceClient( + client_info=client_info_, **kwargs ) else: self.prediction_client = prediction_client @@ -113,16 +120,20 @@ def __init__(self, project=None, region='us-central1', client=None, def __location_path(self, project=None, region=None): if project is None: if self.project is None: - raise ValueError('Either initialize your client with a value ' - 'for \'project\', or provide \'project\' as a ' - 'parameter for this method.') + raise ValueError( + "Either initialize your client with a value " + "for 'project', or provide 'project' as a " + "parameter for this method." + ) project = self.project if region is None: if self.region is None: - raise ValueError('Either initialize your client with a value ' - 'for \'region\', or provide \'region\' as a ' - 'parameter for this method.') + raise ValueError( + "Either initialize your client with a value " + "for 'region', or provide 'region' as a " + "parameter for this method." + ) region = self.region return self.client.location_path(project, region) @@ -131,168 +142,201 @@ def __location_path(self, project=None, region=None): # we need to manually copy user-updated fields over def __update_metadata(self, metadata, k, v): new_metadata = {} - new_metadata['ml_use_column_spec_id'] = metadata.ml_use_column_spec_id - new_metadata['weight_column_spec_id'] = metadata.weight_column_spec_id - new_metadata['target_column_spec_id'] = metadata.target_column_spec_id + new_metadata["ml_use_column_spec_id"] = metadata.ml_use_column_spec_id + new_metadata["weight_column_spec_id"] = metadata.weight_column_spec_id + new_metadata["target_column_spec_id"] = metadata.target_column_spec_id new_metadata[k] = v return new_metadata - def __dataset_from_args(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): - if (dataset is None - and dataset_display_name is None - and dataset_name is None): - raise ValueError('One of \'dataset\', \'dataset_name\' or ' - '\'dataset_display_name\' must be set.') + def __dataset_from_args( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): + if dataset is None and dataset_display_name is None and dataset_name is None: + raise ValueError( + "One of 'dataset', 'dataset_name' or " + "'dataset_display_name' must be set." + ) # we prefer to make a live call here in the case that the # dataset object is out-of-date if dataset is not None: dataset_name = dataset.name return self.get_dataset( - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - project=project, - region=region + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + project=project, + region=region, ) - def __model_from_args(self, model=None, model_display_name=None, - model_name=None, project=None, region=None): - if (model is None - and model_display_name is None - and model_name is None): - raise ValueError('One of \'model\', \'model_name\' or ' - '\'model_display_name\' must be set.') + def __model_from_args( + self, + model=None, + model_display_name=None, + model_name=None, + project=None, + region=None, + ): + if model is None and model_display_name is None and model_name is None: + raise ValueError( + "One of 'model', 'model_name' or " "'model_display_name' must be set." + ) # we prefer to make a live call here in the case that the # model object is out-of-date if model is not None: model_name = model.name return self.get_model( - model_display_name=model_display_name, - model_name=model_name, - project=project, - region=region + model_display_name=model_display_name, + model_name=model_name, + project=project, + region=region, ) - def __dataset_name_from_args(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): - if (dataset is None - and dataset_display_name is None - and dataset_name is None): - raise ValueError('One of \'dataset\', \'dataset_name\' or ' - '\'dataset_display_name\' must be set.') + def __dataset_name_from_args( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): + if dataset is None and dataset_display_name is None and dataset_name is None: + raise ValueError( + "One of 'dataset', 'dataset_name' or " + "'dataset_display_name' must be set." + ) if dataset_name is None: if dataset is None: dataset = self.get_dataset( - dataset_display_name=dataset_display_name, - project=project, - region=region + dataset_display_name=dataset_display_name, + project=project, + region=region, ) dataset_name = dataset.name else: # we do this to force a NotFound error when needed - self.get_dataset( - dataset_name=dataset_name, - project=project, - region=region - ) + self.get_dataset(dataset_name=dataset_name, project=project, region=region) return dataset_name - def __table_spec_name_from_args(self, table_spec_index=0, - dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region + def __table_spec_name_from_args( + self, + table_spec_index=0, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, ) - table_specs = [t for t in - self.list_table_specs(dataset_name=dataset_name) - ] + table_specs = [t for t in self.list_table_specs(dataset_name=dataset_name)] table_spec_full_id = table_specs[table_spec_index].name return table_spec_full_id - def __model_name_from_args(self, model=None, model_display_name=None, - model_name=None, project=None, region=None): - if (model is None - and model_display_name is None - and model_name is None): - raise ValueError('One of \'model\', \'model_name\' or ' - '\'model_display_name\' must be set.') + def __model_name_from_args( + self, + model=None, + model_display_name=None, + model_name=None, + project=None, + region=None, + ): + if model is None and model_display_name is None and model_name is None: + raise ValueError( + "One of 'model', 'model_name' or " "'model_display_name' must be set." + ) if model_name is None: if model is None: model = self.get_model( - model_display_name=model_display_name, - project=project, - region=region + model_display_name=model_display_name, + project=project, + region=region, ) model_name = model.name else: # we do this to force a NotFound error when needed - self.get_model( - model_name=model_name, - project=project, - region=region - ) + self.get_model(model_name=model_name, project=project, region=region) return model_name - def __column_spec_name_from_args(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, - project=None, region=None): - column_specs = self.list_column_specs(dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - project=project, - region=region) + def __column_spec_name_from_args( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + ): + column_specs = self.list_column_specs( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + project=project, + region=region, + ) if column_spec_display_name is not None: column_specs = {s.display_name: s for s in column_specs} if column_specs.get(column_spec_display_name) is None: - raise exceptions.NotFound('No column with ' + - 'column_spec_display_name: \'{}\' found'.format( - column_spec_display_name - )) + raise exceptions.NotFound( + "No column with " + + "column_spec_display_name: '{}' found".format( + column_spec_display_name + ) + ) column_spec_name = column_specs[column_spec_display_name].name elif column_spec_name is not None: column_specs = {s.name: s for s in column_specs} if column_specs.get(column_spec_name) is None: - raise exceptions.NotFound('No column with ' + - 'column_spec_name: \'{}\' found'.format( - column_spec_name - )) + raise exceptions.NotFound( + "No column with " + + "column_spec_name: '{}' found".format(column_spec_name) + ) else: - raise ValueError('Either supply \'column_spec_name\' or ' - '\'column_spec_display_name\' for the column to update') + raise ValueError( + "Either supply 'column_spec_name' or " + "'column_spec_display_name' for the column to update" + ) return column_spec_name def __type_code_to_value_type(self, type_code, value): if value is None: - return {'null_value': 0} + return {"null_value": 0} elif type_code == data_types_pb2.FLOAT64: - return {'number_value': value} + return {"number_value": value} elif type_code == data_types_pb2.TIMESTAMP: - return {'string_value': value} + return {"string_value": value} elif type_code == data_types_pb2.STRING: - return {'string_value': value} + return {"string_value": value} elif type_code == data_types_pb2.ARRAY: - return {'list_value': value} + return {"list_value": value} elif type_code == data_types_pb2.STRUCT: - return {'struct_value': value} + return {"struct_value": value} elif type_code == data_types_pb2.CATEGORY: - return {'string_value': value} + return {"string_value": value} else: - raise ValueError('Unknown type_code: {}'.format(type_code)) + raise ValueError("Unknown type_code: {}".format(type_code)) def list_datasets(self, project=None, region=None): """List all datasets in a particular project and region. @@ -337,11 +381,12 @@ def list_datasets(self, project=None, region=None): ValueError: If required parameters are missing. """ return self.client.list_datasets( - self.__location_path(project=project, region=region) - ) + self.__location_path(project=project, region=region) + ) - def get_dataset(self, project=None, region=None, - dataset_name=None, dataset_display_name=None): + def get_dataset( + self, project=None, region=None, dataset_name=None, dataset_display_name=None + ): """Gets a single dataset in a particular project and region. Example: @@ -388,23 +433,34 @@ def get_dataset(self, project=None, region=None, ValueError: If required parameters are missing. """ if dataset_name is None and dataset_display_name is None: - raise ValueError('One of \'dataset_name\' or ' - '\'dataset_display_name\' must be set.') + raise ValueError( + "One of 'dataset_name' or " "'dataset_display_name' must be set." + ) if dataset_name is not None: return self.client.get_dataset(dataset_name) - result = next((d for d in self.list_datasets(project, region) - if d.display_name == dataset_display_name), None) + result = next( + ( + d + for d in self.list_datasets(project, region) + if d.display_name == dataset_display_name + ), + None, + ) if result is None: - raise exceptions.NotFound(('Dataset with display_name: \'{}\' ' + - 'not found').format(dataset_display_name)) + raise exceptions.NotFound( + ("Dataset with display_name: '{}' " + "not found").format( + dataset_display_name + ) + ) return result - def create_dataset(self, dataset_display_name, metadata={}, project=None, - region=None): + def create_dataset( + self, dataset_display_name, metadata={}, project=None, region=None + ): """Create a dataset. Keep in mind, importing data is a separate step. Example: @@ -442,15 +498,18 @@ def create_dataset(self, dataset_display_name, metadata={}, project=None, ValueError: If required parameters are missing. """ return self.client.create_dataset( - self.__location_path(project, region), - { - 'display_name': dataset_display_name, - 'tables_dataset_metadata': metadata - } - ) + self.__location_path(project, region), + {"display_name": dataset_display_name, "tables_dataset_metadata": metadata}, + ) - def delete_dataset(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): + def delete_dataset( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): """Deletes a dataset. This does not delete any models trained on this dataset. @@ -502,20 +561,29 @@ def delete_dataset(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ try: - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) # delete is idempotent except exceptions.NotFound: return None return self.client.delete_dataset(dataset_name) - def import_data(self, dataset=None, dataset_display_name=None, - dataset_name=None, gcs_input_uris=None, - bigquery_input_uri=None, project=None, region=None): + def import_data( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + gcs_input_uris=None, + bigquery_input_uri=None, + project=None, + region=None, + ): """Imports data into a dataset. Example: @@ -578,35 +646,36 @@ def import_data(self, dataset=None, dataset_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) request = {} if gcs_input_uris is not None: if type(gcs_input_uris) != list: gcs_input_uris = [gcs_input_uris] - request = { - 'gcs_source': { - 'input_uris': gcs_input_uris - } - } + request = {"gcs_source": {"input_uris": gcs_input_uris}} elif bigquery_input_uri is not None: - request = { - 'bigquery_source': { - 'input_uri': bigquery_input_uri - } - } + request = {"bigquery_source": {"input_uri": bigquery_input_uri}} else: - raise ValueError('One of \'gcs_input_uris\', or ' - '\'bigquery_input_uri\' must be set.') + raise ValueError( + "One of 'gcs_input_uris', or " "'bigquery_input_uri' must be set." + ) return self.client.import_data(dataset_name, request) - def list_table_specs(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): + def list_table_specs( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): """Lists table specs. Example: @@ -659,17 +728,26 @@ def list_table_specs(self, dataset=None, dataset_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) return self.client.list_table_specs(dataset_name) - def list_column_specs(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - project=None, region=None): + def list_column_specs( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + project=None, + region=None, + ): """Lists column specs. Example: @@ -738,20 +816,35 @@ def list_column_specs(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ if table_spec_name is None: - table_specs = [t for t in self.list_table_specs(dataset=dataset, + table_specs = [ + t + for t in self.list_table_specs( + dataset=dataset, dataset_display_name=dataset_display_name, dataset_name=dataset_name, project=project, - region=region)] + region=region, + ) + ] table_spec_name = table_specs[table_spec_index].name return self.client.list_column_specs(table_spec_name) - def update_column_spec(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, - type_code=None, nullable=None, project=None, region=None): + def update_column_spec( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + type_code=None, + nullable=None, + project=None, + region=None, + ): """Updates a column's specs. Example: @@ -824,48 +917,56 @@ def update_column_spec(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ column_spec_name = self.__column_spec_name_from_args( - dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - column_spec_name=column_spec_name, - column_spec_display_name=column_spec_display_name, - project=project, - region=region + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, ) # type code must always be set if type_code is None: # this index is safe, we would have already thrown a NotFound # had the column_spec_name not existed - type_code = {s.name: s for s in self.list_column_specs( + type_code = { + s.name: s + for s in self.list_column_specs( dataset=dataset, dataset_display_name=dataset_display_name, dataset_name=dataset_name, table_spec_name=table_spec_name, table_spec_index=table_spec_index, project=project, - region=region) + region=region, + ) }[column_spec_name].data_type.type_code data_type = {} if nullable is not None: - data_type['nullable'] = nullable + data_type["nullable"] = nullable - data_type['type_code'] = type_code + data_type["type_code"] = type_code - request = { - 'name': column_spec_name, - 'data_type': data_type - } + request = {"name": column_spec_name, "data_type": data_type} return self.client.update_column_spec(request) - def set_target_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, - project=None, region=None): + def set_target_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + ): """Sets the target column for a given table. Example: @@ -940,39 +1041,46 @@ def set_target_column(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ column_spec_name = self.__column_spec_name_from_args( - dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - column_spec_name=column_spec_name, - column_spec_display_name=column_spec_display_name, - project=project, - region=region + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, ) - column_spec_id = column_spec_name.rsplit('/', 1)[-1] + column_spec_id = column_spec_name.rsplit("/", 1)[-1] - dataset = self.__dataset_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, - 'target_column_spec_id', - column_spec_id) + metadata = self.__update_metadata( + metadata, "target_column_spec_id", column_spec_id + ) - request = { - 'name': dataset.name, - 'tables_dataset_metadata': metadata, - } + request = {"name": dataset.name, "tables_dataset_metadata": metadata} return self.client.update_dataset(request) - def set_time_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, - project=None, region=None): + def set_time_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + ): """Sets the time column which designates which data will be of type timestamp and will be used for the timeseries data. This column must be of type timestamp. @@ -1045,37 +1153,44 @@ def set_time_column(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ column_spec_name = self.__column_spec_name_from_args( - dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - column_spec_name=column_spec_name, - column_spec_display_name=column_spec_display_name, - project=project, - region=region + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, ) - column_spec_id = column_spec_name.rsplit('/', 1)[-1] + column_spec_id = column_spec_name.rsplit("/", 1)[-1] - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) - table_spec_full_id = self.__table_spec_name_from_args( - dataset_name=dataset_name) + table_spec_full_id = self.__table_spec_name_from_args(dataset_name=dataset_name) my_table_spec = { - 'name': table_spec_full_id, - 'time_column_spec_id': column_spec_id + "name": table_spec_full_id, + "time_column_spec_id": column_spec_id, } self.client.update_table_spec(my_table_spec) return self.get_dataset(dataset_name=dataset_name) - def clear_time_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): + def clear_time_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): """Clears the time column which designates which data will be of type timestamp and will be used for the timeseries data. @@ -1130,27 +1245,33 @@ def clear_time_column(self, dataset=None, dataset_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - dataset_name = self.__dataset_name_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) - table_spec_full_id = self.__table_spec_name_from_args( - dataset_name=dataset_name) + table_spec_full_id = self.__table_spec_name_from_args(dataset_name=dataset_name) - my_table_spec = { - 'name': table_spec_full_id, - 'time_column_spec_id': None, - } + my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None} response = self.client.update_table_spec(my_table_spec) return self.get_dataset(dataset_name=dataset_name) - def set_weight_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, - project=None, region=None): + def set_weight_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + ): """Sets the weight column for a given table. Example: @@ -1225,37 +1346,42 @@ def set_weight_column(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ column_spec_name = self.__column_spec_name_from_args( - dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - column_spec_name=column_spec_name, - column_spec_display_name=column_spec_display_name, - project=project, - region=region + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, ) - column_spec_id = column_spec_name.rsplit('/', 1)[-1] + column_spec_id = column_spec_name.rsplit("/", 1)[-1] - dataset = self.__dataset_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, - 'weight_column_spec_id', - column_spec_id) + metadata = self.__update_metadata( + metadata, "weight_column_spec_id", column_spec_id + ) - request = { - 'name': dataset.name, - 'tables_dataset_metadata': metadata, - } + request = {"name": dataset.name, "tables_dataset_metadata": metadata} return self.client.update_dataset(request) - def clear_weight_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): + def clear_weight_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): """Clears the weight column for a given dataset. Example: @@ -1311,26 +1437,32 @@ def clear_weight_column(self, dataset=None, dataset_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - dataset = self.__dataset_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, 'weight_column_spec_id', - None) + metadata = self.__update_metadata(metadata, "weight_column_spec_id", None) - request = { - 'name': dataset.name, - 'tables_dataset_metadata': metadata, - } + request = {"name": dataset.name, "tables_dataset_metadata": metadata} return self.client.update_dataset(request) - def set_test_train_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, table_spec_name=None, table_spec_index=0, - column_spec_name=None, column_spec_display_name=None, - project=None, region=None): + def set_test_train_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + ): """Sets the test/train (ml_use) column which designates which data belongs to the test and train sets. This column must be categorical. @@ -1406,35 +1538,42 @@ def set_test_train_column(self, dataset=None, dataset_display_name=None, ValueError: If required parameters are missing. """ column_spec_name = self.__column_spec_name_from_args( - dataset=dataset, - dataset_display_name=dataset_display_name, - dataset_name=dataset_name, - table_spec_name=table_spec_name, - table_spec_index=table_spec_index, - column_spec_name=column_spec_name, - column_spec_display_name=column_spec_display_name, - project=project, - region=region + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, ) - column_spec_id = column_spec_name.rsplit('/', 1)[-1] + column_spec_id = column_spec_name.rsplit("/", 1)[-1] - dataset = self.__dataset_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, 'ml_use_column_spec_id', column_spec_id) + metadata = self.__update_metadata( + metadata, "ml_use_column_spec_id", column_spec_id + ) - request = { - 'name': dataset.name, - 'tables_dataset_metadata': metadata, - } + request = {"name": dataset.name, "tables_dataset_metadata": metadata} return self.client.update_dataset(request) - def clear_test_train_column(self, dataset=None, dataset_display_name=None, - dataset_name=None, project=None, region=None): + def clear_test_train_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + ): """Clears the test/train (ml_use) column which designates which data belongs to the test and train sets. @@ -1491,19 +1630,17 @@ def clear_test_train_column(self, dataset=None, dataset_display_name=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - dataset = self.__dataset_from_args(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name, - project=project, - region=region) + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + ) metadata = dataset.tables_dataset_metadata - metadata = self.__update_metadata(metadata, 'ml_use_column_spec_id', - None) + metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None) - request = { - 'name': dataset.name, - 'tables_dataset_metadata': metadata, - } + request = {"name": dataset.name, "tables_dataset_metadata": metadata} return self.client.update_dataset(request) @@ -1550,15 +1687,22 @@ def list_models(self, project=None, region=None): ValueError: If required parameters are missing. """ return self.client.list_models( - self.__location_path(project=project, region=region) + self.__location_path(project=project, region=region) ) - def create_model(self, model_display_name, dataset=None, - dataset_display_name=None, dataset_name=None, - train_budget_milli_node_hours=None, project=None, - region=None, model_metadata={}, - include_column_spec_names=None, - exclude_column_spec_names=None): + def create_model( + self, + model_display_name, + dataset=None, + dataset_display_name=None, + dataset_name=None, + train_budget_milli_node_hours=None, + project=None, + region=None, + model_metadata={}, + include_column_spec_names=None, + exclude_column_spec_names=None, + ): """Create a model. This will train your model on the given dataset. Example: @@ -1620,30 +1764,45 @@ def create_model(self, model_display_name, dataset=None, to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - if (train_budget_milli_node_hours is None or - train_budget_milli_node_hours < 1000 or - train_budget_milli_node_hours > 72000): - raise ValueError('\'train_budget_milli_node_hours\' must be a ' - 'value between 1,000 and 72,000 inclusive') - - if (exclude_column_spec_names not in [None, []] and - include_column_spec_names not in [None, []]): - raise ValueError('Cannot set both ' - '\'exclude_column_spec_names\' and ' - '\'include_column_spec_names\'') - - dataset_name = self.__dataset_name_from_args(dataset=dataset, + if ( + train_budget_milli_node_hours is None + or train_budget_milli_node_hours < 1000 + or train_budget_milli_node_hours > 72000 + ): + raise ValueError( + "'train_budget_milli_node_hours' must be a " + "value between 1,000 and 72,000 inclusive" + ) + + if exclude_column_spec_names not in [ + None, + [], + ] and include_column_spec_names not in [None, []]: + raise ValueError( + "Cannot set both " + "'exclude_column_spec_names' and " + "'include_column_spec_names'" + ) + + dataset_name = self.__dataset_name_from_args( + dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, project=project, - region=region) + region=region, + ) - model_metadata['train_budget_milli_node_hours'] = train_budget_milli_node_hours + model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours - dataset_id = dataset_name.rsplit('/', 1)[-1] - columns = [s for s in self.list_column_specs(dataset=dataset, - dataset_name=dataset_name, - dataset_display_name=dataset_display_name)] + dataset_id = dataset_name.rsplit("/", 1)[-1] + columns = [ + s + for s in self.list_column_specs( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + ) + ] final_columns = [] if include_column_spec_names: @@ -1651,27 +1810,32 @@ def create_model(self, model_display_name, dataset=None, if c.display_name in include_column_spec_names: final_columns.append(c) - model_metadata['input_feature_column_specs'] = final_columns + model_metadata["input_feature_column_specs"] = final_columns elif exclude_column_spec_names: for a in columns: if a.display_name not in exclude_column_spec_names: final_columns.append(a) - model_metadata['input_feature_column_specs'] = final_columns + model_metadata["input_feature_column_specs"] = final_columns request = { - 'display_name': model_display_name, - 'dataset_id': dataset_id, - 'tables_model_metadata': model_metadata + "display_name": model_display_name, + "dataset_id": dataset_id, + "tables_model_metadata": model_metadata, } return self.client.create_model( - self.__location_path(project=project, region=region), - request + self.__location_path(project=project, region=region), request ) - def delete_model(self, model=None, model_display_name=None, - model_name=None, project=None, region=None): + def delete_model( + self, + model=None, + model_display_name=None, + model_name=None, + project=None, + region=None, + ): """Deletes a model. Note this will not delete any datasets associated with this model. @@ -1723,19 +1887,22 @@ def delete_model(self, model=None, model_display_name=None, ValueError: If required parameters are missing. """ try: - model_name = self.__model_name_from_args(model=model, - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region) + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + ) # delete is idempotent except exceptions.NotFound: return None return self.client.delete_model(model_name) - def get_model(self, project=None, region=None, - model_name=None, model_display_name=None): + def get_model( + self, project=None, region=None, model_name=None, model_display_name=None + ): """Gets a single model in a particular project and region. Example: @@ -1781,24 +1948,39 @@ def get_model(self, project=None, region=None, ValueError: If required parameters are missing. """ if model_name is None and model_display_name is None: - raise ValueError('One of \'model_name\' or ' - '\'model_display_name\' must be set.') + raise ValueError( + "One of 'model_name' or " "'model_display_name' must be set." + ) if model_name is not None: return self.client.get_model(model_name) - model = next((d for d in self.list_models(project, region) - if d.display_name == model_display_name), None) + model = next( + ( + d + for d in self.list_models(project, region) + if d.display_name == model_display_name + ), + None, + ) if model is None: - raise exceptions.NotFound('No model with model_diplay_name: ' + - '\'{}\' found'.format(model_display_name)) + raise exceptions.NotFound( + "No model with model_diplay_name: " + + "'{}' found".format(model_display_name) + ) return model - #TODO(jonathanskim): allow deployment from just model ID - def deploy_model(self, model=None, model_name=None, - model_display_name=None, project=None, region=None): + # TODO(jonathanskim): allow deployment from just model ID + def deploy_model( + self, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + ): """Deploys a model. This allows you make online predictions using the model you've deployed. @@ -1850,17 +2032,23 @@ def deploy_model(self, model=None, model_name=None, ValueError: If required parameters are missing. """ model_name = self.__model_name_from_args( - model=model, - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, ) return self.client.deploy_model(model_name) - def undeploy_model(self, model=None, model_name=None, - model_display_name=None, project=None, region=None): + def undeploy_model( + self, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + ): """Undeploys a model. Example: @@ -1911,18 +2099,25 @@ def undeploy_model(self, model=None, model_name=None, ValueError: If required parameters are missing. """ model_name = self.__model_name_from_args( - model=model, - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, ) return self.client.undeploy_model(model_name) ## TODO(lwander): support pandas DataFrame as input type - def predict(self, inputs, model=None, model_name=None, - model_display_name=None, project=None, region=None): + def predict( + self, + inputs, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + ): """Makes a prediction on a deployed model. This will fail if the model was not deployed. @@ -1978,11 +2173,11 @@ def predict(self, inputs, model=None, model_name=None, ValueError: If required parameters are missing. """ model = self.__model_from_args( - model=model, - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, ) column_specs = model.tables_model_metadata.input_feature_column_specs @@ -1990,28 +2185,34 @@ def predict(self, inputs, model=None, model_name=None, inputs = [inputs.get(c.display_name, None) for c in column_specs] if len(inputs) != len(column_specs): - raise ValueError(('Dimension mismatch, the number of provided ' - 'inputs ({}) does not match that of the model ' - '({})').format(len(inputs), len(column_specs))) + raise ValueError( + ( + "Dimension mismatch, the number of provided " + "inputs ({}) does not match that of the model " + "({})" + ).format(len(inputs), len(column_specs)) + ) values = [] for i, c in zip(inputs, column_specs): - value_type = self.__type_code_to_value_type( - c.data_type.type_code, i - ) + value_type = self.__type_code_to_value_type(c.data_type.type_code, i) values.append(value_type) - request = { - 'row': { - 'values': values - } - } + request = {"row": {"values": values}} return self.prediction_client.predict(model.name, request) - def batch_predict(self, gcs_input_uris, gcs_output_uri_prefix, - model=None, model_name=None, model_display_name=None, project=None, - region=None, inputs=None): + def batch_predict( + self, + gcs_input_uris, + gcs_output_uri_prefix, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + inputs=None, + ): """Makes a batch prediction on a model. This does _not_ require the model to be deployed. @@ -2070,31 +2271,25 @@ def batch_predict(self, gcs_input_uris, gcs_output_uri_prefix, ValueError: If required parameters are missing. """ if gcs_input_uris is None or gcs_output_uri_prefix is None: - raise ValueError('Both \'gcs_input_uris\' and ' - '\'gcs_output_uri_prefix\' must be set.') + raise ValueError( + "Both 'gcs_input_uris' and " "'gcs_output_uri_prefix' must be set." + ) model_name = self.__model_name_from_args( - model=model, - model_name=model_name, - model_display_name=model_display_name, - project=project, - region=region + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, ) if type(gcs_input_uris) != list: gcs_input_uris = [gcs_input_uris] - input_request = { - 'gcs_source': { - 'input_uris': gcs_input_uris - } - } + input_request = {"gcs_source": {"input_uris": gcs_input_uris}} - output_request = { - 'gcs_source': { - 'output_uri_prefix': gcs_output_uri_prefix - } - } + output_request = {"gcs_source": {"output_uri_prefix": gcs_output_uri_prefix}} - return self.prediction_client.batch_predict(model_name, input_request, - output_request) + return self.prediction_client.batch_predict( + model_name, input_request, output_request + ) diff --git a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py index 6c2a30386b2a..9ff6c2c7b79a 100644 --- a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py +++ b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py @@ -24,22 +24,24 @@ from google.api_core import exceptions from google.cloud.automl_v1beta1.gapic import enums -PROJECT = os.environ['PROJECT_ID'] -REGION = 'us-central1' +PROJECT = os.environ["PROJECT_ID"] +REGION = "us-central1" MAX_WAIT_TIME_SECONDS = 30 MAX_SLEEP_TIME_SECONDS = 5 -STATIC_DATASET='test_dataset_do_not_delete' -#STATIC_MODEL='test_model_do_not_delete' -STATIC_MODEL='test_online_model_do_not_delete' - -ID = '{rand}_{time}'.format( - rand=''.join([random.choice(string.ascii_letters + string.digits) - for n in range(4)]), - time=int(time.time()) +STATIC_DATASET = "test_dataset_do_not_delete" +STATIC_MODEL='test_model_do_not_delete' + +ID = "{rand}_{time}".format( + rand="".join( + [random.choice(string.ascii_letters + string.digits) for n in range(4)] + ), + time=int(time.time()), ) + def _id(name): - return '{}_{}'.format(name, ID) + return "{}_{}".format(name, ID) + class TestSystemTablesClient(object): def cancel_and_wait(self, op): @@ -65,20 +67,22 @@ def test_list_models(self): def test_create_delete_dataset(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) - display_name = _id('t_cr_dl') + display_name = _id("t_cr_dl") dataset = client.create_dataset(display_name) assert dataset is not None - assert dataset.name == client.get_dataset( - dataset_display_name=display_name - ).name + assert ( + dataset.name == client.get_dataset(dataset_display_name=display_name).name + ) client.delete_dataset(dataset=dataset) def test_import_data(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) - display_name = _id('t_import') + display_name = _id("t_import") dataset = client.create_dataset(display_name) - op = client.import_data(dataset=dataset, - gcs_input_uris='gs://cloud-ml-tables-data/bank-marketing.csv') + op = client.import_data( + dataset=dataset, + gcs_input_uris="gs://cloud-ml-tables-data/bank-marketing.csv", + ) self.cancel_and_wait(op) client.delete_dataset(dataset=dataset) @@ -90,8 +94,10 @@ def ensure_dataset_ready(self, client): dataset = client.create_dataset(STATIC_DATASET) if dataset.example_count is None or dataset.example_count == 0: - op = client.import_data(dataset=dataset, - gcs_input_uris='gs://cloud-ml-tables-data/bank-marketing.csv') + op = client.import_data( + dataset=dataset, + gcs_input_uris="gs://cloud-ml-tables-data/bank-marketing.csv", + ) op.result() dataset = client.get_dataset(dataset_name=dataset.name) @@ -112,88 +118,84 @@ def test_list_table_specs(self): def test_set_column_nullable(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - client.update_column_spec(dataset=dataset, - column_spec_display_name='POutcome', nullable=True) - columns = {c.display_name: c - for c in client.list_column_specs(dataset=dataset)} - assert columns['POutcome'].data_type.nullable == True + client.update_column_spec( + dataset=dataset, column_spec_display_name="POutcome", nullable=True + ) + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} + assert columns["POutcome"].data_type.nullable == True def test_set_target_column(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - client.set_target_column(dataset=dataset, - column_spec_display_name='Age') - columns = {c.display_name: c - for c in client.list_column_specs(dataset=dataset)} + client.set_target_column(dataset=dataset, column_spec_display_name="Age") + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} dataset = client.get_dataset(dataset_name=dataset.name) metadata = dataset.tables_dataset_metadata - assert columns['Age'].name.endswith( - '/{}'.format(metadata.target_column_spec_id) - ) + assert columns["Age"].name.endswith( + "/{}".format(metadata.target_column_spec_id) + ) def test_set_weight_column(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - client.set_weight_column(dataset=dataset, - column_spec_display_name='Duration') - columns = {c.display_name: c - for c in client.list_column_specs(dataset=dataset)} + client.set_weight_column(dataset=dataset, column_spec_display_name="Duration") + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} dataset = client.get_dataset(dataset_name=dataset.name) metadata = dataset.tables_dataset_metadata - assert columns['Duration'].name.endswith( - '/{}'.format(metadata.weight_column_spec_id) - ) + assert columns["Duration"].name.endswith( + "/{}".format(metadata.weight_column_spec_id) + ) def test_set_weight_and_target_column(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - client.set_weight_column(dataset=dataset, - column_spec_display_name='Day') - client.set_target_column(dataset=dataset, - column_spec_display_name='Campaign') - columns = {c.display_name: c - for c in client.list_column_specs(dataset=dataset)} + client.set_weight_column(dataset=dataset, column_spec_display_name="Day") + client.set_target_column(dataset=dataset, column_spec_display_name="Campaign") + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} dataset = client.get_dataset(dataset_name=dataset.name) metadata = dataset.tables_dataset_metadata - assert columns['Day'].name.endswith( - '/{}'.format(metadata.weight_column_spec_id) - ) - assert columns['Campaign'].name.endswith( - '/{}'.format(metadata.target_column_spec_id) - ) + assert columns["Day"].name.endswith( + "/{}".format(metadata.weight_column_spec_id) + ) + assert columns["Campaign"].name.endswith( + "/{}".format(metadata.target_column_spec_id) + ) def test_create_delete_model(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - client.set_target_column(dataset=dataset, - column_spec_display_name='Deposit') - display_name = _id('t_cr_dl') - op = client.create_model(display_name, dataset=dataset, - train_budget_milli_node_hours=1000) + client.set_target_column(dataset=dataset, column_spec_display_name="Deposit") + display_name = _id("t_cr_dl") + op = client.create_model( + display_name, dataset=dataset, train_budget_milli_node_hours=1000 + ) self.cancel_and_wait(op) client.delete_model(model_display_name=display_name) def test_online_predict(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) model = self.ensure_model_online(client) - result = client.predict(inputs={ - 'Age': 31, - 'Balance': 200, - 'Campaign': 2, - 'Contact': 'cellular', - 'Day': 4, - 'Default': 'no', - 'Duration': 12, - 'Education': 'primary', - 'Housing': 'yes', - 'Job': 'blue-collar', - 'Loan': 'no', - 'MaritalStatus': 'divorced', - 'Month': 'jul', - 'PDays': 4, - 'POutcome': None, - 'Previous': 12 - }, model=model) + result = client.predict( + inputs={ + "Age": 31, + "Balance": 200, + "Campaign": 2, + "Contact": "cellular", + "Day": 4, + "Default": "no", + "Duration": 12, + "Education": "primary", + "Housing": "yes", + "Job": "blue-collar", + "Loan": "no", + "MaritalStatus": "divorced", + "Month": "jul", + "PDays": 4, + "POutcome": None, + "Previous": 12, + }, + model=model, + ) assert result is not None def ensure_model_online(self, client): @@ -210,14 +212,14 @@ def ensure_model_ready(self, client): pass dataset = self.ensure_dataset_ready(client) - client.set_target_column(dataset=dataset, - column_spec_display_name='Deposit') + client.set_target_column(dataset=dataset, column_spec_display_name="Deposit") client.clear_weight_column(dataset=dataset) client.clear_test_train_column(dataset=dataset) - client.update_column_spec(dataset=dataset, - column_spec_display_name='POutcome', nullable=True) - op = client.create_model(STATIC_MODEL, dataset=dataset, - train_budget_milli_node_hours=1000) + client.update_column_spec( + dataset=dataset, column_spec_display_name="POutcome", nullable=True + ) + op = client.create_model( + STATIC_MODEL, dataset=dataset, train_budget_milli_node_hours=1000 + ) op.result() return client.get_model(model_display_name=STATIC_MODEL) - diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index 1784d5274477..cbdbff97d2e3 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -23,44 +23,52 @@ from google.api_core import exceptions from google.cloud.automl_v1beta1.proto import data_types_pb2 -PROJECT='project' -REGION='region' -LOCATION_PATH='projects/{}/locations/{}'.format(PROJECT, REGION) +PROJECT = "project" +REGION = "region" +LOCATION_PATH = "projects/{}/locations/{}".format(PROJECT, REGION) -class TestTablesClient(object): - def tables_client(self, client_attrs={}, - prediction_client_attrs={}): +class TestTablesClient(object): + def tables_client(self, client_attrs={}, prediction_client_attrs={}): client_mock = mock.Mock(**client_attrs) prediction_client_mock = mock.Mock(**prediction_client_attrs) - return automl_v1beta1.TablesClient(client=client_mock, - prediction_client=prediction_client_mock, - project=PROJECT, region=REGION) + return automl_v1beta1.TablesClient( + client=client_mock, + prediction_client=prediction_client_mock, + project=PROJECT, + region=REGION, + ) def test_list_datasets_empty(self): - client = self.tables_client({ - 'list_datasets.return_value': [], - 'location_path.return_value': LOCATION_PATH, - }, {}) + client = self.tables_client( + { + "list_datasets.return_value": [], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) ds = client.list_datasets() client.client.location_path.assert_called_with(PROJECT, REGION) client.client.list_datasets.assert_called_with(LOCATION_PATH) assert ds == [] def test_list_datasets_not_empty(self): - datasets = ['some_dataset'] - client = self.tables_client({ - 'list_datasets.return_value': datasets, - 'location_path.return_value': LOCATION_PATH, - }, {}) + datasets = ["some_dataset"] + client = self.tables_client( + { + "list_datasets.return_value": datasets, + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) ds = client.list_datasets() client.client.location_path.assert_called_with(PROJECT, REGION) client.client.list_datasets.assert_called_with(LOCATION_PATH) assert len(ds) == 1 - assert ds[0] == 'some_dataset' + assert ds[0] == "some_dataset" def test_get_dataset_no_value(self): - dataset_actual = 'dataset' + dataset_actual = "dataset" client = self.tables_client({}, {}) error = None try: @@ -71,100 +79,95 @@ def test_get_dataset_no_value(self): client.client.get_dataset.assert_not_called() def test_get_dataset_name(self): - dataset_actual = 'dataset' - client = self.tables_client({ - 'get_dataset.return_value': dataset_actual - }, {}) - dataset = client.get_dataset(dataset_name='my_dataset') - client.client.get_dataset.assert_called_with('my_dataset') + dataset_actual = "dataset" + client = self.tables_client({"get_dataset.return_value": dataset_actual}, {}) + dataset = client.get_dataset(dataset_name="my_dataset") + client.client.get_dataset.assert_called_with("my_dataset") assert dataset == dataset_actual def test_get_no_dataset(self): - client = self.tables_client({ - 'get_dataset.side_effect': exceptions.NotFound('err') - }, {}) + client = self.tables_client( + {"get_dataset.side_effect": exceptions.NotFound("err")}, {} + ) error = None try: - client.get_dataset(dataset_name='my_dataset') + client.get_dataset(dataset_name="my_dataset") except exceptions.NotFound as e: error = e assert error is not None - client.client.get_dataset.assert_called_with('my_dataset') + client.client.get_dataset.assert_called_with("my_dataset") def test_get_dataset_from_empty_list(self): - client = self.tables_client({'list_datasets.return_value': []}, {}) + client = self.tables_client({"list_datasets.return_value": []}, {}) error = None try: - client.get_dataset(dataset_display_name='my_dataset') + client.get_dataset(dataset_display_name="my_dataset") except exceptions.NotFound as e: error = e assert error is not None def test_get_dataset_from_list_not_found(self): - client = self.tables_client({ - 'list_datasets.return_value': [mock.Mock(display_name='not_it')] - }, {}) + client = self.tables_client( + {"list_datasets.return_value": [mock.Mock(display_name="not_it")]}, {} + ) error = None try: - client.get_dataset(dataset_display_name='my_dataset') + client.get_dataset(dataset_display_name="my_dataset") except exceptions.NotFound as e: error = e assert error is not None def test_get_dataset_from_list(self): - client = self.tables_client({ - 'list_datasets.return_value': [ - mock.Mock(display_name='not_it'), - mock.Mock(display_name='my_dataset'), - ] - }, {}) - dataset = client.get_dataset(dataset_display_name='my_dataset') - assert dataset.display_name == 'my_dataset' + client = self.tables_client( + { + "list_datasets.return_value": [ + mock.Mock(display_name="not_it"), + mock.Mock(display_name="my_dataset"), + ] + }, + {}, + ) + dataset = client.get_dataset(dataset_display_name="my_dataset") + assert dataset.display_name == "my_dataset" def test_create_dataset(self): - client = self.tables_client({ - 'location_path.return_value': LOCATION_PATH, - 'create_dataset.return_value': mock.Mock(display_name='name'), - }, {}) - metadata = {'metadata': 'values'} - dataset = client.create_dataset('name', metadata=metadata) + client = self.tables_client( + { + "location_path.return_value": LOCATION_PATH, + "create_dataset.return_value": mock.Mock(display_name="name"), + }, + {}, + ) + metadata = {"metadata": "values"} + dataset = client.create_dataset("name", metadata=metadata) client.client.location_path.assert_called_with(PROJECT, REGION) client.client.create_dataset.assert_called_with( - LOCATION_PATH, - {'display_name': 'name', 'tables_dataset_metadata': metadata} + LOCATION_PATH, {"display_name": "name", "tables_dataset_metadata": metadata} ) - assert dataset.display_name == 'name' + assert dataset.display_name == "name" def test_delete_dataset(self): dataset = mock.Mock() - dataset.configure_mock(name='name') - client = self.tables_client({ - 'delete_dataset.return_value': None, - }, {}) + dataset.configure_mock(name="name") + client = self.tables_client({"delete_dataset.return_value": None}, {}) client.delete_dataset(dataset=dataset) - client.client.delete_dataset.assert_called_with('name') + client.client.delete_dataset.assert_called_with("name") def test_delete_dataset_not_found(self): - client = self.tables_client({ - 'list_datasets.return_value': [], - }, {}) - client.delete_dataset(dataset_display_name='not_found') + client = self.tables_client({"list_datasets.return_value": []}, {}) + client.delete_dataset(dataset_display_name="not_found") client.client.delete_dataset.assert_not_called() def test_delete_dataset_name(self): - client = self.tables_client({ - 'delete_dataset.return_value': None, - }, {}) - client.delete_dataset(dataset_name='name') - client.client.delete_dataset.assert_called_with('name') + client = self.tables_client({"delete_dataset.return_value": None}, {}) + client.delete_dataset(dataset_name="name") + client.client.delete_dataset.assert_called_with("name") def test_import_not_found(self): - client = self.tables_client({ - 'list_datasets.return_value': [], - }, {}) + client = self.tables_client({"list_datasets.return_value": []}, {}) error = None try: - client.import_data(dataset_display_name='name', gcs_input_uris='uri') + client.import_data(dataset_display_name="name", gcs_input_uris="uri") except exceptions.NotFound as e: error = e assert error is not None @@ -172,626 +175,699 @@ def test_import_not_found(self): client.client.import_data.assert_not_called() def test_import_gcs_uri(self): - client = self.tables_client({ - 'import_data.return_value': None, - }, {}) - client.import_data(dataset_name='name', gcs_input_uris='uri') - client.client.import_data.assert_called_with('name', { - 'gcs_source': {'input_uris': ['uri']} - }) + client = self.tables_client({"import_data.return_value": None}, {}) + client.import_data(dataset_name="name", gcs_input_uris="uri") + client.client.import_data.assert_called_with( + "name", {"gcs_source": {"input_uris": ["uri"]}} + ) def test_import_gcs_uris(self): - client = self.tables_client({ - 'import_data.return_value': None, - }, {}) - client.import_data(dataset_name='name', - gcs_input_uris=['uri', 'uri']) - client.client.import_data.assert_called_with('name', { - 'gcs_source': {'input_uris': ['uri', 'uri']} - }) + client = self.tables_client({"import_data.return_value": None}, {}) + client.import_data(dataset_name="name", gcs_input_uris=["uri", "uri"]) + client.client.import_data.assert_called_with( + "name", {"gcs_source": {"input_uris": ["uri", "uri"]}} + ) def test_import_bq_uri(self): - client = self.tables_client({ - 'import_data.return_value': None, - }, {}) - client.import_data(dataset_name='name', - bigquery_input_uri='uri') - client.client.import_data.assert_called_with('name', { - 'bigquery_source': {'input_uri': 'uri'} - }) + client = self.tables_client({"import_data.return_value": None}, {}) + client.import_data(dataset_name="name", bigquery_input_uri="uri") + client.client.import_data.assert_called_with( + "name", {"bigquery_source": {"input_uri": "uri"}} + ) def test_list_table_specs(self): - client = self.tables_client({ - 'list_table_specs.return_value': None, - }, {}) - client.list_table_specs(dataset_name='name') - client.client.list_table_specs.assert_called_with('name') + client = self.tables_client({"list_table_specs.return_value": None}, {}) + client.list_table_specs(dataset_name="name") + client.client.list_table_specs.assert_called_with("name") def test_list_table_specs_not_found(self): - client = self.tables_client({ - 'list_table_specs.side_effect': exceptions.NotFound('not found'), - }, {}) + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("not found")}, {} + ) error = None try: - client.list_table_specs(dataset_name='name') + client.list_table_specs(dataset_name="name") except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') + client.client.list_table_specs.assert_called_with("name") def test_list_column_specs(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [], - }, {}) - client.list_column_specs(dataset_name='name') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') + table_spec_mock.configure_mock(name="table") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [], + }, + {}, + ) + client.list_column_specs(dataset_name="name") + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") def test_update_column_spec_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) error = None try: - client.update_column_spec(dataset_name='name', - column_spec_name='column2') + client.update_column_spec(dataset_name="name", column_spec_name="column2") except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") client.client.update_column_spec.assert_not_called() def test_update_column_spec_display_name_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) error = None try: - client.update_column_spec(dataset_name='name', - column_spec_display_name='column2') + client.update_column_spec( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") client.client.update_column_spec.assert_not_called() def test_update_column_spec_name_no_args(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column/2', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.update_column_spec(dataset_name='name', - column_spec_name='column/2') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_column_spec.assert_called_with({ - 'name': 'column/2', - 'data_type': { - 'type_code': 'type_code', - } - }) + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column/2", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec(dataset_name="name", column_spec_name="column/2") + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_column_spec.assert_called_with( + {"name": "column/2", "data_type": {"type_code": "type_code"}} + ) def test_update_column_spec_no_args(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.update_column_spec(dataset_name='name', - column_spec_display_name='column') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_column_spec.assert_called_with({ - 'name': 'column', - 'data_type': { - 'type_code': 'type_code', - } - }) + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", column_spec_display_name="column" + ) + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_column_spec.assert_called_with( + {"name": "column", "data_type": {"type_code": "type_code"}} + ) def test_update_column_spec_nullable(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.update_column_spec(dataset_name='name', - column_spec_display_name='column', nullable=True) - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_column_spec.assert_called_with({ - 'name': 'column', - 'data_type': { - 'type_code': 'type_code', - 'nullable': True, + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", column_spec_display_name="column", nullable=True + ) + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_column_spec.assert_called_with( + { + "name": "column", + "data_type": {"type_code": "type_code", "nullable": True}, } - }) + ) def test_update_column_spec_type_code(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.update_column_spec(dataset_name='name', - column_spec_display_name='column', type_code='type_code2') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_column_spec.assert_called_with({ - 'name': 'column', - 'data_type': { - 'type_code': 'type_code2', - } - }) + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", + column_spec_display_name="column", + type_code="type_code2", + ) + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_column_spec.assert_called_with( + {"name": "column", "data_type": {"type_code": "type_code2"}} + ) def test_update_column_spec_type_code_nullable(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.update_column_spec(dataset_name='name', nullable=True, - column_spec_display_name='column', type_code='type_code2') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_column_spec.assert_called_with({ - 'name': 'column', - 'data_type': { - 'type_code': 'type_code2', - 'nullable': True, + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", + nullable=True, + column_spec_display_name="column", + type_code="type_code2", + ) + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_column_spec.assert_called_with( + { + "name": "column", + "data_type": {"type_code": "type_code2", "nullable": True}, } - }) + ) def test_update_column_spec_type_code_nullable_false(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code='type_code') - column_spec_mock.configure_mock(name='column', display_name='column', - data_type=data_type_mock) - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.update_column_spec(dataset_name='name', nullable=False, - column_spec_display_name='column', type_code='type_code2') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_column_spec.assert_called_with({ - 'name': 'column', - 'data_type': { - 'type_code': 'type_code2', - 'nullable': False, + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", + nullable=False, + column_spec_display_name="column", + type_code="type_code2", + ) + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_column_spec.assert_called_with( + { + "name": "column", + "data_type": {"type_code": "type_code2", "nullable": False}, } - }) + ) def test_set_target_column_table_not_found(self): - client = self.tables_client({ - 'list_table_specs.side_effect': exceptions.NotFound('err'), - }, {}) + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} + ) error = None try: - client.set_target_column(dataset_name='name', - column_spec_display_name='column2') + client.set_target_column( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') + client.client.list_table_specs.assert_called_with("name") client.client.list_column_specs.assert_not_called() client.client.update_dataset.assert_not_called() def test_set_target_column_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/1', display_name='column') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) + column_spec_mock.configure_mock(name="column/1", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) error = None try: - client.set_target_column(dataset_name='name', - column_spec_display_name='column2') + client.set_target_column( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") client.client.update_dataset.assert_not_called() def test_set_target_column(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/1', display_name='column') + column_spec_mock.configure_mock(name="column/1", display_name="column") dataset_mock = mock.Mock() tables_dataset_metadata_mock = mock.Mock() - tables_dataset_metadata_mock.configure_mock(target_column_spec_id='2', - weight_column_spec_id='2', - ml_use_column_spec_id='3') - dataset_mock.configure_mock(name='dataset', - tables_dataset_metadata=tables_dataset_metadata_mock) - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.set_target_column(dataset_name='name', - column_spec_display_name='column') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_dataset.assert_called_with({ - 'name': 'dataset', - 'tables_dataset_metadata': { - 'target_column_spec_id': '1', - 'weight_column_spec_id': '2', - 'ml_use_column_spec_id': '3', + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="2", + weight_column_spec_id="2", + ml_use_column_spec_id="3", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_target_column(dataset_name="name", column_spec_display_name="column") + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, } - }) + ) def test_set_weight_column_table_not_found(self): - client = self.tables_client({ - 'list_table_specs.side_effect': exceptions.NotFound('err'), - }, {}) + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} + ) try: - client.set_weight_column(dataset_name='name', - column_spec_display_name='column2') + client.set_weight_column( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound: pass - client.client.list_table_specs.assert_called_with('name') + client.client.list_table_specs.assert_called_with("name") client.client.list_column_specs.assert_not_called() client.client.update_dataset.assert_not_called() def test_set_weight_column_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/1', display_name='column') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) + column_spec_mock.configure_mock(name="column/1", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) error = None try: - client.set_weight_column(dataset_name='name', - column_spec_display_name='column2') + client.set_weight_column( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") client.client.update_dataset.assert_not_called() def test_set_weight_column(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/2', display_name='column') + column_spec_mock.configure_mock(name="column/2", display_name="column") dataset_mock = mock.Mock() tables_dataset_metadata_mock = mock.Mock() - tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', - weight_column_spec_id='1', - ml_use_column_spec_id='3') - dataset_mock.configure_mock(name='dataset', - tables_dataset_metadata=tables_dataset_metadata_mock) - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.set_weight_column(dataset_name='name', - column_spec_display_name='column') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_dataset.assert_called_with({ - 'name': 'dataset', - 'tables_dataset_metadata': { - 'target_column_spec_id': '1', - 'weight_column_spec_id': '2', - 'ml_use_column_spec_id': '3', + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="1", + ml_use_column_spec_id="3", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_weight_column(dataset_name="name", column_spec_display_name="column") + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, } - }) + ) def test_clear_weight_column(self): dataset_mock = mock.Mock() tables_dataset_metadata_mock = mock.Mock() - tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', - weight_column_spec_id='2', - ml_use_column_spec_id='3') - dataset_mock.configure_mock(name='dataset', - tables_dataset_metadata=tables_dataset_metadata_mock) - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - }, {}) - client.clear_weight_column(dataset_name='name') - client.client.update_dataset.assert_called_with({ - 'name': 'dataset', - 'tables_dataset_metadata': { - 'target_column_spec_id': '1', - 'weight_column_spec_id': None, - 'ml_use_column_spec_id': '3', + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="2", + ml_use_column_spec_id="3", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) + client.clear_weight_column(dataset_name="name") + client.client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": None, + "ml_use_column_spec_id": "3", + }, } - }) + ) def test_set_test_train_column_table_not_found(self): - client = self.tables_client({ - 'list_table_specs.side_effect': exceptions.NotFound('err'), - }, {}) + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} + ) error = None try: - client.set_test_train_column(dataset_name='name', - column_spec_display_name='column2') + client.set_test_train_column( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') + client.client.list_table_specs.assert_called_with("name") client.client.list_column_specs.assert_not_called() client.client.update_dataset.assert_not_called() def test_set_test_train_column_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/1', display_name='column') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) + column_spec_mock.configure_mock(name="column/1", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) error = None try: - client.set_test_train_column(dataset_name='name', - column_spec_display_name='column2') + client.set_test_train_column( + dataset_name="name", column_spec_display_name="column2" + ) except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") client.client.update_dataset.assert_not_called() def test_set_test_train_column(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/3', display_name='column') + column_spec_mock.configure_mock(name="column/3", display_name="column") dataset_mock = mock.Mock() tables_dataset_metadata_mock = mock.Mock() - tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', - weight_column_spec_id='2', - ml_use_column_spec_id='2') - dataset_mock.configure_mock(name='dataset', - tables_dataset_metadata=tables_dataset_metadata_mock) - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.set_test_train_column(dataset_name='name', - column_spec_display_name='column') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_dataset.assert_called_with({ - 'name': 'dataset', - 'tables_dataset_metadata': { - 'target_column_spec_id': '1', - 'weight_column_spec_id': '2', - 'ml_use_column_spec_id': '3', + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="2", + ml_use_column_spec_id="2", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_test_train_column( + dataset_name="name", column_spec_display_name="column" + ) + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, } - }) + ) def test_clear_test_train_column(self): dataset_mock = mock.Mock() tables_dataset_metadata_mock = mock.Mock() - tables_dataset_metadata_mock.configure_mock(target_column_spec_id='1', - weight_column_spec_id='2', - ml_use_column_spec_id='2') - dataset_mock.configure_mock(name='dataset', - tables_dataset_metadata=tables_dataset_metadata_mock) - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - }, {}) - client.clear_test_train_column(dataset_name='name') - client.client.update_dataset.assert_called_with({ - 'name': 'dataset', - 'tables_dataset_metadata': { - 'target_column_spec_id': '1', - 'weight_column_spec_id': '2', - 'ml_use_column_spec_id': None, + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="2", + ml_use_column_spec_id="2", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) + client.clear_test_train_column(dataset_name="name") + client.client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": None, + }, } - }) + ) def test_set_time_column(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/3', display_name='column') + column_spec_mock.configure_mock(name="column/3", display_name="column") dataset_mock = mock.Mock() - dataset_mock.configure_mock(name='dataset') - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - }, {}) - client.set_time_column(dataset_name='name', - column_spec_display_name='column') - client.client.list_table_specs.assert_called_with('name') - client.client.list_column_specs.assert_called_with('table') - client.client.update_table_spec.assert_called_with({ - 'name': 'table', - 'time_column_spec_id': '3', - }) + dataset_mock.configure_mock(name="dataset") + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_time_column(dataset_name="name", column_spec_display_name="column") + client.client.list_table_specs.assert_called_with("name") + client.client.list_column_specs.assert_called_with("table") + client.client.update_table_spec.assert_called_with( + {"name": "table", "time_column_spec_id": "3"} + ) def test_clear_time_column(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") dataset_mock = mock.Mock() - dataset_mock.configure_mock(name='dataset') - client = self.tables_client({ - 'get_dataset.return_value': dataset_mock, - 'list_table_specs.return_value': [table_spec_mock], - }, {}) - client.clear_time_column(dataset_name='name') - client.client.update_table_spec.assert_called_with({ - 'name': 'table', - 'time_column_spec_id': None, - }) + dataset_mock.configure_mock(name="dataset") + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + }, + {}, + ) + client.clear_time_column(dataset_name="name") + client.client.update_table_spec.assert_called_with( + {"name": "table", "time_column_spec_id": None} + ) def test_list_models_empty(self): - client = self.tables_client({ - 'list_models.return_value': [], - 'location_path.return_value': LOCATION_PATH, - }, {}) + client = self.tables_client( + { + "list_models.return_value": [], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) ds = client.list_models() client.client.location_path.assert_called_with(PROJECT, REGION) client.client.list_models.assert_called_with(LOCATION_PATH) assert ds == [] def test_list_models_not_empty(self): - models = ['some_model'] - client = self.tables_client({ - 'list_models.return_value': models, - 'location_path.return_value': LOCATION_PATH, - }, {}) + models = ["some_model"] + client = self.tables_client( + { + "list_models.return_value": models, + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) ds = client.list_models() client.client.location_path.assert_called_with(PROJECT, REGION) client.client.list_models.assert_called_with(LOCATION_PATH) assert len(ds) == 1 - assert ds[0] == 'some_model' + assert ds[0] == "some_model" def test_get_model_name(self): - model_actual = 'model' - client = self.tables_client({ - 'get_model.return_value': model_actual - }, {}) - model = client.get_model(model_name='my_model') - client.client.get_model.assert_called_with('my_model') + model_actual = "model" + client = self.tables_client({"get_model.return_value": model_actual}, {}) + model = client.get_model(model_name="my_model") + client.client.get_model.assert_called_with("my_model") assert model == model_actual def test_get_no_model(self): - client = self.tables_client({ - 'get_model.side_effect': exceptions.NotFound('err') - }, {}) + client = self.tables_client( + {"get_model.side_effect": exceptions.NotFound("err")}, {} + ) error = None try: - client.get_model(model_name='my_model') + client.get_model(model_name="my_model") except exceptions.NotFound as e: error = e assert error is not None - client.client.get_model.assert_called_with('my_model') + client.client.get_model.assert_called_with("my_model") def test_get_model_from_empty_list(self): - client = self.tables_client({'list_models.return_value': []}, {}) + client = self.tables_client({"list_models.return_value": []}, {}) error = None try: - client.get_model(model_display_name='my_model') + client.get_model(model_display_name="my_model") except exceptions.NotFound as e: error = e assert error is not None def test_get_model_from_list_not_found(self): - client = self.tables_client({ - 'list_models.return_value': [mock.Mock(display_name='not_it')] - }, {}) + client = self.tables_client( + {"list_models.return_value": [mock.Mock(display_name="not_it")]}, {} + ) error = None try: - client.get_model(model_display_name='my_model') + client.get_model(model_display_name="my_model") except exceptions.NotFound as e: error = e assert error is not None def test_get_model_from_list(self): - client = self.tables_client({ - 'list_models.return_value': [ - mock.Mock(display_name='not_it'), - mock.Mock(display_name='my_model'), - ] - }, {}) - model = client.get_model(model_display_name='my_model') - assert model.display_name == 'my_model' + client = self.tables_client( + { + "list_models.return_value": [ + mock.Mock(display_name="not_it"), + mock.Mock(display_name="my_model"), + ] + }, + {}, + ) + model = client.get_model(model_display_name="my_model") + assert model.display_name == "my_model" def test_delete_model(self): model = mock.Mock() - model.configure_mock(name='name') - client = self.tables_client({ - 'delete_model.return_value': None, - }, {}) + model.configure_mock(name="name") + client = self.tables_client({"delete_model.return_value": None}, {}) client.delete_model(model=model) - client.client.delete_model.assert_called_with('name') + client.client.delete_model.assert_called_with("name") def test_delete_model_not_found(self): - client = self.tables_client({ - 'list_models.return_value': [], - }, {}) - client.delete_model(model_display_name='not_found') + client = self.tables_client({"list_models.return_value": []}, {}) + client.delete_model(model_display_name="not_found") client.client.delete_model.assert_not_called() def test_delete_model_name(self): - client = self.tables_client({ - 'delete_model.return_value': None, - }, {}) - client.delete_model(model_name='name') - client.client.delete_model.assert_called_with('name') + client = self.tables_client({"delete_model.return_value": None}, {}) + client.delete_model(model_name="name") + client.client.delete_model.assert_called_with("name") def test_deploy_model_no_args(self): client = self.tables_client({}, {}) @@ -805,16 +881,14 @@ def test_deploy_model_no_args(self): def test_deploy_model(self): client = self.tables_client({}, {}) - client.deploy_model(model_name='name') - client.client.deploy_model.assert_called_with('name') + client.deploy_model(model_name="name") + client.client.deploy_model.assert_called_with("name") def test_deploy_model_not_found(self): - client = self.tables_client({ - 'list_models.return_value': [], - }, {}) + client = self.tables_client({"list_models.return_value": []}, {}) error = None try: - client.deploy_model(model_display_name='name') + client.deploy_model(model_display_name="name") except exceptions.NotFound as e: error = e assert error is not None @@ -822,16 +896,14 @@ def test_deploy_model_not_found(self): def test_undeploy_model(self): client = self.tables_client({}, {}) - client.undeploy_model(model_name='name') - client.client.undeploy_model.assert_called_with('name') + client.undeploy_model(model_name="name") + client.client.undeploy_model.assert_called_with("name") def test_undeploy_model_not_found(self): - client = self.tables_client({ - 'list_models.return_value': [], - }, {}) + client = self.tables_client({"list_models.return_value": []}, {}) error = None try: - client.undeploy_model(model_display_name='name') + client.undeploy_model(model_display_name="name") except exceptions.NotFound as e: error = e assert error is not None @@ -840,86 +912,110 @@ def test_undeploy_model_not_found(self): def test_create_model(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock = mock.Mock() - column_spec_mock.configure_mock(name='column/2', display_name='column') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock], - 'location_path.return_value': LOCATION_PATH, - }, {}) - client.create_model('my_model', dataset_name='my_dataset', - train_budget_milli_node_hours=1000) - client.client.create_model.assert_called_with(LOCATION_PATH, { - 'display_name': 'my_model', - 'dataset_id': 'my_dataset', - 'tables_model_metadata': { - 'train_budget_milli_node_hours': 1000, + column_spec_mock.configure_mock(name="column/2", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + "location_path.return_value": LOCATION_PATH, }, - }) + {}, + ) + client.create_model( + "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1000 + ) + client.client.create_model.assert_called_with( + LOCATION_PATH, + { + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, + }, + ) def test_create_model_include_columns(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock1 = mock.Mock() - column_spec_mock1.configure_mock(name='column/1', - display_name='column1') + column_spec_mock1.configure_mock(name="column/1", display_name="column1") column_spec_mock2 = mock.Mock() - column_spec_mock2.configure_mock(name='column/2', - display_name='column2') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock1, - column_spec_mock2], - 'location_path.return_value': LOCATION_PATH, - }, {}) - client.create_model('my_model', dataset_name='my_dataset', - include_column_spec_names=['column1'], - train_budget_milli_node_hours=1000) - client.client.create_model.assert_called_with(LOCATION_PATH, { - 'display_name': 'my_model', - 'dataset_id': 'my_dataset', - 'tables_model_metadata': { - 'train_budget_milli_node_hours': 1000, - 'input_feature_column_specs': [column_spec_mock1] + column_spec_mock2.configure_mock(name="column/2", display_name="column2") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [ + column_spec_mock1, + column_spec_mock2, + ], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + client.create_model( + "my_model", + dataset_name="my_dataset", + include_column_spec_names=["column1"], + train_budget_milli_node_hours=1000, + ) + client.client.create_model.assert_called_with( + LOCATION_PATH, + { + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": { + "train_budget_milli_node_hours": 1000, + "input_feature_column_specs": [column_spec_mock1], + }, }, - }) + ) def test_create_model_exclude_columns(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here - table_spec_mock.configure_mock(name='table') + table_spec_mock.configure_mock(name="table") column_spec_mock1 = mock.Mock() - column_spec_mock1.configure_mock(name='column/1', - display_name='column1') + column_spec_mock1.configure_mock(name="column/1", display_name="column1") column_spec_mock2 = mock.Mock() - column_spec_mock2.configure_mock(name='column/2', - display_name='column2') - client = self.tables_client({ - 'list_table_specs.return_value': [table_spec_mock], - 'list_column_specs.return_value': [column_spec_mock1, - column_spec_mock2], - 'location_path.return_value': LOCATION_PATH, - }, {}) - client.create_model('my_model', dataset_name='my_dataset', - exclude_column_spec_names=['column1'], - train_budget_milli_node_hours=1000) - client.client.create_model.assert_called_with(LOCATION_PATH, { - 'display_name': 'my_model', - 'dataset_id': 'my_dataset', - 'tables_model_metadata': { - 'train_budget_milli_node_hours': 1000, - 'input_feature_column_specs': [column_spec_mock2] + column_spec_mock2.configure_mock(name="column/2", display_name="column2") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [ + column_spec_mock1, + column_spec_mock2, + ], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + client.create_model( + "my_model", + dataset_name="my_dataset", + exclude_column_spec_names=["column1"], + train_budget_milli_node_hours=1000, + ) + client.client.create_model.assert_called_with( + LOCATION_PATH, + { + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": { + "train_budget_milli_node_hours": 1000, + "input_feature_column_specs": [column_spec_mock2], + }, }, - }) + ) def test_create_model_invalid_hours_small(self): client = self.tables_client({}, {}) error = None try: - client.create_model('my_model', dataset_name='my_dataset', - train_budget_milli_node_hours=1) + client.create_model( + "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1 + ) except ValueError as e: error = e assert error is not None @@ -929,8 +1025,11 @@ def test_create_model_invalid_hours_large(self): client = self.tables_client({}, {}) error = None try: - client.create_model('my_model', dataset_name='my_dataset', - train_budget_milli_node_hours=1000000) + client.create_model( + "my_model", + dataset_name="my_dataset", + train_budget_milli_node_hours=1000000, + ) except ValueError as e: error = e assert error is not None @@ -940,8 +1039,7 @@ def test_create_model_invalid_no_dataset(self): client = self.tables_client({}, {}) error = None try: - client.create_model('my_model', - train_budget_milli_node_hours=1000) + client.create_model("my_model", train_budget_milli_node_hours=1000) except ValueError as e: error = e assert error is not None @@ -952,10 +1050,13 @@ def test_create_model_invalid_include_exclude(self): client = self.tables_client({}, {}) error = None try: - client.create_model('my_model', dataset_name='my_dataset', - include_column_spec_names=['a'], - exclude_column_spec_names=['b'], - train_budget_milli_node_hours=1000) + client.create_model( + "my_model", + dataset_name="my_dataset", + include_column_spec_names=["a"], + exclude_column_spec_names=["b"], + train_budget_milli_node_hours=1000, + ) except ValueError as e: error = e assert error is not None @@ -964,62 +1065,46 @@ def test_create_model_invalid_include_exclude(self): def test_predict_from_array(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) - column_spec = mock.Mock(display_name='a', data_type=data_type) + column_spec = mock.Mock(display_name="a", data_type=data_type) model_metadata = mock.Mock(input_feature_column_specs=[column_spec]) model = mock.Mock() - model.configure_mock(tables_model_metadata=model_metadata, - name='my_model') - client = self.tables_client({ - 'get_model.return_value': model - }, {}) - client.predict(['1'], model_name='my_model') - client.prediction_client.predict.assert_called_with('my_model', { - 'row': { - 'values': [{'string_value': '1'}] - } - }) + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict(["1"], model_name="my_model") + client.prediction_client.predict.assert_called_with( + "my_model", {"row": {"values": [{"string_value": "1"}]}} + ) def test_predict_from_dict(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) - column_spec_a = mock.Mock(display_name='a', data_type=data_type) - column_spec_b = mock.Mock(display_name='b', data_type=data_type) - model_metadata = mock.Mock(input_feature_column_specs=[ - column_spec_a, - column_spec_b, - ]) + column_spec_a = mock.Mock(display_name="a", data_type=data_type) + column_spec_b = mock.Mock(display_name="b", data_type=data_type) + model_metadata = mock.Mock( + input_feature_column_specs=[column_spec_a, column_spec_b] + ) model = mock.Mock() - model.configure_mock(tables_model_metadata=model_metadata, - name='my_model') - client = self.tables_client({ - 'get_model.return_value': model - }, {}) - client.predict({'a': '1', 'b': '2'}, model_name='my_model') - client.prediction_client.predict.assert_called_with('my_model', { - 'row': { - 'values': [{'string_value': '1'}, {'string_value': '2'}] - } - }) + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict({"a": "1", "b": "2"}, model_name="my_model") + client.prediction_client.predict.assert_called_with( + "my_model", + {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, + ) def test_predict_from_dict_missing(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) - column_spec_a = mock.Mock(display_name='a', data_type=data_type) - column_spec_b = mock.Mock(display_name='b', data_type=data_type) - model_metadata = mock.Mock(input_feature_column_specs=[ - column_spec_a, - column_spec_b, - ]) + column_spec_a = mock.Mock(display_name="a", data_type=data_type) + column_spec_b = mock.Mock(display_name="b", data_type=data_type) + model_metadata = mock.Mock( + input_feature_column_specs=[column_spec_a, column_spec_b] + ) model = mock.Mock() - model.configure_mock(tables_model_metadata=model_metadata, - name='my_model') - client = self.tables_client({ - 'get_model.return_value': model - }, {}) - client.predict({'a': '1'}, model_name='my_model') - client.prediction_client.predict.assert_called_with('my_model', { - 'row': { - 'values': [{'string_value': '1'}, {'null_value': 0}] - } - }) + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict({"a": "1"}, model_name="my_model") + client.prediction_client.predict.assert_called_with( + "my_model", {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}} + ) def test_predict_all_types(self): float_type = mock.Mock(type_code=data_types_pb2.FLOAT64) @@ -1028,71 +1113,70 @@ def test_predict_all_types(self): array_type = mock.Mock(type_code=data_types_pb2.ARRAY) struct_type = mock.Mock(type_code=data_types_pb2.STRUCT) category_type = mock.Mock(type_code=data_types_pb2.CATEGORY) - column_spec_float = mock.Mock(display_name='float', - data_type=float_type) - column_spec_timestamp = mock.Mock(display_name='timestamp', - data_type=timestamp_type) - column_spec_string = mock.Mock(display_name='string', - data_type=string_type) - column_spec_array = mock.Mock(display_name='array', - data_type=array_type) - column_spec_struct = mock.Mock(display_name='struct', - data_type=struct_type) - column_spec_category = mock.Mock(display_name='category', - data_type=category_type) - column_spec_null = mock.Mock(display_name='null', - data_type=category_type) - model_metadata = mock.Mock(input_feature_column_specs=[ - column_spec_float, - column_spec_timestamp, - column_spec_string, - column_spec_array, - column_spec_struct, - column_spec_category, - column_spec_null, - ]) + column_spec_float = mock.Mock(display_name="float", data_type=float_type) + column_spec_timestamp = mock.Mock( + display_name="timestamp", data_type=timestamp_type + ) + column_spec_string = mock.Mock(display_name="string", data_type=string_type) + column_spec_array = mock.Mock(display_name="array", data_type=array_type) + column_spec_struct = mock.Mock(display_name="struct", data_type=struct_type) + column_spec_category = mock.Mock( + display_name="category", data_type=category_type + ) + column_spec_null = mock.Mock(display_name="null", data_type=category_type) + model_metadata = mock.Mock( + input_feature_column_specs=[ + column_spec_float, + column_spec_timestamp, + column_spec_string, + column_spec_array, + column_spec_struct, + column_spec_category, + column_spec_null, + ] + ) model = mock.Mock() - model.configure_mock(tables_model_metadata=model_metadata, - name='my_model') - client = self.tables_client({ - 'get_model.return_value': model - }, {}) - client.predict({ - 'float': 1.0, - 'timestamp': 'EST', - 'string': 'text', - 'array': [1], - 'struct': {'a': 'b'}, - 'category': 'a', - 'null': None, - } , model_name='my_model') - client.prediction_client.predict.assert_called_with('my_model', { - 'row': { - 'values': [ - {'number_value': 1.0}, - {'string_value': 'EST'}, - {'string_value': 'text'}, - {'list_value': [1]}, - {'struct_value': {'a': 'b'}}, - {'string_value': 'a'}, - {'null_value': 0}, - ], - } - }) + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict( + { + "float": 1.0, + "timestamp": "EST", + "string": "text", + "array": [1], + "struct": {"a": "b"}, + "category": "a", + "null": None, + }, + model_name="my_model", + ) + client.prediction_client.predict.assert_called_with( + "my_model", + { + "row": { + "values": [ + {"number_value": 1.0}, + {"string_value": "EST"}, + {"string_value": "text"}, + {"list_value": [1]}, + {"struct_value": {"a": "b"}}, + {"string_value": "a"}, + {"null_value": 0}, + ] + } + }, + ) def test_predict_from_array_missing(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) - column_spec = mock.Mock(display_name='a', data_type=data_type) + column_spec = mock.Mock(display_name="a", data_type=data_type) model_metadata = mock.Mock(input_feature_column_specs=[column_spec]) model = mock.Mock() - model.configure_mock(tables_model_metadata=model_metadata, - name='my_model') - client = self.tables_client({ - 'get_model.return_value': model - }, {}) + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) error = None try: - client.predict([], model_name='my_model') + client.predict([], model_name="my_model") except ValueError as e: error = e assert error is not None @@ -1100,24 +1184,26 @@ def test_predict_from_array_missing(self): def test_batch_predict(self): client = self.tables_client({}, {}) - client.batch_predict(model_name='my_model', - gcs_input_uris='gs://input', - gcs_output_uri_prefix='gs://output') - client.prediction_client.batch_predict.assert_called_with('my_model', - { 'gcs_source': { - 'input_uris': ['gs://input'], - }}, { 'gcs_source': { - 'output_uri_prefix': 'gs://output', - }}, + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + gcs_output_uri_prefix="gs://output", + ) + client.prediction_client.batch_predict.assert_called_with( + "my_model", + {"gcs_source": {"input_uris": ["gs://input"]}}, + {"gcs_source": {"output_uri_prefix": "gs://output"}}, ) def test_batch_predict_missing_input_gcs_uri(self): client = self.tables_client({}, {}) error = None try: - client.batch_predict(model_name='my_model', - gcs_input_uris=None, - gcs_output_uri_prefix='gs://output') + client.batch_predict( + model_name="my_model", + gcs_input_uris=None, + gcs_output_uri_prefix="gs://output", + ) except ValueError as e: error = e assert error is not None @@ -1127,23 +1213,25 @@ def test_batch_predict_missing_input_gcs_uri(self): client = self.tables_client({}, {}) error = None try: - client.batch_predict(model_name='my_model', - gcs_input_uris='gs://input', - gcs_output_uri_prefix=None) + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + gcs_output_uri_prefix=None, + ) except ValueError as e: error = e assert error is not None client.prediction_client.batch_predict.assert_not_called() def test_batch_predict_missing_model(self): - client = self.tables_client({ - 'list_models.return_value': [], - }, {}) + client = self.tables_client({"list_models.return_value": []}, {}) error = None try: - client.batch_predict(model_display_name='my_model', - gcs_input_uris='gs://input', - gcs_output_uri_prefix='gs://output') + client.batch_predict( + model_display_name="my_model", + gcs_input_uris="gs://input", + gcs_output_uri_prefix="gs://output", + ) except exceptions.NotFound as e: error = e assert error is not None @@ -1153,8 +1241,9 @@ def test_batch_predict_no_model(self): client = self.tables_client({}, {}) error = None try: - client.batch_predict(gcs_input_uris='gs://input', - gcs_output_uri_prefix='gs://output') + client.batch_predict( + gcs_input_uris="gs://input", gcs_output_uri_prefix="gs://output" + ) except ValueError as e: error = e assert error is not None From 83e6d50497c643157b6af778ec65b256bbb1abec Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Mon, 22 Jul 2019 15:05:28 -0400 Subject: [PATCH 07/11] Passes **kwargs through to client & implements missing methods --- .../automl_v1beta1/tables/tables_client.py | 460 ++++++++++++++++-- .../v1beta1/test_system_tables_client_v1.py | 26 +- .../v1beta1/test_tables_client_v1beta1.py | 262 ++++++---- 3 files changed, 595 insertions(+), 153 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index e243a4533813..09a8e633f655 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -101,11 +101,11 @@ def __init__( client_info_.gapic_version = version if client is None: - self.client = gapic.auto_ml_client.AutoMlClient( + self.auto_ml_client = gapic.auto_ml_client.AutoMlClient( client_info=client_info_, **kwargs ) else: - self.client = client + self.auto_ml_client = client if prediction_client is None: self.prediction_client = gapic.prediction_service_client.PredictionServiceClient( @@ -136,7 +136,7 @@ def __location_path(self, project=None, region=None): ) region = self.region - return self.client.location_path(project, region) + return self.auto_ml_client.location_path(project, region) # the returned metadata object doesn't allow for updating fields, so # we need to manually copy user-updated fields over @@ -156,6 +156,7 @@ def __dataset_from_args( dataset_name=None, project=None, region=None, + **kwargs ): if dataset is None and dataset_display_name is None and dataset_name is None: raise ValueError( @@ -172,6 +173,7 @@ def __dataset_from_args( dataset_name=dataset_name, project=project, region=region, + **kwargs ) def __model_from_args( @@ -181,6 +183,7 @@ def __model_from_args( model_name=None, project=None, region=None, + **kwargs ): if model is None and model_display_name is None and model_name is None: raise ValueError( @@ -196,6 +199,7 @@ def __model_from_args( model_name=model_name, project=project, region=region, + **kwargs ) def __dataset_name_from_args( @@ -205,6 +209,7 @@ def __dataset_name_from_args( dataset_name=None, project=None, region=None, + **kwargs ): if dataset is None and dataset_display_name is None and dataset_name is None: raise ValueError( @@ -218,12 +223,15 @@ def __dataset_name_from_args( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) dataset_name = dataset.name else: # we do this to force a NotFound error when needed - self.get_dataset(dataset_name=dataset_name, project=project, region=region) + self.get_dataset( + dataset_name=dataset_name, project=project, region=region, **kwargs + ) return dataset_name def __table_spec_name_from_args( @@ -234,6 +242,7 @@ def __table_spec_name_from_args( dataset_name=None, project=None, region=None, + **kwargs ): dataset_name = self.__dataset_name_from_args( dataset=dataset, @@ -241,9 +250,12 @@ def __table_spec_name_from_args( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) - table_specs = [t for t in self.list_table_specs(dataset_name=dataset_name)] + table_specs = [ + t for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) + ] table_spec_full_id = table_specs[table_spec_index].name return table_spec_full_id @@ -255,6 +267,7 @@ def __model_name_from_args( model_name=None, project=None, region=None, + **kwargs ): if model is None and model_display_name is None and model_name is None: raise ValueError( @@ -267,11 +280,14 @@ def __model_name_from_args( model_display_name=model_display_name, project=project, region=region, + **kwargs ) model_name = model.name else: # we do this to force a NotFound error when needed - self.get_model(model_name=model_name, project=project, region=region) + self.get_model( + model_name=model_name, project=project, region=region, **kwargs + ) return model_name def __column_spec_name_from_args( @@ -285,6 +301,7 @@ def __column_spec_name_from_args( column_spec_display_name=None, project=None, region=None, + **kwargs ): column_specs = self.list_column_specs( dataset=dataset, @@ -294,6 +311,7 @@ def __column_spec_name_from_args( table_spec_index=table_spec_index, project=project, region=region, + **kwargs ) if column_spec_display_name is not None: column_specs = {s.display_name: s for s in column_specs} @@ -338,7 +356,7 @@ def __type_code_to_value_type(self, type_code, value): else: raise ValueError("Unknown type_code: {}".format(type_code)) - def list_datasets(self, project=None, region=None): + def list_datasets(self, project=None, region=None, **kwargs): """List all datasets in a particular project and region. Example: @@ -380,12 +398,17 @@ def list_datasets(self, project=None, region=None): to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.client.list_datasets( - self.__location_path(project=project, region=region) + return self.auto_ml_client.list_datasets( + self.__location_path(project=project, region=region), **kwargs ) def get_dataset( - self, project=None, region=None, dataset_name=None, dataset_display_name=None + self, + project=None, + region=None, + dataset_name=None, + dataset_display_name=None, + **kwargs ): """Gets a single dataset in a particular project and region. @@ -438,12 +461,12 @@ def get_dataset( ) if dataset_name is not None: - return self.client.get_dataset(dataset_name) + return self.auto_ml_client.get_dataset(dataset_name, **kwargs) result = next( ( d - for d in self.list_datasets(project, region) + for d in self.list_datasets(project, region, **kwargs) if d.display_name == dataset_display_name ), None, @@ -459,7 +482,7 @@ def get_dataset( return result def create_dataset( - self, dataset_display_name, metadata={}, project=None, region=None + self, dataset_display_name, metadata={}, project=None, region=None, **kwargs ): """Create a dataset. Keep in mind, importing data is a separate step. @@ -497,9 +520,10 @@ def create_dataset( to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.client.create_dataset( + return self.auto_ml_client.create_dataset( self.__location_path(project, region), {"display_name": dataset_display_name, "tables_dataset_metadata": metadata}, + **kwargs ) def delete_dataset( @@ -509,6 +533,7 @@ def delete_dataset( dataset_name=None, project=None, region=None, + **kwargs ): """Deletes a dataset. This does not delete any models trained on this dataset. @@ -567,12 +592,13 @@ def delete_dataset( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) # delete is idempotent except exceptions.NotFound: return None - return self.client.delete_dataset(dataset_name) + return self.auto_ml_client.delete_dataset(dataset_name, **kwargs) def import_data( self, @@ -583,6 +609,7 @@ def import_data( bigquery_input_uri=None, project=None, region=None, + **kwargs ): """Imports data into a dataset. @@ -652,6 +679,7 @@ def import_data( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) request = {} @@ -666,7 +694,140 @@ def import_data( "One of 'gcs_input_uris', or " "'bigquery_input_uri' must be set." ) - return self.client.import_data(dataset_name, request) + return self.auto_ml_client.import_data(dataset_name, request, **kwargs) + + def export_data( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + gcs_output_uri_prefix=None, + bigquery_output_uri=None, + project=None, + region=None, + **kwargs + ): + """Exports data from a dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.create_dataset(dataset_display_name='my_dataset') + >>> + >>> client.export_data(dataset=d, + ... gcs_output_uri_prefix='gs://cloud-ml-tables-data/bank-marketing.csv') + ... + >>> def callback(operation_future): + ... result = operation_future.result() + ... + >>> response.add_done_callback(callback) + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to export + data from. This must be supplied if `dataset` or `dataset_name` + are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + export data from. This must be supplied if + `dataset_display_name` or `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to export data from. This must + be supplied if `dataset_display_name` or `dataset_name` are not + supplied. + gcs_output_uri_prefix (Optional[Union[string, Sequence[string]]]): + A single `gs://..` prefixed URI to export to. This must be + supplied if `bigquery_output_uri` is not. + bigquery_output_uri (Optional[string]): + A URI pointing to the BigQuery table containing the data to + export. This must be supplied if `gcs_output_uri_prefix` is not. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + request = {} + if gcs_output_uri_prefix is not None: + request = {"gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix}} + elif bigquery_output_uri is not None: + request = {"bigquery_destination": {"output_uri": bigquery_output_uri}} + else: + raise ValueError( + "One of 'gcs_output_uri_prefix', or 'bigquery_output_uri' must be set." + ) + + return self.auto_ml_client.export_data(dataset_name, request, **kwargs) + + def get_table_spec(self, table_spec_name, project=None, region=None, **kwargs): + """Gets a single table spec in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_table_spec('my_table_spec') + >>> + + Args: + table_spec_name (string): + This is the fully-qualified name generated by the AutoML API + for this table spec. + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.TableSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.get_table_spec(table_spec_name, **kwargs) def list_table_specs( self, @@ -675,6 +836,7 @@ def list_table_specs( dataset_name=None, project=None, region=None, + **kwargs ): """Lists table specs. @@ -734,9 +896,50 @@ def list_table_specs( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) - return self.client.list_table_specs(dataset_name) + return self.auto_ml_client.list_table_specs(dataset_name, **kwargs) + + def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs): + """Gets a single column spec in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_column_spec('my_column_spec') + >>> + + Args: + column_spec_name (string): + This is the fully-qualified name generated by the AutoML API + for this column spec. + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.ColumnSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.get_column_spec(column_spec_name, **kwargs) def list_column_specs( self, @@ -747,6 +950,7 @@ def list_column_specs( table_spec_index=0, project=None, region=None, + **kwargs ): """Lists column specs. @@ -824,12 +1028,13 @@ def list_column_specs( dataset_name=dataset_name, project=project, region=region, + **kwargs ) ] table_spec_name = table_specs[table_spec_index].name - return self.client.list_column_specs(table_spec_name) + return self.auto_ml_client.list_column_specs(table_spec_name, **kwargs) def update_column_spec( self, @@ -844,6 +1049,7 @@ def update_column_spec( nullable=None, project=None, region=None, + **kwargs ): """Updates a column's specs. @@ -926,6 +1132,7 @@ def update_column_spec( column_spec_display_name=column_spec_display_name, project=project, region=region, + **kwargs ) # type code must always be set @@ -942,6 +1149,7 @@ def update_column_spec( table_spec_index=table_spec_index, project=project, region=region, + **kwargs ) }[column_spec_name].data_type.type_code @@ -953,7 +1161,7 @@ def update_column_spec( request = {"name": column_spec_name, "data_type": data_type} - return self.client.update_column_spec(request) + return self.auto_ml_client.update_column_spec(request, **kwargs) def set_target_column( self, @@ -966,6 +1174,7 @@ def set_target_column( column_spec_display_name=None, project=None, region=None, + **kwargs ): """Sets the target column for a given table. @@ -1050,6 +1259,7 @@ def set_target_column( column_spec_display_name=column_spec_display_name, project=project, region=region, + **kwargs ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1059,6 +1269,7 @@ def set_target_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata( @@ -1067,7 +1278,7 @@ def set_target_column( request = {"name": dataset.name, "tables_dataset_metadata": metadata} - return self.client.update_dataset(request) + return self.auto_ml_client.update_dataset(request, **kwargs) def set_time_column( self, @@ -1080,6 +1291,7 @@ def set_time_column( column_spec_display_name=None, project=None, region=None, + **kwargs ): """Sets the time column which designates which data will be of type timestamp and will be used for the timeseries data. @@ -1088,8 +1300,8 @@ def set_time_column( Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.set_time_column(dataset_display_name='my_dataset', @@ -1144,7 +1356,7 @@ def set_time_column( `table_spec_name`, `dataset_name` or `dataset_display_name` are not supplied. Returns: - A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + A :class:`~google.cloud.automl_v1beta1.types.TableSpec` instance. Raises: google.api_core.exceptions.GoogleAPICallError: If the request failed for any reason. @@ -1162,6 +1374,7 @@ def set_time_column( column_spec_display_name=column_spec_display_name, project=project, region=region, + **kwargs ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1171,17 +1384,19 @@ def set_time_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) - table_spec_full_id = self.__table_spec_name_from_args(dataset_name=dataset_name) + table_spec_full_id = self.__table_spec_name_from_args( + dataset_name=dataset_name, **kwargs + ) my_table_spec = { "name": table_spec_full_id, "time_column_spec_id": column_spec_id, } - self.client.update_table_spec(my_table_spec) - return self.get_dataset(dataset_name=dataset_name) + return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs) def clear_time_column( self, @@ -1190,6 +1405,7 @@ def clear_time_column( dataset_name=None, project=None, region=None, + **kwargs ): """Clears the time column which designates which data will be of type timestamp and will be used for the timeseries data. @@ -1197,8 +1413,8 @@ def clear_time_column( Example: >>> from google.cloud import automl_v1beta1 >>> - >>> client = automl_v1beta1.tables.ClientHelper( - ... client=automl_v1beta1.AutoMlClient(), + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') ... project='my-project', region='us-central1') ... >>> client.set_time_column(dataset_display_name='my_dataset') @@ -1236,7 +1452,7 @@ def clear_time_column( not supplied. Returns: - A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + A :class:`~google.cloud.automl_v1beta1.types.TableSpec` instance. Raises: google.api_core.exceptions.GoogleAPICallError: If the request @@ -1251,14 +1467,16 @@ def clear_time_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) - table_spec_full_id = self.__table_spec_name_from_args(dataset_name=dataset_name) + table_spec_full_id = self.__table_spec_name_from_args( + dataset_name=dataset_name, **kwargs + ) my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None} - response = self.client.update_table_spec(my_table_spec) - return self.get_dataset(dataset_name=dataset_name) + return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs) def set_weight_column( self, @@ -1271,6 +1489,7 @@ def set_weight_column( column_spec_display_name=None, project=None, region=None, + **kwargs ): """Sets the weight column for a given table. @@ -1355,6 +1574,7 @@ def set_weight_column( column_spec_display_name=column_spec_display_name, project=project, region=region, + **kwargs ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1364,6 +1584,7 @@ def set_weight_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata( @@ -1372,7 +1593,7 @@ def set_weight_column( request = {"name": dataset.name, "tables_dataset_metadata": metadata} - return self.client.update_dataset(request) + return self.auto_ml_client.update_dataset(request, **kwargs) def clear_weight_column( self, @@ -1381,6 +1602,7 @@ def clear_weight_column( dataset_name=None, project=None, region=None, + **kwargs ): """Clears the weight column for a given dataset. @@ -1443,13 +1665,14 @@ def clear_weight_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata(metadata, "weight_column_spec_id", None) request = {"name": dataset.name, "tables_dataset_metadata": metadata} - return self.client.update_dataset(request) + return self.auto_ml_client.update_dataset(request, **kwargs) def set_test_train_column( self, @@ -1462,6 +1685,7 @@ def set_test_train_column( column_spec_display_name=None, project=None, region=None, + **kwargs ): """Sets the test/train (ml_use) column which designates which data belongs to the test and train sets. This column must be categorical. @@ -1547,6 +1771,7 @@ def set_test_train_column( column_spec_display_name=column_spec_display_name, project=project, region=region, + **kwargs ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1556,6 +1781,7 @@ def set_test_train_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata( @@ -1564,7 +1790,7 @@ def set_test_train_column( request = {"name": dataset.name, "tables_dataset_metadata": metadata} - return self.client.update_dataset(request) + return self.auto_ml_client.update_dataset(request, **kwargs) def clear_test_train_column( self, @@ -1573,6 +1799,7 @@ def clear_test_train_column( dataset_name=None, project=None, region=None, + **kwargs ): """Clears the test/train (ml_use) column which designates which data belongs to the test and train sets. @@ -1636,15 +1863,16 @@ def clear_test_train_column( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None) request = {"name": dataset.name, "tables_dataset_metadata": metadata} - return self.client.update_dataset(request) + return self.auto_ml_client.update_dataset(request, **kwargs) - def list_models(self, project=None, region=None): + def list_models(self, project=None, region=None, **kwargs): """List all models in a particular project and region. Example: @@ -1686,10 +1914,84 @@ def list_models(self, project=None, region=None): to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.client.list_models( - self.__location_path(project=project, region=region) + return self.auto_ml_client.list_models( + self.__location_path(project=project, region=region), **kwargs ) + def list_model_evaluations( + self, + project=None, + region=None, + model=None, + model_display_name=None, + model_name=None, + **kwargs + ): + """List all model evaluations for a given model. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> ms = client.list_model_evaluations(model_display_name='my_model') + >>> + >>> for m in ms: + ... # do something + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to list + evaluations for. This must be supplied if `model` or + `model_name` are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to list + evaluations for. This must be supplied if `model_display_name` + or `model` are not supplied. + model (Optional[model]): + The `model` instance you want to list evaluations for. This + must be supplied if `model_display_name` or `model_name` are + not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of + :class:`~google.cloud.automl_v1beta1.types.ModelEvaluation` + instances. You can also iterate over the pages of the response + using its `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + + return self.auto_ml_client.list_model_evaluations(model_name, **kwargs) + def create_model( self, model_display_name, @@ -1702,6 +2004,7 @@ def create_model( model_metadata={}, include_column_spec_names=None, exclude_column_spec_names=None, + **kwargs ): """Create a model. This will train your model on the given dataset. @@ -1790,6 +2093,7 @@ def create_model( dataset_display_name=dataset_display_name, project=project, region=region, + **kwargs ) model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours @@ -1801,6 +2105,7 @@ def create_model( dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, + **kwargs ) ] @@ -1824,8 +2129,8 @@ def create_model( "tables_model_metadata": model_metadata, } - return self.client.create_model( - self.__location_path(project=project, region=region), request + return self.auto_ml_client.create_model( + self.__location_path(project=project, region=region), request, **kwargs ) def delete_model( @@ -1835,6 +2140,7 @@ def delete_model( model_name=None, project=None, region=None, + **kwargs ): """Deletes a model. Note this will not delete any datasets associated with this model. @@ -1893,15 +2199,63 @@ def delete_model( model_display_name=model_display_name, project=project, region=region, + **kwargs ) # delete is idempotent except exceptions.NotFound: return None - return self.client.delete_model(model_name) + return self.auto_ml_client.delete_model(model_name, **kwargs) + + def get_model_evaluation( + self, model_evaluation_name, project=None, region=None, **kwargs + ): + """Gets a single evaluation model in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_model_evaluation('my_model_evaluation') + >>> + + Args: + model_evaluation_name (string): + This is the fully-qualified name generated by the AutoML API + for this model evaluation. + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.ModelEvaluation` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.get_model_evaluation(model_evaluation_name, **kwargs) def get_model( - self, project=None, region=None, model_name=None, model_display_name=None + self, + project=None, + region=None, + model_name=None, + model_display_name=None, + **kwargs ): """Gets a single model in a particular project and region. @@ -1953,12 +2307,12 @@ def get_model( ) if model_name is not None: - return self.client.get_model(model_name) + return self.auto_ml_client.get_model(model_name, **kwargs) model = next( ( d - for d in self.list_models(project, region) + for d in self.list_models(project, region, **kwargs) if d.display_name == model_display_name ), None, @@ -1980,6 +2334,7 @@ def deploy_model( model_display_name=None, project=None, region=None, + **kwargs ): """Deploys a model. This allows you make online predictions using the model you've deployed. @@ -2037,9 +2392,10 @@ def deploy_model( model_display_name=model_display_name, project=project, region=region, + **kwargs ) - return self.client.deploy_model(model_name) + return self.auto_ml_client.deploy_model(model_name, **kwargs) def undeploy_model( self, @@ -2048,6 +2404,7 @@ def undeploy_model( model_display_name=None, project=None, region=None, + **kwargs ): """Undeploys a model. @@ -2104,9 +2461,10 @@ def undeploy_model( model_display_name=model_display_name, project=project, region=region, + **kwargs ) - return self.client.undeploy_model(model_name) + return self.auto_ml_client.undeploy_model(model_name, **kwargs) ## TODO(lwander): support pandas DataFrame as input type def predict( @@ -2117,6 +2475,7 @@ def predict( model_display_name=None, project=None, region=None, + **kwargs ): """Makes a prediction on a deployed model. This will fail if the model was not deployed. @@ -2178,6 +2537,7 @@ def predict( model_display_name=model_display_name, project=project, region=region, + **kwargs ) column_specs = model.tables_model_metadata.input_feature_column_specs @@ -2200,7 +2560,7 @@ def predict( request = {"row": {"values": values}} - return self.prediction_client.predict(model.name, request) + return self.prediction_client.predict(model.name, request, **kwargs) def batch_predict( self, @@ -2212,6 +2572,7 @@ def batch_predict( project=None, region=None, inputs=None, + **kwargs ): """Makes a batch prediction on a model. This does _not_ require the model to be deployed. @@ -2281,6 +2642,7 @@ def batch_predict( model_display_name=model_display_name, project=project, region=region, + **kwargs ) if type(gcs_input_uris) != list: @@ -2291,5 +2653,5 @@ def batch_predict( output_request = {"gcs_source": {"output_uri_prefix": gcs_output_uri_prefix}} return self.prediction_client.batch_predict( - model_name, input_request, output_request + model_name, input_request, output_request, **kwargs ) diff --git a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py index 9ff6c2c7b79a..fdbae6381f6d 100644 --- a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py +++ b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py @@ -29,7 +29,7 @@ MAX_WAIT_TIME_SECONDS = 30 MAX_SLEEP_TIME_SECONDS = 5 STATIC_DATASET = "test_dataset_do_not_delete" -STATIC_MODEL='test_model_do_not_delete' +STATIC_MODEL = "test_model_do_not_delete" ID = "{rand}_{time}".format( rand="".join( @@ -58,7 +58,7 @@ def cancel_and_wait(self, op): def test_list_datasets(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) # we need to unroll the iterator to actually make client calls - [d for d in client.list_datasets()] + [d for d in client.list_datasets(timeout=10)] def test_list_models(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) @@ -109,11 +109,19 @@ def test_list_column_specs(self): # we need to unroll the iterator to actually make client calls [d for d in client.list_column_specs(dataset=dataset)] + def test_get_column_spec(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + # we need to unroll the iterator to actually make client calls + cs = [d for d in client.list_column_specs(dataset=dataset)] + client.get_column_spec(cs[0].name) + def test_list_table_specs(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) # we need to unroll the iterator to actually make client calls - [d for d in client.list_table_specs(dataset=dataset)] + ts = [d for d in client.list_table_specs(dataset=dataset)] + client.get_table_spec(ts[0].name) def test_set_column_nullable(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) @@ -172,6 +180,18 @@ def test_create_delete_model(self): self.cancel_and_wait(op) client.delete_model(model_display_name=display_name) + def test_list_model_evaluations(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_online(client) + # we need to unroll the iterator to actually make client calls + [m for m in client.list_model_evaluations(model=model)] + + def test_get_model_evaluation(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_online(client) + me = [m for m in client.list_model_evaluations(model=model)] + client.get_model_evaluation(model_evaluation_name=me[0].name) + def test_online_predict(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) model = self.ensure_model_online(client) diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index cbdbff97d2e3..1bf7478e0343 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -48,8 +48,8 @@ def test_list_datasets_empty(self): {}, ) ds = client.list_datasets() - client.client.location_path.assert_called_with(PROJECT, REGION) - client.client.list_datasets.assert_called_with(LOCATION_PATH) + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_datasets.assert_called_with(LOCATION_PATH) assert ds == [] def test_list_datasets_not_empty(self): @@ -62,8 +62,8 @@ def test_list_datasets_not_empty(self): {}, ) ds = client.list_datasets() - client.client.location_path.assert_called_with(PROJECT, REGION) - client.client.list_datasets.assert_called_with(LOCATION_PATH) + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_datasets.assert_called_with(LOCATION_PATH) assert len(ds) == 1 assert ds[0] == "some_dataset" @@ -76,13 +76,13 @@ def test_get_dataset_no_value(self): except ValueError as e: error = e assert error is not None - client.client.get_dataset.assert_not_called() + client.auto_ml_client.get_dataset.assert_not_called() def test_get_dataset_name(self): dataset_actual = "dataset" client = self.tables_client({"get_dataset.return_value": dataset_actual}, {}) dataset = client.get_dataset(dataset_name="my_dataset") - client.client.get_dataset.assert_called_with("my_dataset") + client.auto_ml_client.get_dataset.assert_called_with("my_dataset") assert dataset == dataset_actual def test_get_no_dataset(self): @@ -95,7 +95,7 @@ def test_get_no_dataset(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.get_dataset.assert_called_with("my_dataset") + client.auto_ml_client.get_dataset.assert_called_with("my_dataset") def test_get_dataset_from_empty_list(self): client = self.tables_client({"list_datasets.return_value": []}, {}) @@ -140,8 +140,8 @@ def test_create_dataset(self): ) metadata = {"metadata": "values"} dataset = client.create_dataset("name", metadata=metadata) - client.client.location_path.assert_called_with(PROJECT, REGION) - client.client.create_dataset.assert_called_with( + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.create_dataset.assert_called_with( LOCATION_PATH, {"display_name": "name", "tables_dataset_metadata": metadata} ) assert dataset.display_name == "name" @@ -151,17 +151,42 @@ def test_delete_dataset(self): dataset.configure_mock(name="name") client = self.tables_client({"delete_dataset.return_value": None}, {}) client.delete_dataset(dataset=dataset) - client.client.delete_dataset.assert_called_with("name") + client.auto_ml_client.delete_dataset.assert_called_with("name") def test_delete_dataset_not_found(self): client = self.tables_client({"list_datasets.return_value": []}, {}) client.delete_dataset(dataset_display_name="not_found") - client.client.delete_dataset.assert_not_called() + client.auto_ml_client.delete_dataset.assert_not_called() def test_delete_dataset_name(self): client = self.tables_client({"delete_dataset.return_value": None}, {}) client.delete_dataset(dataset_name="name") - client.client.delete_dataset.assert_called_with("name") + client.auto_ml_client.delete_dataset.assert_called_with("name") + + def test_export_not_found(self): + client = self.tables_client({"list_datasets.return_value": []}, {}) + error = None + try: + client.export_data(dataset_display_name="name", gcs_input_uris="uri") + except exceptions.NotFound as e: + error = e + assert error is not None + + client.auto_ml_client.export_data.assert_not_called() + + def test_export_gcs_uri(self): + client = self.tables_client({"export_data.return_value": None}, {}) + client.export_data(dataset_name="name", gcs_output_uri_prefix="uri") + client.auto_ml_client.export_data.assert_called_with( + "name", {"gcs_destination": {"output_uri_prefix": "uri"}} + ) + + def test_export_bq_uri(self): + client = self.tables_client({"export_data.return_value": None}, {}) + client.export_data(dataset_name="name", bigquery_output_uri="uri") + client.auto_ml_client.export_data.assert_called_with( + "name", {"bigquery_destination": {"output_uri": "uri"}} + ) def test_import_not_found(self): client = self.tables_client({"list_datasets.return_value": []}, {}) @@ -172,33 +197,33 @@ def test_import_not_found(self): error = e assert error is not None - client.client.import_data.assert_not_called() + client.auto_ml_client.import_data.assert_not_called() def test_import_gcs_uri(self): client = self.tables_client({"import_data.return_value": None}, {}) client.import_data(dataset_name="name", gcs_input_uris="uri") - client.client.import_data.assert_called_with( + client.auto_ml_client.import_data.assert_called_with( "name", {"gcs_source": {"input_uris": ["uri"]}} ) def test_import_gcs_uris(self): client = self.tables_client({"import_data.return_value": None}, {}) client.import_data(dataset_name="name", gcs_input_uris=["uri", "uri"]) - client.client.import_data.assert_called_with( + client.auto_ml_client.import_data.assert_called_with( "name", {"gcs_source": {"input_uris": ["uri", "uri"]}} ) def test_import_bq_uri(self): client = self.tables_client({"import_data.return_value": None}, {}) client.import_data(dataset_name="name", bigquery_input_uri="uri") - client.client.import_data.assert_called_with( + client.auto_ml_client.import_data.assert_called_with( "name", {"bigquery_source": {"input_uri": "uri"}} ) def test_list_table_specs(self): client = self.tables_client({"list_table_specs.return_value": None}, {}) client.list_table_specs(dataset_name="name") - client.client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_table_specs.assert_called_with("name") def test_list_table_specs_not_found(self): client = self.tables_client( @@ -210,7 +235,17 @@ def test_list_table_specs_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_table_specs.assert_called_with("name") + + def test_get_table_spec(self): + client = self.tables_client({}, {}) + client.get_table_spec("name") + client.auto_ml_client.get_table_spec.assert_called_with("name") + + def test_get_column_spec(self): + client = self.tables_client({}, {}) + client.get_column_spec("name") + client.auto_ml_client.get_column_spec.assert_called_with("name") def test_list_column_specs(self): table_spec_mock = mock.Mock() @@ -224,8 +259,8 @@ def test_list_column_specs(self): {}, ) client.list_column_specs(dataset_name="name") - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") def test_update_column_spec_not_found(self): table_spec_mock = mock.Mock() @@ -249,9 +284,9 @@ def test_update_column_spec_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_not_called() def test_update_column_spec_display_name_not_found(self): table_spec_mock = mock.Mock() @@ -277,9 +312,9 @@ def test_update_column_spec_display_name_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_not_called() def test_update_column_spec_name_no_args(self): table_spec_mock = mock.Mock() @@ -298,9 +333,9 @@ def test_update_column_spec_name_no_args(self): {}, ) client.update_column_spec(dataset_name="name", column_spec_name="column/2") - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( {"name": "column/2", "data_type": {"type_code": "type_code"}} ) @@ -323,9 +358,9 @@ def test_update_column_spec_no_args(self): client.update_column_spec( dataset_name="name", column_spec_display_name="column" ) - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( {"name": "column", "data_type": {"type_code": "type_code"}} ) @@ -348,9 +383,9 @@ def test_update_column_spec_nullable(self): client.update_column_spec( dataset_name="name", column_spec_display_name="column", nullable=True ) - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( { "name": "column", "data_type": {"type_code": "type_code", "nullable": True}, @@ -378,9 +413,9 @@ def test_update_column_spec_type_code(self): column_spec_display_name="column", type_code="type_code2", ) - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( {"name": "column", "data_type": {"type_code": "type_code2"}} ) @@ -406,9 +441,9 @@ def test_update_column_spec_type_code_nullable(self): column_spec_display_name="column", type_code="type_code2", ) - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( { "name": "column", "data_type": {"type_code": "type_code2", "nullable": True}, @@ -437,9 +472,9 @@ def test_update_column_spec_type_code_nullable_false(self): column_spec_display_name="column", type_code="type_code2", ) - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_column_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( { "name": "column", "data_type": {"type_code": "type_code2", "nullable": False}, @@ -458,9 +493,9 @@ def test_set_target_column_table_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_not_called() - client.client.update_dataset.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_not_called() + client.auto_ml_client.update_dataset.assert_not_called() def test_set_target_column_not_found(self): table_spec_mock = mock.Mock() @@ -483,9 +518,9 @@ def test_set_target_column_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_dataset.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_not_called() def test_set_target_column(self): table_spec_mock = mock.Mock() @@ -512,9 +547,9 @@ def test_set_target_column(self): {}, ) client.set_target_column(dataset_name="name", column_spec_display_name="column") - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_dataset.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_called_with( { "name": "dataset", "tables_dataset_metadata": { @@ -535,9 +570,9 @@ def test_set_weight_column_table_not_found(self): ) except exceptions.NotFound: pass - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_not_called() - client.client.update_dataset.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_not_called() + client.auto_ml_client.update_dataset.assert_not_called() def test_set_weight_column_not_found(self): table_spec_mock = mock.Mock() @@ -560,9 +595,9 @@ def test_set_weight_column_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_dataset.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_not_called() def test_set_weight_column(self): table_spec_mock = mock.Mock() @@ -589,9 +624,9 @@ def test_set_weight_column(self): {}, ) client.set_weight_column(dataset_name="name", column_spec_display_name="column") - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_dataset.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_called_with( { "name": "dataset", "tables_dataset_metadata": { @@ -615,7 +650,7 @@ def test_clear_weight_column(self): ) client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) client.clear_weight_column(dataset_name="name") - client.client.update_dataset.assert_called_with( + client.auto_ml_client.update_dataset.assert_called_with( { "name": "dataset", "tables_dataset_metadata": { @@ -638,9 +673,9 @@ def test_set_test_train_column_table_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_not_called() - client.client.update_dataset.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_not_called() + client.auto_ml_client.update_dataset.assert_not_called() def test_set_test_train_column_not_found(self): table_spec_mock = mock.Mock() @@ -663,9 +698,9 @@ def test_set_test_train_column_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_dataset.assert_not_called() + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_not_called() def test_set_test_train_column(self): table_spec_mock = mock.Mock() @@ -694,9 +729,9 @@ def test_set_test_train_column(self): client.set_test_train_column( dataset_name="name", column_spec_display_name="column" ) - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_dataset.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_called_with( { "name": "dataset", "tables_dataset_metadata": { @@ -720,7 +755,7 @@ def test_clear_test_train_column(self): ) client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) client.clear_test_train_column(dataset_name="name") - client.client.update_dataset.assert_called_with( + client.auto_ml_client.update_dataset.assert_called_with( { "name": "dataset", "tables_dataset_metadata": { @@ -748,9 +783,9 @@ def test_set_time_column(self): {}, ) client.set_time_column(dataset_name="name", column_spec_display_name="column") - client.client.list_table_specs.assert_called_with("name") - client.client.list_column_specs.assert_called_with("table") - client.client.update_table_spec.assert_called_with( + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_table_spec.assert_called_with( {"name": "table", "time_column_spec_id": "3"} ) @@ -768,10 +803,35 @@ def test_clear_time_column(self): {}, ) client.clear_time_column(dataset_name="name") - client.client.update_table_spec.assert_called_with( + client.auto_ml_client.update_table_spec.assert_called_with( {"name": "table", "time_column_spec_id": None} ) + def test_get_model_evaluation(self): + client = self.tables_client({}, {}) + ds = client.get_model_evaluation(model_evaluation_name="x") + client.auto_ml_client.get_model_evaluation.assert_called_with("x") + + def test_list_model_evaluations_empty(self): + client = self.tables_client({"list_model_evaluations.return_value": []}, {}) + ds = client.list_model_evaluations(model_name="model") + client.auto_ml_client.list_model_evaluations.assert_called_with("model") + assert ds == [] + + def test_list_model_evaluations_not_empty(self): + evaluations = ["eval"] + client = self.tables_client( + { + "list_model_evaluations.return_value": evaluations, + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + ds = client.list_model_evaluations(model_name="model") + client.auto_ml_client.list_model_evaluations.assert_called_with("model") + assert len(ds) == 1 + assert ds[0] == "eval" + def test_list_models_empty(self): client = self.tables_client( { @@ -781,8 +841,8 @@ def test_list_models_empty(self): {}, ) ds = client.list_models() - client.client.location_path.assert_called_with(PROJECT, REGION) - client.client.list_models.assert_called_with(LOCATION_PATH) + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_models.assert_called_with(LOCATION_PATH) assert ds == [] def test_list_models_not_empty(self): @@ -795,8 +855,8 @@ def test_list_models_not_empty(self): {}, ) ds = client.list_models() - client.client.location_path.assert_called_with(PROJECT, REGION) - client.client.list_models.assert_called_with(LOCATION_PATH) + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_models.assert_called_with(LOCATION_PATH) assert len(ds) == 1 assert ds[0] == "some_model" @@ -804,7 +864,7 @@ def test_get_model_name(self): model_actual = "model" client = self.tables_client({"get_model.return_value": model_actual}, {}) model = client.get_model(model_name="my_model") - client.client.get_model.assert_called_with("my_model") + client.auto_ml_client.get_model.assert_called_with("my_model") assert model == model_actual def test_get_no_model(self): @@ -817,7 +877,7 @@ def test_get_no_model(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.get_model.assert_called_with("my_model") + client.auto_ml_client.get_model.assert_called_with("my_model") def test_get_model_from_empty_list(self): client = self.tables_client({"list_models.return_value": []}, {}) @@ -857,17 +917,17 @@ def test_delete_model(self): model.configure_mock(name="name") client = self.tables_client({"delete_model.return_value": None}, {}) client.delete_model(model=model) - client.client.delete_model.assert_called_with("name") + client.auto_ml_client.delete_model.assert_called_with("name") def test_delete_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) client.delete_model(model_display_name="not_found") - client.client.delete_model.assert_not_called() + client.auto_ml_client.delete_model.assert_not_called() def test_delete_model_name(self): client = self.tables_client({"delete_model.return_value": None}, {}) client.delete_model(model_name="name") - client.client.delete_model.assert_called_with("name") + client.auto_ml_client.delete_model.assert_called_with("name") def test_deploy_model_no_args(self): client = self.tables_client({}, {}) @@ -877,12 +937,12 @@ def test_deploy_model_no_args(self): except ValueError as e: error = e assert error is not None - client.client.deploy_model.assert_not_called() + client.auto_ml_client.deploy_model.assert_not_called() def test_deploy_model(self): client = self.tables_client({}, {}) client.deploy_model(model_name="name") - client.client.deploy_model.assert_called_with("name") + client.auto_ml_client.deploy_model.assert_called_with("name") def test_deploy_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) @@ -892,12 +952,12 @@ def test_deploy_model_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.deploy_model.assert_not_called() + client.auto_ml_client.deploy_model.assert_not_called() def test_undeploy_model(self): client = self.tables_client({}, {}) client.undeploy_model(model_name="name") - client.client.undeploy_model.assert_called_with("name") + client.auto_ml_client.undeploy_model.assert_called_with("name") def test_undeploy_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) @@ -907,7 +967,7 @@ def test_undeploy_model_not_found(self): except exceptions.NotFound as e: error = e assert error is not None - client.client.undeploy_model.assert_not_called() + client.auto_ml_client.undeploy_model.assert_not_called() def test_create_model(self): table_spec_mock = mock.Mock() @@ -926,7 +986,7 @@ def test_create_model(self): client.create_model( "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1000 ) - client.client.create_model.assert_called_with( + client.auto_ml_client.create_model.assert_called_with( LOCATION_PATH, { "display_name": "my_model", @@ -960,7 +1020,7 @@ def test_create_model_include_columns(self): include_column_spec_names=["column1"], train_budget_milli_node_hours=1000, ) - client.client.create_model.assert_called_with( + client.auto_ml_client.create_model.assert_called_with( LOCATION_PATH, { "display_name": "my_model", @@ -997,7 +1057,7 @@ def test_create_model_exclude_columns(self): exclude_column_spec_names=["column1"], train_budget_milli_node_hours=1000, ) - client.client.create_model.assert_called_with( + client.auto_ml_client.create_model.assert_called_with( LOCATION_PATH, { "display_name": "my_model", @@ -1019,7 +1079,7 @@ def test_create_model_invalid_hours_small(self): except ValueError as e: error = e assert error is not None - client.client.create_model.assert_not_called() + client.auto_ml_client.create_model.assert_not_called() def test_create_model_invalid_hours_large(self): client = self.tables_client({}, {}) @@ -1033,7 +1093,7 @@ def test_create_model_invalid_hours_large(self): except ValueError as e: error = e assert error is not None - client.client.create_model.assert_not_called() + client.auto_ml_client.create_model.assert_not_called() def test_create_model_invalid_no_dataset(self): client = self.tables_client({}, {}) @@ -1043,8 +1103,8 @@ def test_create_model_invalid_no_dataset(self): except ValueError as e: error = e assert error is not None - client.client.get_dataset.assert_not_called() - client.client.create_model.assert_not_called() + client.auto_ml_client.get_dataset.assert_not_called() + client.auto_ml_client.create_model.assert_not_called() def test_create_model_invalid_include_exclude(self): client = self.tables_client({}, {}) @@ -1060,8 +1120,8 @@ def test_create_model_invalid_include_exclude(self): except ValueError as e: error = e assert error is not None - client.client.get_dataset.assert_not_called() - client.client.create_model.assert_not_called() + client.auto_ml_client.get_dataset.assert_not_called() + client.auto_ml_client.create_model.assert_not_called() def test_predict_from_array(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) @@ -1247,5 +1307,5 @@ def test_batch_predict_no_model(self): except ValueError as e: error = e assert error is not None - client.client.list_models.assert_not_called() + client.auto_ml_client.list_models.assert_not_called() client.prediction_client.batch_predict.assert_not_called() From 3dc7975ba8f5e6d1f3f5c6bc29dccb8df2ee1534 Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Tue, 23 Jul 2019 12:14:36 -0400 Subject: [PATCH 08/11] Support BQ as input/output in batch_predict --- .../automl_v1beta1/tables/tables_client.py | 53 +++++++++++----- .../v1beta1/test_tables_client_v1beta1.py | 60 ++++++++++++++++++- 2 files changed, 96 insertions(+), 17 deletions(-) diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py index 09a8e633f655..303818b1536a 100644 --- a/automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -1999,6 +1999,7 @@ def create_model( dataset_display_name=None, dataset_name=None, train_budget_milli_node_hours=None, + optimization_objective=None, project=None, region=None, model_metadata={}, @@ -2037,6 +2038,8 @@ def create_model( The amount of time (in thousandths of an hour) to spend training. This value must be between 1,000 and 72,000 inclusive (between 1 and 72 hours). + optimization_objective (string): + The metric AutoML tables should optimize for. dataset_display_name (Optional[string]): The human-readable name given to the dataset you want to train your model on. This must be supplied if `dataset` or @@ -2097,6 +2100,8 @@ def create_model( ) model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours + if optimization_objective is not None: + model_metadata["optimization_objective"] = optimization_objective dataset_id = dataset_name.rsplit("/", 1)[-1] columns = [ @@ -2564,8 +2569,10 @@ def predict( def batch_predict( self, - gcs_input_uris, - gcs_output_uri_prefix, + bigquery_input_uri=None, + bigquery_output_uri=None, + gcs_input_uris=None, + gcs_output_uri_prefix=None, model=None, model_name=None, model_display_name=None, @@ -2602,11 +2609,15 @@ def batch_predict( region (Optional[string]): If you have initialized the client with a value for `region` it will be used if this parameter is not supplied. - gcs_input_uris (Union[List[string], string]) + gcs_input_uris (Optional(Union[List[string], string])) Either a list of or a single GCS URI containing the data you want to predict off of. - gcs_output_uri_prefix (string) + gcs_output_uri_prefix (Optional[string]) The folder in GCS you want to write output to. + bigquery_input_uri (Optional[string]) + The BigQuery table to input data from. + bigquery_output_uri (Optional[string]) + The BigQuery table to output data to. model_display_name (Optional[string]): The human-readable name given to the model you want to predict with. This must be supplied if `model` or `model_name` are not @@ -2631,11 +2642,6 @@ def batch_predict( to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - if gcs_input_uris is None or gcs_output_uri_prefix is None: - raise ValueError( - "Both 'gcs_input_uris' and " "'gcs_output_uri_prefix' must be set." - ) - model_name = self.__model_name_from_args( model=model, model_name=model_name, @@ -2645,12 +2651,31 @@ def batch_predict( **kwargs ) - if type(gcs_input_uris) != list: - gcs_input_uris = [gcs_input_uris] - - input_request = {"gcs_source": {"input_uris": gcs_input_uris}} + input_request = None + if gcs_input_uris is not None: + if type(gcs_input_uris) != list: + gcs_input_uris = [gcs_input_uris] + input_request = {"gcs_source": {"input_uris": gcs_input_uris}} + elif bigquery_input_uri is not None: + input_request = {"bigquery_source": {"input_uri": bigquery_input_uri}} + else: + raise ValueError( + "One of 'gcs_input_uris'/'bigquery_input_uris' must" "be set" + ) - output_request = {"gcs_source": {"output_uri_prefix": gcs_output_uri_prefix}} + output_request = None + if gcs_output_uri_prefix is not None: + output_request = { + "gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix} + } + elif bigquery_output_uri is not None: + output_request = { + "bigquery_destination": {"output_uri": bigquery_output_uri} + } + else: + raise ValueError( + "One of 'gcs_output_uri_prefix'/'bigquery_output_uri' must be set" + ) return self.prediction_client.batch_predict( model_name, input_request, output_request, **kwargs diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index 1bf7478e0343..f8192baeaf5f 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -1242,7 +1242,7 @@ def test_predict_from_array_missing(self): assert error is not None client.prediction_client.predict.assert_not_called() - def test_batch_predict(self): + def test_batch_predict_gcs(self): client = self.tables_client({}, {}) client.batch_predict( model_name="my_model", @@ -1252,7 +1252,33 @@ def test_batch_predict(self): client.prediction_client.batch_predict.assert_called_with( "my_model", {"gcs_source": {"input_uris": ["gs://input"]}}, - {"gcs_source": {"output_uri_prefix": "gs://output"}}, + {"gcs_destination": {"output_uri_prefix": "gs://output"}}, + ) + + def test_batch_predict_bigquery(self): + client = self.tables_client({}, {}) + client.batch_predict( + model_name="my_model", + bigquery_input_uri="bq://input", + bigquery_output_uri="bq://output", + ) + client.prediction_client.batch_predict.assert_called_with( + "my_model", + {"bigquery_source": {"input_uri": "bq://input"}}, + {"bigquery_destination": {"output_uri": "bq://output"}}, + ) + + def test_batch_predict_mixed(self): + client = self.tables_client({}, {}) + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + bigquery_output_uri="bq://output", + ) + client.prediction_client.batch_predict.assert_called_with( + "my_model", + {"gcs_source": {"input_uris": ["gs://input"]}}, + {"bigquery_destination": {"output_uri": "bq://output"}}, ) def test_batch_predict_missing_input_gcs_uri(self): @@ -1269,7 +1295,21 @@ def test_batch_predict_missing_input_gcs_uri(self): assert error is not None client.prediction_client.batch_predict.assert_not_called() - def test_batch_predict_missing_input_gcs_uri(self): + def test_batch_predict_missing_input_bigquery_uri(self): + client = self.tables_client({}, {}) + error = None + try: + client.batch_predict( + model_name="my_model", + bigquery_input_uri=None, + gcs_output_uri_prefix="gs://output", + ) + except ValueError as e: + error = e + assert error is not None + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_output_gcs_uri(self): client = self.tables_client({}, {}) error = None try: @@ -1283,6 +1323,20 @@ def test_batch_predict_missing_input_gcs_uri(self): assert error is not None client.prediction_client.batch_predict.assert_not_called() + def test_batch_predict_missing_output_bigquery_uri(self): + client = self.tables_client({}, {}) + error = None + try: + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + bigquery_output_uri=None, + ) + except ValueError as e: + error = e + assert error is not None + client.prediction_client.batch_predict.assert_not_called() + def test_batch_predict_missing_model(self): client = self.tables_client({"list_models.return_value": []}, {}) error = None From b564d962bdd6f0de77b8fe74f33293a6617fa901 Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Thu, 1 Aug 2019 11:08:13 -0400 Subject: [PATCH 09/11] Address first round of feedback --- automl/README.rst | 4 ++-- automl/google/cloud/automl_v1beta1/__init__.py | 8 ++++---- automl/synth.py | 16 ++++++++++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/automl/README.rst b/automl/README.rst index 994cf397ef55..ac7b6d53c77d 100644 --- a/automl/README.rst +++ b/automl/README.rst @@ -117,7 +117,7 @@ development environment: .. code-block:: console pip install -r ../docs/requirements.txt - pip install -U mock pytest + pip install -U nox mock pytest 3. If you want to run all tests, you will need a billing-enabled `GCP project`_, and a `service account`_ with access to the AutoML APIs. @@ -131,5 +131,5 @@ development environment: .. code-block:: console export PROJECT_ID= GOOGLE_APPLICATION_CREDENTIALS= - pytest + nox diff --git a/automl/google/cloud/automl_v1beta1/__init__.py b/automl/google/cloud/automl_v1beta1/__init__.py index 20055ef19a50..474b05550c81 100644 --- a/automl/google/cloud/automl_v1beta1/__init__.py +++ b/automl/google/cloud/automl_v1beta1/__init__.py @@ -24,6 +24,10 @@ from google.cloud.automl_v1beta1.tables import tables_client +class TablesClient(tables_client.TablesClient): + __doc__ = tables_client.TablesClient.__doc__ + + class AutoMlClient(auto_ml_client.AutoMlClient): __doc__ = auto_ml_client.AutoMlClient.__doc__ enums = enums @@ -34,8 +38,4 @@ class PredictionServiceClient(prediction_service_client.PredictionServiceClient) enums = enums -class TablesClient(tables_client.TablesClient): - __doc__ = tables_client.TablesClient.__doc__ - - __all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", "TablesClient") diff --git a/automl/synth.py b/automl/synth.py index 4318ff31fc2b..ab93031fc583 100644 --- a/automl/synth.py +++ b/automl/synth.py @@ -33,6 +33,22 @@ s.move(library / f"tests/unit/gapic/{version}") s.move(library / f"docs/gapic/{version}") + s.replace( + f"google/cloud/automl_{version}/__init__.py", + f"from google.cloud.automl_v1beta1.gapic import prediction_service_client", + f"from google.cloud.automl_v1beta1.gapic import prediction_service_client" + f"from google.cloud.automl_v1beta1.tables import tables_client" + f"\n\n" + f"class TablesClient(tables_client.TablesClient):" + f" __doc__ = tables_client.TablesClient.__doc__" + ) + + s.replace( + f"google/cloud/automl_{version}/__init__.py", + f"__all__ = (\"enums\", \"types\", \"AutoMlClient\", \"PredictionServiceClient\")", + f"__all__ = (\"enums\", \"types\", \"AutoMlClient\", \"PredictionServiceClient\", \"TablesClient\")" + ) + s.move(library / f"docs/conf.py") # Use the highest version library to generate import alias. From 4c85d7481ec20cb5419ce23923f633332bab4485 Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Thu, 1 Aug 2019 13:23:18 -0400 Subject: [PATCH 10/11] Switch to pytest.raises, fix .rst formatting exception --- automl/README.rst | 6 +- .../v1beta1/test_tables_client_v1beta1.py | 186 +++--------------- 2 files changed, 34 insertions(+), 158 deletions(-) diff --git a/automl/README.rst b/automl/README.rst index ac7b6d53c77d..b3332b112dc4 100644 --- a/automl/README.rst +++ b/automl/README.rst @@ -114,10 +114,10 @@ development environment: 1. Make sure you have `virtualenv`_ installed and activated as shown above. 2. Run the following one-time setup (it will be persisted in your virtualenv): -.. code-block:: console + .. code-block:: console - pip install -r ../docs/requirements.txt - pip install -U nox mock pytest + pip install -r ../docs/requirements.txt + pip install -U nox mock pytest 3. If you want to run all tests, you will need a billing-enabled `GCP project`_, and a `service account`_ with access to the AutoML APIs. diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index f8192baeaf5f..5176a31cce6d 100644 --- a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -70,12 +70,8 @@ def test_list_datasets_not_empty(self): def test_get_dataset_no_value(self): dataset_actual = "dataset" client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): dataset = client.get_dataset() - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.get_dataset.assert_not_called() def test_get_dataset_name(self): @@ -89,33 +85,21 @@ def test_get_no_dataset(self): client = self.tables_client( {"get_dataset.side_effect": exceptions.NotFound("err")}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.get_dataset(dataset_name="my_dataset") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.get_dataset.assert_called_with("my_dataset") def test_get_dataset_from_empty_list(self): client = self.tables_client({"list_datasets.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.get_dataset(dataset_display_name="my_dataset") - except exceptions.NotFound as e: - error = e - assert error is not None def test_get_dataset_from_list_not_found(self): client = self.tables_client( {"list_datasets.return_value": [mock.Mock(display_name="not_it")]}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.get_dataset(dataset_display_name="my_dataset") - except exceptions.NotFound as e: - error = e - assert error is not None def test_get_dataset_from_list(self): client = self.tables_client( @@ -165,12 +149,8 @@ def test_delete_dataset_name(self): def test_export_not_found(self): client = self.tables_client({"list_datasets.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.export_data(dataset_display_name="name", gcs_input_uris="uri") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.export_data.assert_not_called() @@ -190,12 +170,8 @@ def test_export_bq_uri(self): def test_import_not_found(self): client = self.tables_client({"list_datasets.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.import_data(dataset_display_name="name", gcs_input_uris="uri") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.import_data.assert_not_called() @@ -229,12 +205,8 @@ def test_list_table_specs_not_found(self): client = self.tables_client( {"list_table_specs.side_effect": exceptions.NotFound("not found")}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.list_table_specs(dataset_name="name") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") def test_get_table_spec(self): @@ -278,12 +250,8 @@ def test_update_column_spec_not_found(self): }, {}, ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.update_column_spec(dataset_name="name", column_spec_name="column2") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_called_with("table") client.auto_ml_client.update_column_spec.assert_not_called() @@ -304,14 +272,10 @@ def test_update_column_spec_display_name_not_found(self): }, {}, ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.update_column_spec( dataset_name="name", column_spec_display_name="column2" ) - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_called_with("table") client.auto_ml_client.update_column_spec.assert_not_called() @@ -485,14 +449,10 @@ def test_set_target_column_table_not_found(self): client = self.tables_client( {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.set_target_column( dataset_name="name", column_spec_display_name="column2" ) - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_not_called() client.auto_ml_client.update_dataset.assert_not_called() @@ -510,14 +470,10 @@ def test_set_target_column_not_found(self): }, {}, ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.set_target_column( dataset_name="name", column_spec_display_name="column2" ) - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_called_with("table") client.auto_ml_client.update_dataset.assert_not_called() @@ -587,14 +543,10 @@ def test_set_weight_column_not_found(self): }, {}, ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.set_weight_column( dataset_name="name", column_spec_display_name="column2" ) - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_called_with("table") client.auto_ml_client.update_dataset.assert_not_called() @@ -665,14 +617,10 @@ def test_set_test_train_column_table_not_found(self): client = self.tables_client( {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.set_test_train_column( dataset_name="name", column_spec_display_name="column2" ) - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_not_called() client.auto_ml_client.update_dataset.assert_not_called() @@ -690,14 +638,10 @@ def test_set_test_train_column_not_found(self): }, {}, ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.set_test_train_column( dataset_name="name", column_spec_display_name="column2" ) - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.list_table_specs.assert_called_with("name") client.auto_ml_client.list_column_specs.assert_called_with("table") client.auto_ml_client.update_dataset.assert_not_called() @@ -871,33 +815,21 @@ def test_get_no_model(self): client = self.tables_client( {"get_model.side_effect": exceptions.NotFound("err")}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.get_model(model_name="my_model") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.get_model.assert_called_with("my_model") def test_get_model_from_empty_list(self): client = self.tables_client({"list_models.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.get_model(model_display_name="my_model") - except exceptions.NotFound as e: - error = e - assert error is not None def test_get_model_from_list_not_found(self): client = self.tables_client( {"list_models.return_value": [mock.Mock(display_name="not_it")]}, {} ) - error = None - try: + with pytest.raises(exceptions.NotFound): client.get_model(model_display_name="my_model") - except exceptions.NotFound as e: - error = e - assert error is not None def test_get_model_from_list(self): client = self.tables_client( @@ -931,12 +863,8 @@ def test_delete_model_name(self): def test_deploy_model_no_args(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.deploy_model() - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.deploy_model.assert_not_called() def test_deploy_model(self): @@ -946,12 +874,8 @@ def test_deploy_model(self): def test_deploy_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.deploy_model(model_display_name="name") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.deploy_model.assert_not_called() def test_undeploy_model(self): @@ -961,12 +885,8 @@ def test_undeploy_model(self): def test_undeploy_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.undeploy_model(model_display_name="name") - except exceptions.NotFound as e: - error = e - assert error is not None client.auto_ml_client.undeploy_model.assert_not_called() def test_create_model(self): @@ -1071,45 +991,32 @@ def test_create_model_exclude_columns(self): def test_create_model_invalid_hours_small(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.create_model( "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1 ) - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.create_model.assert_not_called() def test_create_model_invalid_hours_large(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.create_model( "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1000000, ) - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.create_model.assert_not_called() def test_create_model_invalid_no_dataset(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.create_model("my_model", train_budget_milli_node_hours=1000) - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.get_dataset.assert_not_called() client.auto_ml_client.create_model.assert_not_called() def test_create_model_invalid_include_exclude(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.create_model( "my_model", dataset_name="my_dataset", @@ -1117,9 +1024,6 @@ def test_create_model_invalid_include_exclude(self): exclude_column_spec_names=["b"], train_budget_milli_node_hours=1000, ) - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.get_dataset.assert_not_called() client.auto_ml_client.create_model.assert_not_called() @@ -1234,12 +1138,8 @@ def test_predict_from_array_missing(self): model = mock.Mock() model.configure_mock(tables_model_metadata=model_metadata, name="my_model") client = self.tables_client({"get_model.return_value": model}, {}) - error = None - try: + with pytest.raises(ValueError): client.predict([], model_name="my_model") - except ValueError as e: - error = e - assert error is not None client.prediction_client.predict.assert_not_called() def test_batch_predict_gcs(self): @@ -1283,83 +1183,59 @@ def test_batch_predict_mixed(self): def test_batch_predict_missing_input_gcs_uri(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.batch_predict( model_name="my_model", gcs_input_uris=None, gcs_output_uri_prefix="gs://output", ) - except ValueError as e: - error = e - assert error is not None client.prediction_client.batch_predict.assert_not_called() def test_batch_predict_missing_input_bigquery_uri(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.batch_predict( model_name="my_model", bigquery_input_uri=None, gcs_output_uri_prefix="gs://output", ) - except ValueError as e: - error = e - assert error is not None client.prediction_client.batch_predict.assert_not_called() def test_batch_predict_missing_output_gcs_uri(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.batch_predict( model_name="my_model", gcs_input_uris="gs://input", gcs_output_uri_prefix=None, ) - except ValueError as e: - error = e - assert error is not None client.prediction_client.batch_predict.assert_not_called() def test_batch_predict_missing_output_bigquery_uri(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.batch_predict( model_name="my_model", gcs_input_uris="gs://input", bigquery_output_uri=None, ) - except ValueError as e: - error = e - assert error is not None client.prediction_client.batch_predict.assert_not_called() def test_batch_predict_missing_model(self): client = self.tables_client({"list_models.return_value": []}, {}) - error = None - try: + with pytest.raises(exceptions.NotFound): client.batch_predict( model_display_name="my_model", gcs_input_uris="gs://input", gcs_output_uri_prefix="gs://output", ) - except exceptions.NotFound as e: - error = e - assert error is not None client.prediction_client.batch_predict.assert_not_called() def test_batch_predict_no_model(self): client = self.tables_client({}, {}) - error = None - try: + with pytest.raises(ValueError): client.batch_predict( gcs_input_uris="gs://input", gcs_output_uri_prefix="gs://output" ) - except ValueError as e: - error = e - assert error is not None client.auto_ml_client.list_models.assert_not_called() client.prediction_client.batch_predict.assert_not_called() From 95871501e03cc4ddfa129bd22ae1f113a83e5b1c Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Thu, 8 Aug 2019 14:30:03 -0400 Subject: [PATCH 11/11] Make list system tests more stringent --- .../v1beta1/test_system_tables_client_v1.py | 52 +++++++++++++------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py index fdbae6381f6d..2a763cdc24dc 100644 --- a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py +++ b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py @@ -57,13 +57,19 @@ def cancel_and_wait(self, op): def test_list_datasets(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) - # we need to unroll the iterator to actually make client calls - [d for d in client.list_datasets(timeout=10)] + dataset = self.ensure_dataset_ready(client) + # will raise if not found + next( + iter( + [d for d in client.list_datasets(timeout=10) if d.name == dataset.name] + ) + ) def test_list_models(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) - # we need to unroll the iterator to actually make client calls - [m for m in client.list_models()] + model = self.ensure_model_ready(client) + # will raise if not found + next(iter([m for m in client.list_models(timeout=10) if m.name == model.name])) def test_create_delete_dataset(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) @@ -106,22 +112,28 @@ def ensure_dataset_ready(self, client): def test_list_column_specs(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - # we need to unroll the iterator to actually make client calls - [d for d in client.list_column_specs(dataset=dataset)] + # will raise if not found + next( + iter( + [ + d + for d in client.list_column_specs(dataset=dataset) + if d.display_name == "Deposit" + ] + ) + ) def test_get_column_spec(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - # we need to unroll the iterator to actually make client calls - cs = [d for d in client.list_column_specs(dataset=dataset)] - client.get_column_spec(cs[0].name) + name = [d for d in client.list_column_specs(dataset=dataset)][0].name + assert client.get_column_spec(name).name == name def test_list_table_specs(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) dataset = self.ensure_dataset_ready(client) - # we need to unroll the iterator to actually make client calls - ts = [d for d in client.list_table_specs(dataset=dataset)] - client.get_table_spec(ts[0].name) + name = [d for d in client.list_table_specs(dataset=dataset)][0].name + assert client.get_table_spec(name).name == name def test_set_column_nullable(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) @@ -183,14 +195,22 @@ def test_create_delete_model(self): def test_list_model_evaluations(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) model = self.ensure_model_online(client) - # we need to unroll the iterator to actually make client calls - [m for m in client.list_model_evaluations(model=model)] + # will raise if not found + next( + iter( + [ + m + for m in client.list_model_evaluations(model=model) + if m.display_name is not None + ] + ) + ) def test_get_model_evaluation(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) model = self.ensure_model_online(client) - me = [m for m in client.list_model_evaluations(model=model)] - client.get_model_evaluation(model_evaluation_name=me[0].name) + name = [m for m in client.list_model_evaluations(model=model)][0].name + assert client.get_model_evaluation(model_evaluation_name=name).name == name def test_online_predict(self): client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION)