From 04ab11b32d1629049053efc5c1eb6f8ecdfc35d5 Mon Sep 17 00:00:00 2001 From: Dan Date: Thu, 23 Feb 2023 11:26:04 -0700 Subject: [PATCH 1/2] squash add pg 3 part id support --- sqeleton/databases/postgresql.py | 25 +++++++++++++++++++++++++ tests/test_database.py | 27 +++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/sqeleton/databases/postgresql.py b/sqeleton/databases/postgresql.py index 7c7b47d..a7034f1 100644 --- a/sqeleton/databases/postgresql.py +++ b/sqeleton/databases/postgresql.py @@ -1,4 +1,5 @@ from ..abcs.database_types import ( + DbPath, Timestamp, TimestampTZ, Float, @@ -122,3 +123,27 @@ def create_connection(self): return c except pg.OperationalError as e: raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/tests/test_database.py b/tests/test_database.py index 069f75d..f04a205 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -80,3 +80,30 @@ def test_current_timestamp(self): db = get_conn(self.db_cls) res = db.query(current_timestamp(), datetime) assert isinstance(res, datetime), (res, type(res)) + + +@test_each_database +class TestThreePartIds(unittest.TestCase): + def test_three_part_support(self): + if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake]: + self.skipTest('Limited support for 3 part ids') + + table_name = "tbl_" + random_table_suffix() + db = get_conn(self.db_cls) + db_res = db.query("SELECT CURRENT_DATABASE()") + schema_res = db.query("SELECT CURRENT_SCHEMA()") + db_name = db_res.rows[0][0] + schema_name = schema_res.rows[0][0] + + table_one_part = table((table_name,), schema={"id": int}) + table_two_part = table((schema_name, table_name), schema={"id": int}) + table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) + + db.query(table_one_part.create()) + db.query(table_one_part.drop()) + + db.query(table_two_part.create()) + db.query(table_two_part.drop()) + + db.query(table_three_part.create()) + db.query(table_three_part.drop()) From 62a26056c2297c5039c2d570ad06b9916efd9129 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 24 Feb 2023 16:20:41 +0100 Subject: [PATCH 2/2] Small improvement to tests --- tests/test_database.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/test_database.py b/tests/test_database.py index f04a205..fb6c074 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -86,7 +86,7 @@ def test_current_timestamp(self): class TestThreePartIds(unittest.TestCase): def test_three_part_support(self): if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake]: - self.skipTest('Limited support for 3 part ids') + self.skipTest("Limited support for 3 part ids") table_name = "tbl_" + random_table_suffix() db = get_conn(self.db_cls) @@ -99,11 +99,8 @@ def test_three_part_support(self): table_two_part = table((schema_name, table_name), schema={"id": int}) table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) - db.query(table_one_part.create()) - db.query(table_one_part.drop()) - - db.query(table_two_part.create()) - db.query(table_two_part.drop()) - - db.query(table_three_part.create()) - db.query(table_three_part.drop()) + for part in (table_one_part, table_two_part, table_three_part): + db.query(part.create()) + d = db.query_table_schema(part.path) + assert len(d) == 1 + db.query(part.drop())