diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 8cf3489ccf2ae..10aafd5d662bb 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -802,7 +802,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator): :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) :param project_id: (Optional) The name of the project where the data - will be returned from. (templated) + will be returned from. If None, it will be derived from the hook's project ID. (templated) :param max_results: The maximum number of records (rows) to be fetched from the table. (templated) :param selected_fields: List of fields to return (comma-separated). If @@ -872,7 +872,7 @@ def _submit_job( hook: BigQueryHook, job_id: str, ) -> BigQueryJob: - get_query = self.generate_query() + get_query = self.generate_query(hook=hook) configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}} """Submit a new job and get the job id for polling the status using Triggerer.""" return hook.insert_job( @@ -883,17 +883,21 @@ def _submit_job( nowait=True, ) - def generate_query(self) -> str: + def generate_query(self, hook: BigQueryHook) -> str: """ Generate a select query if selected fields are given or with * for the given dataset and table id + :param hook BigQuery Hook """ query = "select " if self.selected_fields: query += self.selected_fields else: query += "*" - query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` limit {self.max_results}" + query += ( + f" from `{self.project_id or hook.project_id}.{self.dataset_id}" + f".{self.table_id}` limit {self.max_results}" + ) return query def execute(self, context: Context): @@ -906,7 +910,7 @@ def execute(self, context: Context): if not self.deferrable: self.log.info( "Fetching Data from %s.%s.%s max results: %s", - self.project_id, + self.project_id or hook.project_id, self.dataset_id, self.table_id, self.max_results, diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 1da7f87f90259..c7b17af2ed5f5 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -187,6 +187,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "project_id": self.project_id, "table_id": self.table_id, "poll_interval": self.poll_interval, + "as_dict": self.as_dict, }, ) diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 1e871678a9c40..b0547cec3f18c 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -814,6 +814,38 @@ def test_execute(self, mock_hook, as_dict): location=TEST_DATASET_LOCATION, ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_generate_query__with_project_id(self, mock_hook): + operator = BigQueryGetDataOperator( + gcp_conn_id=GCP_CONN_ID, + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + project_id=TEST_GCP_PROJECT_ID, + max_results=100, + use_legacy_sql=False, + ) + assert ( + operator.generate_query(hook=mock_hook) == f"select * from `{TEST_GCP_PROJECT_ID}." + f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100" + ) + + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_generate_query__without_project_id(self, mock_hook): + hook_project_id = mock_hook.project_id + operator = BigQueryGetDataOperator( + gcp_conn_id=GCP_CONN_ID, + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + max_results=100, + use_legacy_sql=False, + ) + assert ( + operator.generate_query(hook=mock_hook) == f"select * from `{hook_project_id}." + f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100" + ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_bigquery_get_data_operator_async_with_selected_fields( self, mock_hook, create_task_instance_of_operator diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index 410aa14b1a7ff..cb997259bd680 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -224,6 +224,7 @@ def test_bigquery_get_data_trigger_serialization(self, get_data_trigger): classpath, kwargs = get_data_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger" assert kwargs == { + "as_dict": False, "conn_id": TEST_CONN_ID, "job_id": TEST_JOB_ID, "dataset_id": TEST_DATASET_ID,