diff --git a/sqeleton/databases/postgresql.py b/sqeleton/databases/postgresql.py index 7c7b47d..a7034f1 100644 --- a/sqeleton/databases/postgresql.py +++ b/sqeleton/databases/postgresql.py @@ -1,4 +1,5 @@ from ..abcs.database_types import ( + DbPath, Timestamp, TimestampTZ, Float, @@ -122,3 +123,27 @@ def create_connection(self): return c except pg.OperationalError as e: raise ConnectError(*e.args) from e + + def select_table_schema(self, path: DbPath) -> str: + database, schema, table = self._normalize_table_path(path) + + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return None, self.default_schema, path[0] + elif len(path) == 2: + return None, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) diff --git a/tests/test_database.py b/tests/test_database.py index 069f75d..fb6c074 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -80,3 +80,27 @@ def test_current_timestamp(self): db = get_conn(self.db_cls) res = db.query(current_timestamp(), datetime) assert isinstance(res, datetime), (res, type(res)) + + +@test_each_database +class TestThreePartIds(unittest.TestCase): + def test_three_part_support(self): + if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake]: + self.skipTest("Limited support for 3 part ids") + + table_name = "tbl_" + random_table_suffix() + db = get_conn(self.db_cls) + db_res = db.query("SELECT CURRENT_DATABASE()") + schema_res = db.query("SELECT CURRENT_SCHEMA()") + db_name = db_res.rows[0][0] + schema_name = schema_res.rows[0][0] + + table_one_part = table((table_name,), schema={"id": int}) + table_two_part = table((schema_name, table_name), schema={"id": int}) + table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) + + for part in (table_one_part, table_two_part, table_three_part): + db.query(part.create()) + d = db.query_table_schema(part.path) + assert len(d) == 1 + db.query(part.drop())