diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 1fd22b86b70f9..ea791992d5c95 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -309,6 +309,14 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage | None: hook = self.get_db_hook() + try: + from airflow.providers.openlineage.utils.utils import should_use_external_connection + + use_external_connection = should_use_external_connection(hook) + except ImportError: + # OpenLineage provider release < 1.8.0 - we always use connection + use_external_connection = True + connection = hook.get_connection(getattr(hook, hook.conn_name_attr)) try: database_info = hook.get_openlineage_database_info(connection) @@ -334,6 +342,7 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage | None: database_info=database_info, database=self.database, sqlalchemy_engine=hook.get_sqlalchemy_engine(), + use_connection=use_external_connection, ) return operator_lineage diff --git a/airflow/providers/openlineage/sqlparser.py b/airflow/providers/openlineage/sqlparser.py index c27dedc53c5bd..f181ff8ccea01 100644 --- a/airflow/providers/openlineage/sqlparser.py +++ b/airflow/providers/openlineage/sqlparser.py @@ -29,6 +29,7 @@ ExtractionErrorRunFacet, SqlJobFacet, ) +from openlineage.client.run import Dataset from openlineage.common.sql import DbTableMeta, SqlMeta, parse from airflow.providers.openlineage.extractors.base import OperatorLineage @@ -40,7 +41,6 @@ from airflow.typing_compat import TypedDict if TYPE_CHECKING: - from openlineage.client.run import Dataset from sqlalchemy.engine import Engine from airflow.hooks.base import BaseHook @@ -104,6 +104,18 @@ class DatabaseInfo: normalize_name_method: Callable[[str], str] = default_normalize_name_method +def from_table_meta( + table_meta: DbTableMeta, database: str | None, namespace: str, is_uppercase: bool +) -> Dataset: + if table_meta.database: + name = table_meta.qualified_name + elif database: + name = f"{database}.{table_meta.schema}.{table_meta.name}" + else: + name = f"{table_meta.schema}.{table_meta.name}" + return Dataset(namespace=namespace, name=name if not is_uppercase else name.upper()) + + class SQLParser: """Interface for openlineage-sql. @@ -117,7 +129,7 @@ def __init__(self, dialect: str | None = None, default_schema: str | None = None def parse(self, sql: list[str] | str) -> SqlMeta | None: """Parse a single or a list of SQL statements.""" - return parse(sql=sql, dialect=self.dialect) + return parse(sql=sql, dialect=self.dialect, default_schema=self.default_schema) def parse_table_schemas( self, @@ -156,6 +168,23 @@ def parse_table_schemas( else None, ) + def get_metadata_from_parser( + self, + inputs: list[DbTableMeta], + outputs: list[DbTableMeta], + database_info: DatabaseInfo, + namespace: str = DEFAULT_NAMESPACE, + database: str | None = None, + ) -> tuple[list[Dataset], ...]: + database = database if database else database_info.database + return [ + from_table_meta(dataset, database, namespace, database_info.is_uppercase_names) + for dataset in inputs + ], [ + from_table_meta(dataset, database, namespace, database_info.is_uppercase_names) + for dataset in outputs + ] + def attach_column_lineage( self, datasets: list[Dataset], database: str | None, parse_result: SqlMeta ) -> None: @@ -204,6 +233,7 @@ def generate_openlineage_metadata_from_sql( database_info: DatabaseInfo, database: str | None = None, sqlalchemy_engine: Engine | None = None, + use_connection: bool = True, ) -> OperatorLineage: """Parse SQL statement(s) and generate OpenLineage metadata. @@ -242,15 +272,24 @@ def generate_openlineage_metadata_from_sql( ) namespace = self.create_namespace(database_info=database_info) - inputs, outputs = self.parse_table_schemas( - hook=hook, - inputs=parse_result.in_tables, - outputs=parse_result.out_tables, - namespace=namespace, - database=database, - database_info=database_info, - sqlalchemy_engine=sqlalchemy_engine, - ) + if use_connection: + inputs, outputs = self.parse_table_schemas( + hook=hook, + inputs=parse_result.in_tables, + outputs=parse_result.out_tables, + namespace=namespace, + database=database, + database_info=database_info, + sqlalchemy_engine=sqlalchemy_engine, + ) + else: + inputs, outputs = self.get_metadata_from_parser( + inputs=parse_result.in_tables, + outputs=parse_result.out_tables, + namespace=namespace, + database=database, + database_info=database_info, + ) self.attach_column_lineage(outputs, database or database_info.database, parse_result) diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index 1c777aff761c8..ad1f3b09518cd 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -384,3 +384,8 @@ def normalize_sql(sql: str | Iterable[str]): sql = [stmt for stmt in sql.split(";") if stmt != ""] sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""] return ";\n".join(sql) + + +def should_use_external_connection(hook) -> bool: + # TODO: Add checking overrides + return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"] diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 075ae21e3014a..a9b3ee5209d37 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -19,6 +19,7 @@ import os from contextlib import closing, contextmanager +from functools import cached_property from io import StringIO from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload @@ -177,6 +178,7 @@ def _get_field(self, extra_dict, field_name): return extra_dict[field_name] or None return extra_dict.get(backcompat_key) or None + @cached_property def _get_conn_params(self) -> dict[str, str | None]: """Fetch connection params as a dict. @@ -269,7 +271,7 @@ def _get_conn_params(self) -> dict[str, str | None]: def get_uri(self) -> str: """Override DbApiHook get_uri method for get_sqlalchemy_engine().""" - conn_params = self._get_conn_params() + conn_params = self._get_conn_params return self._conn_params_to_sqlalchemy_uri(conn_params) def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: @@ -283,7 +285,7 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: def get_conn(self) -> SnowflakeConnection: """Return a snowflake.connection object.""" - conn_config = self._get_conn_params() + conn_config = self._get_conn_params conn = connector.connect(**conn_config) return conn @@ -294,7 +296,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): :return: the created engine. """ engine_kwargs = engine_kwargs or {} - conn_params = self._get_conn_params() + conn_params = self._get_conn_params if "insecure_mode" in conn_params: engine_kwargs.setdefault("connect_args", {}) engine_kwargs["connect_args"]["insecure_mode"] = True @@ -458,21 +460,7 @@ def get_openlineage_database_dialect(self, _) -> str: return "snowflake" def get_openlineage_default_schema(self) -> str | None: - """ - Attempt to get current schema. - - Usually ``SELECT CURRENT_SCHEMA();`` should work. - However, apparently you may set ``database`` without ``schema`` - and get results from ``SELECT CURRENT_SCHEMAS();`` but not - from ``SELECT CURRENT_SCHEMA();``. - It still may return nothing if no database is set in connection. - """ - schema = self._get_conn_params()["schema"] - if not schema: - current_schemas = self.get_first("SELECT PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0] - if current_schemas: - _, schema = current_schemas.split(".") - return schema + return self._get_conn_params["schema"] def _get_openlineage_authority(self, _) -> str: from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 6eec055eb578e..3f52a43a1d88d 100644 --- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -86,7 +86,7 @@ def __init__( @property def account_identifier(self) -> str: """Returns snowflake account identifier.""" - conn_config = self._get_conn_params() + conn_config = self._get_conn_params account_identifier = f"https://{conn_config['account']}" if conn_config["region"]: @@ -147,7 +147,7 @@ def execute_query( When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. """ - conn_config = self._get_conn_params() + conn_config = self._get_conn_params req_id = uuid.uuid4() url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements" @@ -186,7 +186,7 @@ def execute_query( def get_headers(self) -> dict[str, Any]: """Form auth headers based on either OAuth token or JWT token from private key.""" - conn_config = self._get_conn_params() + conn_config = self._get_conn_params # Use OAuth if refresh_token and client_id and client_secret are provided if all( @@ -225,7 +225,7 @@ def get_headers(self) -> dict[str, Any]: def get_oauth_token(self) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" - conn_config = self._get_conn_params() + conn_config = self._get_conn_params url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request" data = { "grant_type": "refresh_token", diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index d1c6e261517f5..586172c3b87a9 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -158,7 +158,8 @@ def get_db_hook(self): "SVV_REDSHIFT_COLUMNS.data_type, " "SVV_REDSHIFT_COLUMNS.database_name \n" "FROM SVV_REDSHIFT_COLUMNS \n" - "WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') " + "WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' " + "AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') " "OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' " "AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND " "SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')" @@ -171,7 +172,8 @@ def get_db_hook(self): "SVV_REDSHIFT_COLUMNS.data_type, " "SVV_REDSHIFT_COLUMNS.database_name \n" "FROM SVV_REDSHIFT_COLUMNS \n" - "WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')" + "WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' " + "AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')" ), ] diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index fb5b1a5514167..54a18eeca7336 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -270,7 +270,7 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri( ): with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == expected_uri - assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == expected_conn_params + assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == expected_conn_params def test_get_conn_params_should_support_private_auth_in_connection( self, encrypted_temporary_private_key: Path @@ -288,7 +288,7 @@ def test_get_conn_params_should_support_private_auth_in_connection( }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params @pytest.mark.parametrize("include_params", [True, False]) def test_hook_param_beats_extra(self, include_params): @@ -311,7 +311,7 @@ def test_hook_param_beats_extra(self, include_params): assert hook_params != extras assert SnowflakeHook( snowflake_conn_id="test_conn", **(hook_params if include_params else {}) - )._get_conn_params() == { + )._get_conn_params == { "user": None, "password": "", "application": "AIRFLOW", @@ -340,7 +340,7 @@ def test_extra_short_beats_long(self, include_unprefixed): ).get_uri(), ): assert list(extras.values()) != list(extras_prefixed.values()) - assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == { + assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == { "user": None, "password": "", "application": "AIRFLOW", @@ -366,7 +366,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key( }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params def test_get_conn_params_should_support_private_auth_with_unencrypted_key( self, non_encrypted_temporary_private_key @@ -384,15 +384,15 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key( }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params connection_kwargs["password"] = "" with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params connection_kwargs["password"] = _PASSWORD with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() ), pytest.raises(TypeError, match="Password was given but private key is not encrypted."): - SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() + SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params def test_get_conn_params_should_fail_on_invalid_key(self): connection_kwargs = { @@ -419,8 +419,7 @@ def test_should_add_partner_info(self): AIRFLOW_SNOWFLAKE_PARTNER="PARTNER_NAME", ): assert ( - SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()["application"] - == "PARTNER_NAME" + SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params["application"] == "PARTNER_NAME" ) def test_get_conn_should_call_connect(self): @@ -429,7 +428,7 @@ def test_get_conn_should_call_connect(self): ), mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as mock_connector: hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_conn() - mock_connector.connect.assert_called_once_with(**hook._get_conn_params()) + mock_connector.connect.assert_called_once_with(**hook._get_conn_params) assert mock_connector.connect.return_value == conn def test_get_sqlalchemy_engine_should_support_pass_auth(self): @@ -516,7 +515,7 @@ def test_hook_parameters_should_take_precedence(self): "session_parameters": {"AA": "AAA"}, "user": "user", "warehouse": "TEST_WAREHOUSE", - } == hook._get_conn_params() + } == hook._get_conn_params assert ( "snowflake://user:pw@TEST_ACCOUNT.TEST_REGION/TEST_DATABASE/TEST_SCHEMA" "?application=AIRFLOW&authenticator=TEST_AUTH&role=TEST_ROLE&warehouse=TEST_WAREHOUSE" @@ -587,22 +586,14 @@ def test_empty_sql_parameter(self): hook.run(sql=empty_statement) assert err.value.args[0] == "List of SQL statements is empty" - @pytest.mark.parametrize( - "returned_schema,expected_schema", - [([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")], - ) - @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first") - def test_get_openlineage_default_schema_with_no_schema_set( - self, mock_get_first, returned_schema, expected_schema - ): + def test_get_openlineage_default_schema_with_no_schema_set(self): connection_kwargs = { **BASE_CONNECTION_KWARGS, - "schema": None, + "schema": "PUBLIC", } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): hook = SnowflakeHook(snowflake_conn_id="test_conn") - mock_get_first.return_value = returned_schema - assert hook.get_openlineage_default_schema() == expected_schema + assert hook.get_openlineage_default_schema() == "PUBLIC" @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first") def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first): diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py index bfd978755e6dd..5ba18d6e127ed 100644 --- a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py +++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py @@ -20,7 +20,7 @@ import uuid from typing import TYPE_CHECKING, Any from unittest import mock -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, PropertyMock import pytest import requests @@ -168,7 +168,10 @@ class TestSnowflakeSqlApiHook: ], ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers") def test_execute_query( self, @@ -197,7 +200,10 @@ def test_execute_query( [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers") def test_execute_query_exception_without_statement_handle( self, @@ -262,7 +268,10 @@ def test_check_query_output_exception(self, mock_geturl_header_params, query_ids with pytest.raises(AirflowException, match='Response: {"foo": "bar"}, Status Code: 500'): hook.check_query_output(query_ids) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers") def test_get_request_url_header_params(self, mock_get_header, mock_conn_param): """Test get_request_url_header_params by mocking _get_conn_params and get_headers""" @@ -274,7 +283,10 @@ def test_get_request_url_header_params(self, mock_get_header, mock_conn_param): assert url == "https://airflow.af_region.snowflakecomputing.com/api/v2/statements/uuid" @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_private_key") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) @mock.patch("airflow.providers.snowflake.utils.sql_api_generate_jwt.JWTGenerator.get_token") def test_get_headers_should_support_private_key(self, mock_get_token, mock_conn_param, mock_private_key): """Test get_headers method by mocking get_private_key and _get_conn_params method""" @@ -285,7 +297,10 @@ def test_get_headers_should_support_private_key(self, mock_get_token, mock_conn_ assert result == HEADERS @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_oauth_token") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_token): """Test get_headers method by mocking get_oauth_token and _get_conn_params method""" mock_conn_param.return_value = CONN_PARAMS_OAUTH @@ -296,7 +311,10 @@ def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_toke @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.HTTPBasicAuth") @mock.patch("requests.post") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): """Test get_oauth_token method makes the right http request""" BASIC_AUTH = {"Authorization": "Basic usernamepassword"} diff --git a/tests/providers/snowflake/operators/test_snowflake_sql.py b/tests/providers/snowflake/operators/test_snowflake_sql.py index a79064e341994..87d77ca813df7 100644 --- a/tests/providers/snowflake/operators/test_snowflake_sql.py +++ b/tests/providers/snowflake/operators/test_snowflake_sql.py @@ -17,7 +17,8 @@ # under the License. from __future__ import annotations -from unittest.mock import MagicMock, call, patch +from unittest import mock +from unittest.mock import MagicMock, patch import pytest from _pytest.outcomes import importorskip @@ -37,8 +38,6 @@ def Row(*args, **kwargs): ColumnLineageDatasetFacet, ColumnLineageDatasetFacetFieldsAdditional, ColumnLineageDatasetFacetFieldsAdditionalInputFields, - SchemaDatasetFacet, - SchemaField, SqlJobFacet, ) from openlineage.client.run import Dataset @@ -163,7 +162,9 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc ) -def test_execute_openlineage_events(): +@mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection") +def test_execute_openlineage_events(should_use_external_connection): + should_use_external_connection.return_value = False DB_NAME = "DATABASE" DB_SCHEMA_NAME = "PUBLIC" @@ -174,9 +175,6 @@ class SnowflakeHookForTests(SnowflakeHook): get_conn = MagicMock(name="conn") get_connection = MagicMock() - def get_first(self, *_): - return [f"{DB_NAME}.{DB_SCHEMA_NAME}"] - dbapi_hook = SnowflakeHookForTests() class SnowflakeOperatorForTest(SnowflakeOperator): @@ -185,7 +183,7 @@ def get_db_hook(self): sql = ( "INSERT INTO Test_table\n" - "SELECT t1.*, t2.additional_constant FROM ANOTHER_db.another_schema.popular_orders_day_of_week t1\n" + "SELECT t1.*, t2.additional_constant FROM ANOTHER_DB.ANOTHER_SCHEMA.popular_orders_day_of_week t1\n" "JOIN little_table t2 ON t1.order_day_of_week = t2.order_day_of_week;\n" "FORGOT TO COMMENT" ) @@ -223,6 +221,7 @@ def get_db_hook(self): dbapi_hook.get_connection.return_value = Connection( conn_id="snowflake_default", conn_type="snowflake", + schema="PUBLIC", extra={ "account": "test_account", "region": "us-east", @@ -233,55 +232,17 @@ def get_db_hook(self): dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = rows lineage = op.get_openlineage_facets_on_start() - assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [ - call( - "SELECT database.information_schema.columns.table_schema, database.information_schema.columns.table_name, " - "database.information_schema.columns.column_name, database.information_schema.columns.ordinal_position, " - "database.information_schema.columns.data_type, database.information_schema.columns.table_catalog \n" - "FROM database.information_schema.columns \n" - "WHERE database.information_schema.columns.table_name IN ('LITTLE_TABLE') " - "UNION ALL " - "SELECT another_db.information_schema.columns.table_schema, another_db.information_schema.columns.table_name, " - "another_db.information_schema.columns.column_name, another_db.information_schema.columns.ordinal_position, " - "another_db.information_schema.columns.data_type, another_db.information_schema.columns.table_catalog \n" - "FROM another_db.information_schema.columns \n" - "WHERE another_db.information_schema.columns.table_schema = 'ANOTHER_SCHEMA' " - "AND another_db.information_schema.columns.table_name IN ('POPULAR_ORDERS_DAY_OF_WEEK')" - ), - call( - "SELECT database.information_schema.columns.table_schema, database.information_schema.columns.table_name, " - "database.information_schema.columns.column_name, database.information_schema.columns.ordinal_position, " - "database.information_schema.columns.data_type, database.information_schema.columns.table_catalog \n" - "FROM database.information_schema.columns \n" - "WHERE database.information_schema.columns.table_name IN ('TEST_TABLE')" - ), - ] + # Not calling Snowflake + assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [] assert lineage.inputs == [ Dataset( namespace="snowflake://test_account.us-east.aws", - name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.POPULAR_ORDERS_DAY_OF_WEEK", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"), - SchemaField(name="ORDER_PLACED_ON", type="TIMESTAMP_NTZ"), - SchemaField(name="ORDERS_PLACED", type="NUMBER"), - ] - ) - }, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.LITTLE_TABLE", ), Dataset( namespace="snowflake://test_account.us-east.aws", - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.LITTLE_TABLE", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"), - SchemaField(name="ADDITIONAL_CONSTANT", type="TEXT"), - ] - ) - }, + name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.POPULAR_ORDERS_DAY_OF_WEEK", ), ] assert lineage.outputs == [ @@ -289,14 +250,6 @@ def get_db_hook(self): namespace="snowflake://test_account.us-east.aws", name=f"{DB_NAME}.{DB_SCHEMA_NAME}.TEST_TABLE", facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"), - SchemaField(name="ORDER_PLACED_ON", type="TIMESTAMP_NTZ"), - SchemaField(name="ORDERS_PLACED", type="NUMBER"), - SchemaField(name="ADDITIONAL_CONSTANT", type="TEXT"), - ] - ), "columnLineage": ColumnLineageDatasetFacet( fields={ "additional_constant": ColumnLineageDatasetFacetFieldsAdditional(