Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param project_id: (Optional) The name of the project where the data
will be returned from. (templated)
will be returned from. If None, it will be derived from the hook's project ID. (templated)
:param max_results: The maximum number of records (rows) to be fetched
from the table. (templated)
:param selected_fields: List of fields to return (comma-separated). If
Expand Down Expand Up @@ -872,7 +872,7 @@ def _submit_job(
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
get_query = self.generate_query()
get_query = self.generate_query(hook=hook)
configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}}
"""Submit a new job and get the job id for polling the status using Triggerer."""
return hook.insert_job(
Expand All @@ -883,17 +883,21 @@ def _submit_job(
nowait=True,
)

def generate_query(self) -> str:
def generate_query(self, hook: BigQueryHook) -> str:
"""
Generate a select query if selected fields are given or with *
for the given dataset and table id
:param hook BigQuery Hook
"""
query = "select "
if self.selected_fields:
query += self.selected_fields
else:
query += "*"
query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` limit {self.max_results}"
query += (
f" from `{self.project_id or hook.project_id}.{self.dataset_id}"
f".{self.table_id}` limit {self.max_results}"
)
return query

def execute(self, context: Context):
Expand All @@ -906,7 +910,7 @@ def execute(self, context: Context):
if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s.%s max results: %s",
self.project_id,
self.project_id or hook.project_id,
self.dataset_id,
self.table_id,
self.max_results,
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"as_dict": self.as_dict,
},
)

Expand Down
32 changes: 32 additions & 0 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,38 @@ def test_execute(self, mock_hook, as_dict):
location=TEST_DATASET_LOCATION,
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_generate_query__with_project_id(self, mock_hook):
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
project_id=TEST_GCP_PROJECT_ID,
max_results=100,
use_legacy_sql=False,
)
assert (
operator.generate_query(hook=mock_hook) == f"select * from `{TEST_GCP_PROJECT_ID}."
f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100"
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_generate_query__without_project_id(self, mock_hook):
hook_project_id = mock_hook.project_id
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
max_results=100,
use_legacy_sql=False,
)
assert (
operator.generate_query(hook=mock_hook) == f"select * from `{hook_project_id}."
f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100"
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_get_data_operator_async_with_selected_fields(
self, mock_hook, create_task_instance_of_operator
Expand Down
1 change: 1 addition & 0 deletions tests/providers/google/cloud/triggers/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):
classpath, kwargs = get_data_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger"
assert kwargs == {
"as_dict": False,
"conn_id": TEST_CONN_ID,
"job_id": TEST_JOB_ID,
"dataset_id": TEST_DATASET_ID,
Expand Down