Skip to content

Commit de1444c

Browse files
committed
feat: Add explicit support for SnowflakeSqlApiHook to Ol helper
1 parent 26dbca5 commit de1444c

File tree

4 files changed

+810
-68
lines changed

4 files changed

+810
-68
lines changed

providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -617,10 +617,9 @@ def _get_openlineage_authority(self, _) -> str | None:
617617

618618
def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
619619
"""
620-
Generate OpenLineage metadata for a Snowflake task instance based on executed query IDs.
620+
Emit separate OpenLineage events for each Snowflake query, based on executed query IDs.
621621
622-
If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata.
623-
If multiple query IDs are present, emits separate OpenLineage events for each query.
622+
If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata.
624623
625624
Note that `get_openlineage_database_specific_lineage` is usually called after task's execution,
626625
so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted
@@ -641,13 +640,22 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi
641640
)
642641

643642
if not self.query_ids:
644-
self.log.debug("openlineage: no snowflake query ids found.")
643+
self.log.info("OpenLineage could not find snowflake query ids.")
645644
return None
646645

647646
self.log.debug("openlineage: getting connection to get database info")
648647
connection = self.get_connection(self.get_conn_id())
649648
namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))
650649

650+
self.log.info("Separate OpenLineage events will be emitted for each query_id.")
651+
emit_openlineage_events_for_snowflake_queries(
652+
task_instance=task_instance,
653+
hook=self,
654+
query_ids=self.query_ids,
655+
query_for_extra_metadata=True,
656+
query_source_namespace=namespace,
657+
)
658+
651659
if len(self.query_ids) == 1:
652660
self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.")
653661
return OperatorLineage(
@@ -658,20 +666,4 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi
658666
}
659667
)
660668

661-
self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.")
662-
try:
663-
from airflow.providers.openlineage.utils.utils import should_use_external_connection
664-
665-
use_external_connection = should_use_external_connection(self)
666-
except ImportError:
667-
# OpenLineage provider release < 1.8.0 - we always use connection
668-
use_external_connection = True
669-
670-
emit_openlineage_events_for_snowflake_queries(
671-
query_ids=self.query_ids,
672-
query_source_namespace=namespace,
673-
task_instance=task_instance,
674-
hook=self if use_external_connection else None,
675-
)
676-
677669
return None

providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import datetime
2020
import logging
2121
from contextlib import closing
22-
from typing import TYPE_CHECKING
22+
from typing import TYPE_CHECKING, Any
2323
from urllib.parse import quote, urlparse, urlunparse
2424

2525
from airflow.providers.common.compat.openlineage.check import require_openlineage_version
@@ -31,6 +31,7 @@
3131
from openlineage.client.facet_v2 import JobFacet
3232

3333
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
34+
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
3435

3536

3637
log = logging.getLogger(__name__)
@@ -204,9 +205,29 @@ def _run_single_query_with_hook(hook: SnowflakeHook, sql: str) -> list[dict]:
204205
return result
205206

206207

208+
def _run_single_query_with_api_hook(hook: SnowflakeSqlApiHook, sql: str) -> list[dict[str, Any]]:
209+
"""Execute a query against Snowflake API without adding extra logging or instrumentation."""
210+
# `hook.execute_query` resets the query_ids, so we need to save them and re-assign after we're done
211+
query_ids_before_execution = list(hook.query_ids)
212+
try:
213+
_query_ids = hook.execute_query(sql=sql, statement_count=0)
214+
hook.wait_for_query(query_id=_query_ids[0], raise_error=True, poll_interval=1, timeout=3)
215+
return hook.get_result_from_successful_sql_api_query(query_id=_query_ids[0])
216+
finally:
217+
hook.query_ids = query_ids_before_execution
218+
219+
220+
def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
221+
"""Convert 'START_TIME' and 'END_TIME' fields to UTC datetime objects."""
222+
for row in data:
223+
for key in ("START_TIME", "END_TIME"):
224+
row[key] = datetime.datetime.fromtimestamp(float(row[key]), timezone.utc)
225+
return data
226+
227+
207228
def _get_queries_details_from_snowflake(
208-
hook: SnowflakeHook, query_ids: list[str]
209-
) -> dict[str, dict[str, str]]:
229+
hook: SnowflakeHook | SnowflakeSqlApiHook, query_ids: list[str]
230+
) -> dict[str, dict[str, Any]]:
210231
"""Retrieve execution details for specific queries from Snowflake's query history."""
211232
if not query_ids:
212233
return {}
@@ -221,7 +242,16 @@ def _get_queries_details_from_snowflake(
221242
f";"
222243
)
223244

