diff --git a/airflow/gcp/hooks/bigquery.py b/airflow/gcp/hooks/bigquery.py index f42a875b55e1c..8cb38a077500e 100644 --- a/airflow/gcp/hooks/bigquery.py +++ b/airflow/gcp/hooks/bigquery.py @@ -1798,7 +1798,8 @@ def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dict return dataset_resource - def get_datasets_list(self, project_id: Optional[str] = None) -> List: + @CloudBaseHook.catch_http_exception + def get_datasets_list(self, project_id=None, max_results=None, all_datasets=False): """ Method returns full list of BigQuery datasets in the current project @@ -1806,9 +1807,13 @@ def get_datasets_list(self, project_id: Optional[str] = None) -> List: For more information, see: https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list - :param project_id: Google Cloud Project for which you - try to get all datasets + :param project_id: Project ID of the datasets to be listed :type project_id: str + :param max_results: The maximum number of results to return in a single response page. + :type max_results: int + :param all_datasets: Whether to list all datasets, including hidden ones + :type all_datasets: bool + :return: datasets_list Example of returned datasets_list: :: @@ -1833,16 +1838,26 @@ def get_datasets_list(self, project_id: Optional[str] = None) -> List: } ] """ + dataset_project_id = project_id if project_id else self.project_id - try: - datasets_list = self.service.datasets().list( - projectId=dataset_project_id).execute(num_retries=self.num_retries)['datasets'] - self.log.info("Datasets List: %s", datasets_list) + optional_params = {'all': all_datasets} + if max_results: + optional_params['maxResults'] = max_results - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content)) + request = self.service.datasets().list( + projectId=dataset_project_id, + **optional_params) + + datasets_list = [] + + while request is not None: + response = request.execute(num_retries=self.num_retries) + datasets_list.extend(response['datasets']) + request = self.service.datasets().list_next(previous_request=request, + previous_response=response) + + self.log.info("%s items found", len(datasets_list)) return datasets_list diff --git a/tests/gcp/hooks/test_bigquery.py b/tests/gcp/hooks/test_bigquery.py index 0e5cb74d75120..8d143859c1cae 100644 --- a/tests/gcp/hooks/test_bigquery.py +++ b/tests/gcp/hooks/test_bigquery.py @@ -776,7 +776,7 @@ def test_get_dataset(self): self.assertEqual(result, expected_result) def test_get_datasets_list(self): - expected_result = {'datasets': [ + expected_result = [ { "kind": "bigquery#dataset", "location": "US", @@ -795,18 +795,48 @@ def test_get_datasets_list(self): "datasetId": "dataset_1_test" } } - ]} - project_id = "project_test"'' + ] + project_id = "project_test" - mocked = mock.Mock() - with mock.patch.object(hook.BigQueryBaseCursor(mocked, project_id).service, - 'datasets') as MockService: - MockService.return_value.list( - projectId=project_id).execute.return_value = expected_result - result = hook.BigQueryBaseCursor( - mocked, "test_create_empty_dataset").get_datasets_list( - project_id=project_id) - self.assertEqual(result, expected_result['datasets']) + mock_service = mock.Mock() + cursor = hook.BigQueryBaseCursor(mock_service, project_id) + mock_service.datasets.return_value.list.return_value.execute.return_value = { + 'datasets': expected_result} + mock_service.datasets.return_value.list_next.return_value = None + + result = cursor.get_datasets_list(project_id=project_id) + + self.assertEqual(result, expected_result) + + def test_get_datasets_list_multiple_pages(self): + dataset = { + "kind": "bigquery#dataset", + "location": "US", + "id": "your-project:dataset_2_test", + "datasetReference": { + "projectId": "your-project", + "datasetId": "dataset_2_test" + } + } + expected_result = [dataset] * 4 + project_id = "project_test" + + pages_requests = [ + mock.Mock(**{'execute.return_value': {"datasets": [dataset]}}) + for _ in range(4) + ] + datasets_mock = mock.Mock( + **{'list.return_value': pages_requests[1], + 'list_next.side_effect': pages_requests[1:] + [None]} + ) + + mock_service = mock.Mock() + cursor = hook.BigQueryBaseCursor(mock_service, project_id) + mock_service.datasets.return_value = datasets_mock + + result = cursor.get_datasets_list(project_id=project_id) + + self.assertEqual(result, expected_result) def test_get_dataset_tables_list(self): tables_list_result = [