Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions airflow/gcp/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,17 +1798,22 @@ def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dict

return dataset_resource

def get_datasets_list(self, project_id: Optional[str] = None) -> List:
@CloudBaseHook.catch_http_exception
def get_datasets_list(self, project_id=None, max_results=None, all_datasets=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add type annotations, please? I think it is good idea to keep them :)

"""
Method returns full list of BigQuery datasets in the current project

.. seealso::
For more information, see:
https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list

:param project_id: Google Cloud Project for which you
try to get all datasets
:param project_id: Project ID of the datasets to be listed
:type project_id: str
:param max_results: The maximum number of results to return in a single response page.
:type max_results: int
:param all_datasets: Whether to list all datasets, including hidden ones
:type all_datasets: bool

:return: datasets_list

Example of returned datasets_list: ::
Expand All @@ -1833,16 +1838,26 @@ def get_datasets_list(self, project_id: Optional[str] = None) -> List:
}
]
"""

dataset_project_id = project_id if project_id else self.project_id

try:
datasets_list = self.service.datasets().list(
projectId=dataset_project_id).execute(num_retries=self.num_retries)['datasets']
self.log.info("Datasets List: %s", datasets_list)
optional_params = {'all': all_datasets}
if max_results:
optional_params['maxResults'] = max_results

except HttpError as err:
raise AirflowException(
'BigQuery job failed. Error was: {}'.format(err.content))
request = self.service.datasets().list(
projectId=dataset_project_id,
**optional_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder is this optional_params dictionary is required. I've checked how it is done here: airflow.providers.google.marketing_platform.hooks.campaign_manager.GoogleCampaignManagerHook.list_reports and it is slightly cleaner. WDYT?


datasets_list = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could just use datasets here and add type. What do you think?


while request is not None:
response = request.execute(num_retries=self.num_retries)
datasets_list.extend(response['datasets'])
request = self.service.datasets().list_next(previous_request=request,
previous_response=response)

self.log.info("%s items found", len(datasets_list))

return datasets_list

Expand Down
54 changes: 42 additions & 12 deletions tests/gcp/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_get_dataset(self):
self.assertEqual(result, expected_result)

def test_get_datasets_list(self):
expected_result = {'datasets': [
expected_result = [
{
"kind": "bigquery#dataset",
"location": "US",
Expand All @@ -795,18 +795,48 @@ def test_get_datasets_list(self):
"datasetId": "dataset_1_test"
}
}
]}
project_id = "project_test"''
]
project_id = "project_test"

mocked = mock.Mock()
with mock.patch.object(hook.BigQueryBaseCursor(mocked, project_id).service,
'datasets') as MockService:
MockService.return_value.list(
projectId=project_id).execute.return_value = expected_result
result = hook.BigQueryBaseCursor(
mocked, "test_create_empty_dataset").get_datasets_list(
project_id=project_id)
self.assertEqual(result, expected_result['datasets'])
mock_service = mock.Mock()
cursor = hook.BigQueryBaseCursor(mock_service, project_id)
mock_service.datasets.return_value.list.return_value.execute.return_value = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could check here by using assert_called_once_with with what parameters list method was called.

'datasets': expected_result}
mock_service.datasets.return_value.list_next.return_value = None

result = cursor.get_datasets_list(project_id=project_id)

self.assertEqual(result, expected_result)

def test_get_datasets_list_multiple_pages(self):
dataset = {
"kind": "bigquery#dataset",
"location": "US",
"id": "your-project:dataset_2_test",
"datasetReference": {
"projectId": "your-project",
"datasetId": "dataset_2_test"
}
}
expected_result = [dataset] * 4
project_id = "project_test"

pages_requests = [
mock.Mock(**{'execute.return_value': {"datasets": [dataset]}})
for _ in range(4)
]
datasets_mock = mock.Mock(
**{'list.return_value': pages_requests[1],
'list_next.side_effect': pages_requests[1:] + [None]}
)

mock_service = mock.Mock()
cursor = hook.BigQueryBaseCursor(mock_service, project_id)
mock_service.datasets.return_value = datasets_mock

result = cursor.get_datasets_list(project_id=project_id)

self.assertEqual(result, expected_result)

def test_get_dataset_tables_list(self):
tables_list_result = [
Expand Down