From 7c0dbcf9d7af3d6aa61fd11c8a294474a3cd1339 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Tue, 4 Jul 2023 13:26:14 +0200 Subject: [PATCH 1/3] Add support for both database and schema in PostgresHook. Signed-off-by: Jakub Dardzinski --- airflow/providers/postgres/hooks/postgres.py | 20 ++++- generated/provider_dependencies.json | 3 +- .../postgres/operators/test_postgres.py | 79 ++++++++++++++++++- 3 files changed, 99 insertions(+), 3 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 2e76029028da8..a8bdf502f156a 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -21,7 +21,7 @@ import warnings from contextlib import closing from copy import deepcopy -from typing import Any, Iterable, Union +from typing import TYPE_CHECKING, Any, Iterable, Union import psycopg2 import psycopg2.extensions @@ -33,6 +33,9 @@ from airflow.models.connection import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook +if TYPE_CHECKING: + from airflow.providers.openlineage.sqlparser import DatabaseInfo + CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor] @@ -317,3 +320,18 @@ def _generate_insert_sql( sql += f"{on_conflict_str} DO NOTHING" return sql + + def get_openlineage_database_info(self, connection) -> DatabaseInfo: + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme="postgres", + authority=DbApiHook.get_openlineage_authority_part(connection), + database=self.database or connection.schema, + ) + + def get_openlineage_database_dialect(self, _) -> str: + return "postgres" + + def get_openlineage_default_schema(self) -> str | None: + return self.get_first("SELECT CURRENT_SCHEMA;")[0] diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index f4eeb8a273415..2ce63928cd9c2 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -710,7 +710,8 @@ ], "cross-providers-deps": [ "amazon", - "common.sql" + "common.sql", + "openlineage" ], "excluded-python-versions": [] }, diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index b58abf3497dc2..9f467dcf7de9e 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -20,6 +20,7 @@ import pytest from airflow.models.dag import DAG +from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.providers.postgres.operators.postgres import PostgresOperator from airflow.utils import timezone @@ -38,7 +39,6 @@ def setup_method(self): def teardown_method(self): tables_to_drop = ["test_postgres_to_postgres", "test_airflow"] - from airflow.providers.postgres.hooks.postgres import PostgresHook with PostgresHook().get_conn() as conn: with conn.cursor() as cur: @@ -113,3 +113,80 @@ def test_runtime_parameter_setting(self): ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert op.get_db_hook().get_first("SHOW statement_timeout;")[0] == "3s" + + +@pytest.mark.backend("postgres") +class TestPostgresOpenLineage: + custom_schemas = ["another_schema"] + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + dag = DAG(TEST_DAG_ID, default_args=args) + self.dag = dag + + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + for schema in self.custom_schemas: + cur.execute(f"CREATE SCHEMA {schema}") + + def teardown_method(self): + tables_to_drop = ["test_postgres_to_postgres", "test_airflow"] + + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + for table in tables_to_drop: + cur.execute(f"DROP TABLE IF EXISTS {table}") + for schema in self.custom_schemas: + cur.execute(f"DROP SCHEMA {schema} CASCADE") + + def test_postgres_operator_openlineage_implicit_schema(self): + sql = """ + CREATE TABLE IF NOT EXISTS test_airflow ( + dummy VARCHAR(50) + ); + """ + op = PostgresOperator( + task_id="basic_postgres", + sql=sql, + dag=self.dag, + hook_params={"options": "-c search_path=another_schema"}, + ) + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 0 + assert len(lineage.outputs) == 0 + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # OpenLineage provider runs same method on complete by default + lineage_on_complete = op.get_openlineage_facets_on_start() + assert len(lineage_on_complete.inputs) == 0 + assert len(lineage_on_complete.outputs) == 1 + assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:None" + assert lineage_on_complete.outputs[0].name == "airflow.another_schema.test_airflow" + assert "schema" in lineage_on_complete.outputs[0].facets + + def test_postgres_operator_openlineage_explicit_schema(self): + sql = """ + CREATE TABLE IF NOT EXISTS public.test_airflow ( + dummy VARCHAR(50) + ); + """ + op = PostgresOperator( + task_id="basic_postgres", + sql=sql, + dag=self.dag, + hook_params={"options": "-c search_path=another_schema"}, + ) + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 0 + assert len(lineage.outputs) == 0 + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # OpenLineage provider runs same method on complete by default + lineage_on_complete = op.get_openlineage_facets_on_start() + assert len(lineage_on_complete.inputs) == 0 + assert len(lineage_on_complete.outputs) == 1 + assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:None" + assert lineage_on_complete.outputs[0].name == "airflow.public.test_airflow" + assert "schema" in lineage_on_complete.outputs[0].facets From b060d026c77b689114ee8ff740352fafdcc4ac69 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Tue, 4 Jul 2023 23:30:06 +0200 Subject: [PATCH 2/3] Add docstrings. Fix selective checks. Signed-off-by: Jakub Dardzinski --- airflow/providers/postgres/hooks/postgres.py | 3 +++ dev/breeze/tests/test_selective_checks.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index a8bdf502f156a..475690ec06fff 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -322,6 +322,7 @@ def _generate_insert_sql( return sql def get_openlineage_database_info(self, connection) -> DatabaseInfo: + """Returns Postgres specific information for OpenLineage.""" from airflow.providers.openlineage.sqlparser import DatabaseInfo return DatabaseInfo( @@ -331,7 +332,9 @@ def get_openlineage_database_info(self, connection) -> DatabaseInfo: ) def get_openlineage_database_dialect(self, _) -> str: + """Returns postgres dialect.""" return "postgres" def get_openlineage_default_schema(self) -> str | None: + """Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter.""" return self.get_first("SELECT CURRENT_SCHEMA;")[0] diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 7437642e4e9b6..148b485cd3db1 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -98,7 +98,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "tests/providers/postgres/file.py", ), { - "affected-providers-list-as-string": "amazon common.sql google postgres", + "affected-providers-list-as-string": "amazon common.sql google openlineage postgres", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -110,7 +110,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "docs-build": "true", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Providers[amazon] " - "API Always Providers[common.sql,postgres] Providers[google]", + "API Always Providers[common.sql,openlineage,postgres] Providers[google]", }, id="API and providers tests and docs should run", ) @@ -164,7 +164,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "tests/providers/postgres/file.py", ), { - "affected-providers-list-as-string": "amazon common.sql google postgres", + "affected-providers-list-as-string": "amazon common.sql google openlineage postgres", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -177,7 +177,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-kubernetes-tests": "true", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Providers[amazon] " - "Always Providers[common.sql,postgres] Providers[google]", + "Always Providers[common.sql,openlineage,postgres] Providers[google]", }, id="Helm tests, providers (both upstream and downstream)," "kubernetes tests and docs should run", From 2b65e355e4bc35d6dae8ae6c81fa097f2f95fa79 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Fri, 21 Jul 2023 14:08:26 +0200 Subject: [PATCH 3/3] Add OpenLineage support for Redshift connections. Signed-off-by: Jakub Dardzinski Add default port for Postgres connection. Signed-off-by: Jakub Dardzinski --- airflow/providers/postgres/hooks/postgres.py | 37 +++++++++++--- .../providers/postgres/hooks/test_postgres.py | 49 +++++++++++++++++++ .../postgres/operators/test_postgres.py | 4 +- 3 files changed, 82 insertions(+), 8 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 475690ec06fff..87edd3c87c1dc 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -322,18 +322,43 @@ def _generate_insert_sql( return sql def get_openlineage_database_info(self, connection) -> DatabaseInfo: - """Returns Postgres specific information for OpenLineage.""" + """Returns Postgres/Redshift specific information for OpenLineage.""" from airflow.providers.openlineage.sqlparser import DatabaseInfo + is_redshift = connection.extra_dejson.get("redshift", False) + + if is_redshift: + authority = self._get_openlineage_redshift_authority_part(connection) + else: + authority = DbApiHook.get_openlineage_authority_part(connection, default_port=5432) + return DatabaseInfo( - scheme="postgres", - authority=DbApiHook.get_openlineage_authority_part(connection), + scheme="postgres" if not is_redshift else "redshift", + authority=authority, database=self.database or connection.schema, ) - def get_openlineage_database_dialect(self, _) -> str: - """Returns postgres dialect.""" - return "postgres" + def _get_openlineage_redshift_authority_part(self, connection) -> str: + try: + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + except ImportError: + from airflow.exceptions import AirflowException + + raise AirflowException( + "apache-airflow-providers-amazon not installed, run: " + "pip install 'apache-airflow-providers-postgres[amazon]'." + ) + aws_conn_id = connection.extra_dejson.get("aws_conn_id", "aws_default") + + port = connection.port or 5439 + cluster_identifier = connection.extra_dejson.get("cluster-identifier", connection.host.split(".")[0]) + region_name = AwsBaseHook(aws_conn_id=aws_conn_id).region_name + + return f"{cluster_identifier}.{region_name}:{port}" + + def get_openlineage_database_dialect(self, connection) -> str: + """Returns postgres/redshift dialect.""" + return "redshift" if connection.extra_dejson.get("redshift", False) else "postgres" def get_openlineage_default_schema(self) -> str | None: """Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter.""" diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index a6fa619634c45..d90b0def66722 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -269,6 +269,55 @@ def test_schema_kwarg_database_kwarg_compatibility(self): hook = PostgresHook(schema=database) assert hook.database == database + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) + @pytest.mark.parametrize("port", [5432, 5439, None]) + @pytest.mark.parametrize( + "host,conn_cluster_identifier,expected_host", + [ + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + NOTSET, + "cluster-identifier.us-east-1", + ), + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + "different-identifier", + "different-identifier.us-east-1", + ), + ], + ) + def test_openlineage_methods_with_redshift( + self, + mock_aws_hook_class, + aws_conn_id, + port, + host, + conn_cluster_identifier, + expected_host, + ): + mock_conn_extra = { + "iam": True, + "redshift": True, + } + if aws_conn_id is not NOTSET: + mock_conn_extra["aws_conn_id"] = aws_conn_id + if conn_cluster_identifier is not NOTSET: + mock_conn_extra["cluster-identifier"] = conn_cluster_identifier + + self.connection.extra = json.dumps(mock_conn_extra) + self.connection.host = host + self.connection.port = port + + # Mock AWS Connection + mock_aws_hook_instance = mock_aws_hook_class.return_value + mock_aws_hook_instance.region_name = "us-east-1" + + assert ( + self.db_hook._get_openlineage_redshift_authority_part(self.connection) + == f"{expected_host}:{port or 5439}" + ) + @pytest.mark.backend("postgres") class TestPostgresHook: diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index 9f467dcf7de9e..5bf8ee09360f6 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -161,7 +161,7 @@ def test_postgres_operator_openlineage_implicit_schema(self): lineage_on_complete = op.get_openlineage_facets_on_start() assert len(lineage_on_complete.inputs) == 0 assert len(lineage_on_complete.outputs) == 1 - assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:None" + assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:5432" assert lineage_on_complete.outputs[0].name == "airflow.another_schema.test_airflow" assert "schema" in lineage_on_complete.outputs[0].facets @@ -187,6 +187,6 @@ def test_postgres_operator_openlineage_explicit_schema(self): lineage_on_complete = op.get_openlineage_facets_on_start() assert len(lineage_on_complete.inputs) == 0 assert len(lineage_on_complete.outputs) == 1 - assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:None" + assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:5432" assert lineage_on_complete.outputs[0].name == "airflow.public.test_airflow" assert "schema" in lineage_on_complete.outputs[0].facets