diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 2e76029028da8..87edd3c87c1dc 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -21,7 +21,7 @@ import warnings from contextlib import closing from copy import deepcopy -from typing import Any, Iterable, Union +from typing import TYPE_CHECKING, Any, Iterable, Union import psycopg2 import psycopg2.extensions @@ -33,6 +33,9 @@ from airflow.models.connection import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook +if TYPE_CHECKING: + from airflow.providers.openlineage.sqlparser import DatabaseInfo + CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor] @@ -317,3 +320,46 @@ def _generate_insert_sql( sql += f"{on_conflict_str} DO NOTHING" return sql + + def get_openlineage_database_info(self, connection) -> DatabaseInfo: + """Returns Postgres/Redshift specific information for OpenLineage.""" + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + is_redshift = connection.extra_dejson.get("redshift", False) + + if is_redshift: + authority = self._get_openlineage_redshift_authority_part(connection) + else: + authority = DbApiHook.get_openlineage_authority_part(connection, default_port=5432) + + return DatabaseInfo( + scheme="postgres" if not is_redshift else "redshift", + authority=authority, + database=self.database or connection.schema, + ) + + def _get_openlineage_redshift_authority_part(self, connection) -> str: + try: + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + except ImportError: + from airflow.exceptions import AirflowException + + raise AirflowException( + "apache-airflow-providers-amazon not installed, run: " + "pip install 'apache-airflow-providers-postgres[amazon]'." + ) + aws_conn_id = connection.extra_dejson.get("aws_conn_id", "aws_default") + + port = connection.port or 5439 + cluster_identifier = connection.extra_dejson.get("cluster-identifier", connection.host.split(".")[0]) + region_name = AwsBaseHook(aws_conn_id=aws_conn_id).region_name + + return f"{cluster_identifier}.{region_name}:{port}" + + def get_openlineage_database_dialect(self, connection) -> str: + """Returns postgres/redshift dialect.""" + return "redshift" if connection.extra_dejson.get("redshift", False) else "postgres" + + def get_openlineage_default_schema(self) -> str | None: + """Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter.""" + return self.get_first("SELECT CURRENT_SCHEMA;")[0] diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 7437642e4e9b6..148b485cd3db1 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -98,7 +98,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "tests/providers/postgres/file.py", ), { - "affected-providers-list-as-string": "amazon common.sql google postgres", + "affected-providers-list-as-string": "amazon common.sql google openlineage postgres", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -110,7 +110,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "docs-build": "true", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Providers[amazon] " - "API Always Providers[common.sql,postgres] Providers[google]", + "API Always Providers[common.sql,openlineage,postgres] Providers[google]", }, id="API and providers tests and docs should run", ) @@ -164,7 +164,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "tests/providers/postgres/file.py", ), { - "affected-providers-list-as-string": "amazon common.sql google postgres", + "affected-providers-list-as-string": "amazon common.sql google openlineage postgres", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -177,7 +177,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-kubernetes-tests": "true", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Providers[amazon] " - "Always Providers[common.sql,postgres] Providers[google]", + "Always Providers[common.sql,openlineage,postgres] Providers[google]", }, id="Helm tests, providers (both upstream and downstream)," "kubernetes tests and docs should run", diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index f4eeb8a273415..2ce63928cd9c2 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -710,7 +710,8 @@ ], "cross-providers-deps": [ "amazon", - "common.sql" + "common.sql", + "openlineage" ], "excluded-python-versions": [] }, diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index a6fa619634c45..d90b0def66722 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -269,6 +269,55 @@ def test_schema_kwarg_database_kwarg_compatibility(self): hook = PostgresHook(schema=database) assert hook.database == database + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) + @pytest.mark.parametrize("port", [5432, 5439, None]) + @pytest.mark.parametrize( + "host,conn_cluster_identifier,expected_host", + [ + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + NOTSET, + "cluster-identifier.us-east-1", + ), + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + "different-identifier", + "different-identifier.us-east-1", + ), + ], + ) + def test_openlineage_methods_with_redshift( + self, + mock_aws_hook_class, + aws_conn_id, + port, + host, + conn_cluster_identifier, + expected_host, + ): + mock_conn_extra = { + "iam": True, + "redshift": True, + } + if aws_conn_id is not NOTSET: + mock_conn_extra["aws_conn_id"] = aws_conn_id + if conn_cluster_identifier is not NOTSET: + mock_conn_extra["cluster-identifier"] = conn_cluster_identifier + + self.connection.extra = json.dumps(mock_conn_extra) + self.connection.host = host + self.connection.port = port + + # Mock AWS Connection + mock_aws_hook_instance = mock_aws_hook_class.return_value + mock_aws_hook_instance.region_name = "us-east-1" + + assert ( + self.db_hook._get_openlineage_redshift_authority_part(self.connection) + == f"{expected_host}:{port or 5439}" + ) + @pytest.mark.backend("postgres") class TestPostgresHook: diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index b58abf3497dc2..5bf8ee09360f6 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -20,6 +20,7 @@ import pytest from airflow.models.dag import DAG +from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.providers.postgres.operators.postgres import PostgresOperator from airflow.utils import timezone @@ -38,7 +39,6 @@ def setup_method(self): def teardown_method(self): tables_to_drop = ["test_postgres_to_postgres", "test_airflow"] - from airflow.providers.postgres.hooks.postgres import PostgresHook with PostgresHook().get_conn() as conn: with conn.cursor() as cur: @@ -113,3 +113,80 @@ def test_runtime_parameter_setting(self): ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert op.get_db_hook().get_first("SHOW statement_timeout;")[0] == "3s" + + +@pytest.mark.backend("postgres") +class TestPostgresOpenLineage: + custom_schemas = ["another_schema"] + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + dag = DAG(TEST_DAG_ID, default_args=args) + self.dag = dag + + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + for schema in self.custom_schemas: + cur.execute(f"CREATE SCHEMA {schema}") + + def teardown_method(self): + tables_to_drop = ["test_postgres_to_postgres", "test_airflow"] + + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + for table in tables_to_drop: + cur.execute(f"DROP TABLE IF EXISTS {table}") + for schema in self.custom_schemas: + cur.execute(f"DROP SCHEMA {schema} CASCADE") + + def test_postgres_operator_openlineage_implicit_schema(self): + sql = """ + CREATE TABLE IF NOT EXISTS test_airflow ( + dummy VARCHAR(50) + ); + """ + op = PostgresOperator( + task_id="basic_postgres", + sql=sql, + dag=self.dag, + hook_params={"options": "-c search_path=another_schema"}, + ) + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 0 + assert len(lineage.outputs) == 0 + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # OpenLineage provider runs same method on complete by default + lineage_on_complete = op.get_openlineage_facets_on_start() + assert len(lineage_on_complete.inputs) == 0 + assert len(lineage_on_complete.outputs) == 1 + assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:5432" + assert lineage_on_complete.outputs[0].name == "airflow.another_schema.test_airflow" + assert "schema" in lineage_on_complete.outputs[0].facets + + def test_postgres_operator_openlineage_explicit_schema(self): + sql = """ + CREATE TABLE IF NOT EXISTS public.test_airflow ( + dummy VARCHAR(50) + ); + """ + op = PostgresOperator( + task_id="basic_postgres", + sql=sql, + dag=self.dag, + hook_params={"options": "-c search_path=another_schema"}, + ) + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 0 + assert len(lineage.outputs) == 0 + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # OpenLineage provider runs same method on complete by default + lineage_on_complete = op.get_openlineage_facets_on_start() + assert len(lineage_on_complete.inputs) == 0 + assert len(lineage_on_complete.outputs) == 1 + assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:5432" + assert lineage_on_complete.outputs[0].name == "airflow.public.test_airflow" + assert "schema" in lineage_on_complete.outputs[0].facets