224-
result = _run_single_query_with_hook(hook=hook, sql=query)
245+
try:
246+
# Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports
247+
if hook.__class__.__name__ == "SnowflakeSqlApiHook":
248+
result = _run_single_query_with_api_hook(hook=hook, sql=query) # type: ignore[arg-type]
249+
result = _process_data_from_api(data=result)
250+
else:
251+
result = _run_single_query_with_hook(hook=hook, sql=query)
252+
except Exception as e:
253+
log.warning("OpenLineage could not retrieve extra metadata from Snowflake. Error encountered: %s", e)
254+
result = []
225255

226256
return {row["QUERY_ID"]: row for row in result} if result else {}
227257

@@ -259,17 +289,18 @@ def _create_snowflake_event_pair(
259289

260290
@require_openlineage_version(provider_min_version="2.3.0")
261291
def emit_openlineage_events_for_snowflake_queries(
262-
query_ids: list[str],
263-
query_source_namespace: str,
264292
task_instance,
265-
hook: SnowflakeHook | None = None,
293+
hook: SnowflakeHook | SnowflakeSqlApiHook | None = None,
294+
query_ids: list[str] | None = None,
295+
query_source_namespace: str | None = None,
296+
query_for_extra_metadata: bool = False,
266297
additional_run_facets: dict | None = None,
267298
additional_job_facets: dict | None = None,
268299
) -> None:
269300
"""
270301
Emit OpenLineage events for executed Snowflake queries.
271302
272-
Metadata retrieval from Snowflake is attempted only if a `SnowflakeHook` is provided.
303+
Metadata retrieval from Snowflake is attempted only if `get_extra_metadata` is True and hook is provided.
273304
If metadata is available, execution details such as start time, end time, execution status,
274305
error messages, and SQL text are included in the events. If no metadata is found, the function
275306
defaults to using the Airflow task instance's state and the current timestamp.
@@ -279,10 +310,16 @@ def emit_openlineage_events_for_snowflake_queries(
279310
will correspond to actual query execution times.
280311
281312
Args:
282-
query_ids: A list of Snowflake query IDs to emit events for.
283-
query_source_namespace: The namespace to be included in ExternalQueryRunFacet.
284313
task_instance: The Airflow task instance that run these queries.
285-
hook: A SnowflakeHook instance used to retrieve query metadata if available.
314+
hook: A supported Snowflake hook instance used to retrieve query metadata if available.
315+
If omitted, `query_ids` and `query_source_namespace` must be provided explicitly and
316+
`query_for_extra_metadata` must be `False`.
317+
query_ids: A list of Snowflake query IDs to emit events for, can only be None if `hook` is provided
318+
and `hook.query_ids` are present.
319+
query_source_namespace: The namespace to be included in ExternalQueryRunFacet,
320+
can be `None` only if hook is provided.
321+
query_for_extra_metadata: Whether to query Snowflake for additional metadata about queries.
322+
Must be `False` if `hook` is not provided.
286323
additional_run_facets: Additional run facets to include in OpenLineage events.
287324
additional_job_facets: Additional job facets to include in OpenLineage events.
288325
"""
@@ -297,23 +334,49 @@ def emit_openlineage_events_for_snowflake_queries(
297334
from airflow.providers.openlineage.conf import namespace
298335
from airflow.providers.openlineage.plugins.listener import get_openlineage_listener
299336

300-
if not query_ids:
301-
log.debug("No Snowflake query IDs provided; skipping OpenLineage event emission.")
302-
return
303-
304-
query_ids = [q for q in query_ids] # Make a copy to make sure it does not change
337+
log.info("OpenLineage will emit events for Snowflake queries.")
305338

306339
if hook:
340+
if not query_ids:
341+
log.debug("No Snowflake query IDs provided; Checking `hook.query_ids` property.")
342+
query_ids = getattr(hook, "query_ids", [])
343+
if not query_ids:
344+
raise ValueError("No Snowflake query IDs provided and `hook.query_ids` are not present.")
345+
346+
if not query_source_namespace:
347+
log.debug("No Snowflake query namespace provided; Creating one from scratch.")
348+
from airflow.providers.openlineage.sqlparser import SQLParser
349+
350+
connection = hook.get_connection(hook.get_conn_id())
351+
query_source_namespace = SQLParser.create_namespace(
352+
hook.get_openlineage_database_info(connection)
353+
)
354+
else:
355+
if not query_ids:
356+
raise ValueError("If 'hook' is not provided, 'query_ids' must be set.")
357+
if not query_source_namespace:
358+
raise ValueError("If 'hook' is not provided, 'query_source_namespace' must be set.")
359+
if query_for_extra_metadata:
360+
raise ValueError("If 'hook' is not provided, 'query_for_extra_metadata' must be False.")
361+
362+
query_ids = [q for q in query_ids] # Make a copy to make sure we do not change hook's attribute
363+
364+
if query_for_extra_metadata and hook:
307365
log.debug("Retrieving metadata for %s queries from Snowflake.", len(query_ids))
308366
snowflake_metadata = _get_queries_details_from_snowflake(hook, query_ids)
309367
else:
310-
log.debug("SnowflakeHook not provided. No extra metadata fill be fetched from Snowflake.")
368+
log.debug("`query_for_extra_metadata` is False. No extra metadata fill be fetched from Snowflake.")
311369
snowflake_metadata = {}
312370

313371
# If real metadata is unavailable, we send events with eventTime=now
314372
default_event_time = timezone.utcnow()
315373
# If no query metadata is provided, we use task_instance's state when checking for success
316-
default_state = task_instance.state.value if hasattr(task_instance, "state") else ""
374+
# ti.state has no `value` attr (AF2) when task it's still running, in AF3 we get 'running', in that case
375+
# assuming it's user call and query succeeded, so we replace it with success.
376+
default_state = (
377+
getattr(task_instance.state, "value", "running") if hasattr(task_instance, "state") else ""
378+
)
379+
default_state = "success" if default_state == "running" else default_state
317380

318381
common_run_facets = {"parent": _get_parent_run_facet(task_instance)}
319382
common_job_facets: dict[str, JobFacet] = {

providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -905,14 +905,17 @@ def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
905905
assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema"
906906
mock_get_first.assert_not_called()
907907

908-
def test_get_openlineage_database_specific_lineage_with_no_query_ids(self):
908+
@mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries")
909+
def test_get_openlineage_database_specific_lineage_with_no_query_ids(self, mock_emit):
909910
hook = SnowflakeHook(snowflake_conn_id="test_conn")
910911
assert hook.query_ids == []
911912

912913
result = hook.get_openlineage_database_specific_lineage(None)
914+
mock_emit.assert_not_called()
913915
assert result is None
914916

915-
def test_get_openlineage_database_specific_lineage_with_single_query_id(self):
917+
@mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries")
918+
def test_get_openlineage_database_specific_lineage_with_single_query_id(self, mock_emit):
916919
from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet
917920
from airflow.providers.openlineage.extractors import OperatorLineage
918921

@@ -921,21 +924,26 @@ def test_get_openlineage_database_specific_lineage_with_single_query_id(self):
921924
hook.get_connection = mock.MagicMock()
922925
hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme")
923926

924-
result = hook.get_openlineage_database_specific_lineage(None)
927+
ti = mock.MagicMock()
928+
929+
result = hook.get_openlineage_database_specific_lineage(ti)
930+
mock_emit.assert_called_once_with(
931+
**{
932+
"hook": hook,
933+
"query_ids": ["query1"],
934+
"query_source_namespace": "scheme://auth",
935+
"task_instance": ti,
936+
"query_for_extra_metadata": True,
937+
}
938+
)
925939
assert result == OperatorLineage(
926940
run_facets={
927941
"externalQuery": ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth")
928942
}
929943
)
930944

931-
@pytest.mark.parametrize("use_external_connection", [True, False])
932-
@mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection")
933945
@mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries")
934-
def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(
935-
self, mock_emit, mock_use_conn, use_external_connection
936-
):
937-
mock_use_conn.return_value = use_external_connection
938-
946+
def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(self, mock_emit):
939947
hook = SnowflakeHook(snowflake_conn_id="test_conn")
940948
hook.query_ids = ["query1", "query2"]
941949
hook.get_connection = mock.MagicMock()
@@ -944,23 +952,19 @@ def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(
944952
ti = mock.MagicMock()
945953

946954
result = hook.get_openlineage_database_specific_lineage(ti)
947-
mock_use_conn.assert_called_once()
948955
mock_emit.assert_called_once_with(
949956
**{
950-
"hook": hook if use_external_connection else None,
957+
"hook": hook,
951958
"query_ids": ["query1", "query2"],
952959
"query_source_namespace": "scheme://auth",
953960
"task_instance": ti,
961+
"query_for_extra_metadata": True,
954962
}
955963
)
956964
assert result is None
957965

958-
# emit_openlineage_events_for_snowflake_queries requires OL provider 2.0.0
959966
@mock.patch("importlib.metadata.version", return_value="1.99.0")
960-
@mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection")
961-
def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider(
962-
self, mock_use_conn, mock_version
963-
):
967+
def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider(self, mock_version):
964968
hook = SnowflakeHook(snowflake_conn_id="test_conn")
965969
hook.query_ids = ["query1", "query2"]
966970
hook.get_connection = mock.MagicMock()
@@ -972,7 +976,6 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider
972976
)
973977
with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err):
974978
hook.get_openlineage_database_specific_lineage(mock.MagicMock())
975-
mock_use_conn.assert_called_once()
976979

977980
@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Snowpark Python doesn't support Python 3.12 yet")
978981
@mock.patch("snowflake.snowpark.Session.builder")

0 commit comments

Comments
 (0)