1919import datetime
2020import logging
2121from contextlib import closing
22- from typing import TYPE_CHECKING
22+ from typing import TYPE_CHECKING , Any
2323from urllib .parse import quote , urlparse , urlunparse
2424
2525from airflow .providers .common .compat .openlineage .check import require_openlineage_version
3131 from openlineage .client .facet_v2 import JobFacet
3232
3333 from airflow .providers .snowflake .hooks .snowflake import SnowflakeHook
34+ from airflow .providers .snowflake .hooks .snowflake_sql_api import SnowflakeSqlApiHook
3435
3536
3637log = logging .getLogger (__name__ )
@@ -204,9 +205,29 @@ def _run_single_query_with_hook(hook: SnowflakeHook, sql: str) -> list[dict]:
204205 return result
205206
206207
208+ def _run_single_query_with_api_hook (hook : SnowflakeSqlApiHook , sql : str ) -> list [dict [str , Any ]]:
209+ """Execute a query against Snowflake API without adding extra logging or instrumentation."""
210+ # `hook.execute_query` resets the query_ids, so we need to save them and re-assign after we're done
211+ query_ids_before_execution = list (hook .query_ids )
212+ try :
213+ _query_ids = hook .execute_query (sql = sql , statement_count = 0 )
214+ hook .wait_for_query (query_id = _query_ids [0 ], raise_error = True , poll_interval = 1 , timeout = 3 )
215+ return hook .get_result_from_successful_sql_api_query (query_id = _query_ids [0 ])
216+ finally :
217+ hook .query_ids = query_ids_before_execution
218+
219+
220+ def _process_data_from_api (data : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
221+ """Convert 'START_TIME' and 'END_TIME' fields to UTC datetime objects."""
222+ for row in data :
223+ for key in ("START_TIME" , "END_TIME" ):
224+ row [key ] = datetime .datetime .fromtimestamp (float (row [key ]), timezone .utc )
225+ return data
226+
227+
207228def _get_queries_details_from_snowflake (
208- hook : SnowflakeHook , query_ids : list [str ]
209- ) -> dict [str , dict [str , str ]]:
229+ hook : SnowflakeHook | SnowflakeSqlApiHook , query_ids : list [str ]
230+ ) -> dict [str , dict [str , Any ]]:
210231 """Retrieve execution details for specific queries from Snowflake's query history."""
211232 if not query_ids :
212233 return {}
@@ -221,7 +242,16 @@ def _get_queries_details_from_snowflake(
221242 f";"
222243 )
223244
224- result = _run_single_query_with_hook (hook = hook , sql = query )
245+ try :
246+ # Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports
247+ if hook .__class__ .__name__ == "SnowflakeSqlApiHook" :
248+ result = _run_single_query_with_api_hook (hook = hook , sql = query ) # type: ignore[arg-type]
249+ result = _process_data_from_api (data = result )
250+ else :
251+ result = _run_single_query_with_hook (hook = hook , sql = query )
252+ except Exception as e :
253+ log .warning ("OpenLineage could not retrieve extra metadata from Snowflake. Error encountered: %s" , e )
254+ result = []
225255
226256 return {row ["QUERY_ID" ]: row for row in result } if result else {}
227257
@@ -259,17 +289,18 @@ def _create_snowflake_event_pair(
259289
260290@require_openlineage_version (provider_min_version = "2.3.0" )
261291def emit_openlineage_events_for_snowflake_queries (
262- query_ids : list [str ],
263- query_source_namespace : str ,
264292 task_instance ,
265- hook : SnowflakeHook | None = None ,
293+ hook : SnowflakeHook | SnowflakeSqlApiHook | None = None ,
294+ query_ids : list [str ] | None = None ,
295+ query_source_namespace : str | None = None ,
296+ query_for_extra_metadata : bool = False ,
266297 additional_run_facets : dict | None = None ,
267298 additional_job_facets : dict | None = None ,
268299) -> None :
269300 """
270301 Emit OpenLineage events for executed Snowflake queries.
271302
272- Metadata retrieval from Snowflake is attempted only if a `SnowflakeHook` is provided.
303+ Metadata retrieval from Snowflake is attempted only if `get_extra_metadata` is True and hook is provided.
273304 If metadata is available, execution details such as start time, end time, execution status,
274305 error messages, and SQL text are included in the events. If no metadata is found, the function
275306 defaults to using the Airflow task instance's state and the current timestamp.
@@ -279,10 +310,16 @@ def emit_openlineage_events_for_snowflake_queries(
279310 will correspond to actual query execution times.
280311
281312 Args:
282- query_ids: A list of Snowflake query IDs to emit events for.
283- query_source_namespace: The namespace to be included in ExternalQueryRunFacet.
284313 task_instance: The Airflow task instance that run these queries.
285- hook: A SnowflakeHook instance used to retrieve query metadata if available.
314+ hook: A supported Snowflake hook instance used to retrieve query metadata if available.
315+ If omitted, `query_ids` and `query_source_namespace` must be provided explicitly and
316+ `query_for_extra_metadata` must be `False`.
317+ query_ids: A list of Snowflake query IDs to emit events for, can only be None if `hook` is provided
318+ and `hook.query_ids` are present.
319+ query_source_namespace: The namespace to be included in ExternalQueryRunFacet,
320+ can be `None` only if hook is provided.
321+ query_for_extra_metadata: Whether to query Snowflake for additional metadata about queries.
322+ Must be `False` if `hook` is not provided.
286323 additional_run_facets: Additional run facets to include in OpenLineage events.
287324 additional_job_facets: Additional job facets to include in OpenLineage events.
288325 """
@@ -297,23 +334,49 @@ def emit_openlineage_events_for_snowflake_queries(
297334 from airflow .providers .openlineage .conf import namespace
298335 from airflow .providers .openlineage .plugins .listener import get_openlineage_listener
299336
300- if not query_ids :
301- log .debug ("No Snowflake query IDs provided; skipping OpenLineage event emission." )
302- return
303-
304- query_ids = [q for q in query_ids ] # Make a copy to make sure it does not change
337+ log .info ("OpenLineage will emit events for Snowflake queries." )
305338
306339 if hook :
340+ if not query_ids :
341+ log .debug ("No Snowflake query IDs provided; Checking `hook.query_ids` property." )
342+ query_ids = getattr (hook , "query_ids" , [])
343+ if not query_ids :
344+ raise ValueError ("No Snowflake query IDs provided and `hook.query_ids` are not present." )
345+
346+ if not query_source_namespace :
347+ log .debug ("No Snowflake query namespace provided; Creating one from scratch." )
348+ from airflow .providers .openlineage .sqlparser import SQLParser
349+
350+ connection = hook .get_connection (hook .get_conn_id ())
351+ query_source_namespace = SQLParser .create_namespace (
352+ hook .get_openlineage_database_info (connection )
353+ )
354+ else :
355+ if not query_ids :
356+ raise ValueError ("If 'hook' is not provided, 'query_ids' must be set." )
357+ if not query_source_namespace :
358+ raise ValueError ("If 'hook' is not provided, 'query_source_namespace' must be set." )
359+ if query_for_extra_metadata :
360+ raise ValueError ("If 'hook' is not provided, 'query_for_extra_metadata' must be False." )
361+
362+ query_ids = [q for q in query_ids ] # Make a copy to make sure we do not change hook's attribute
363+
364+ if query_for_extra_metadata and hook :
307365 log .debug ("Retrieving metadata for %s queries from Snowflake." , len (query_ids ))
308366 snowflake_metadata = _get_queries_details_from_snowflake (hook , query_ids )
309367 else :
310- log .debug ("SnowflakeHook not provided . No extra metadata fill be fetched from Snowflake." )
368+ log .debug ("`query_for_extra_metadata` is False . No extra metadata fill be fetched from Snowflake." )
311369 snowflake_metadata = {}
312370
313371 # If real metadata is unavailable, we send events with eventTime=now
314372 default_event_time = timezone .utcnow ()
315373 # If no query metadata is provided, we use task_instance's state when checking for success
316- default_state = task_instance .state .value if hasattr (task_instance , "state" ) else ""
374+ # ti.state has no `value` attr (AF2) when task it's still running, in AF3 we get 'running', in that case
375+ # assuming it's user call and query succeeded, so we replace it with success.
376+ default_state = (
377+ getattr (task_instance .state , "value" , "running" ) if hasattr (task_instance , "state" ) else ""
378+ )
379+ default_state = "success" if default_state == "running" else default_state
317380
318381 common_run_facets = {"parent" : _get_parent_run_facet (task_instance )}
319382 common_job_facets : dict [str , JobFacet ] = {
0 commit comments