Skip to content
This repository was archived by the owner on May 2, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..abcs.database_types import (
DbPath,
Timestamp,
TimestampTZ,
Float,
Expand Down Expand Up @@ -122,3 +123,27 @@ def create_connection(self):
return c
except pg.OperationalError as e:
raise ConnectError(*e.args) from e

def select_table_schema(self, path: DbPath) -> str:
database, schema, table = self._normalize_table_path(path)

info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, database)

return (
f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
return None, self.default_schema, path[0]
elif len(path) == 2:
return None, path[0], path[1]
elif len(path) == 3:
return path

raise ValueError(
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
)
24 changes: 24 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,27 @@ def test_current_timestamp(self):
db = get_conn(self.db_cls)
res = db.query(current_timestamp(), datetime)
assert isinstance(res, datetime), (res, type(res))


@test_each_database
class TestThreePartIds(unittest.TestCase):
def test_three_part_support(self):
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake]:
self.skipTest("Limited support for 3 part ids")

table_name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
db_res = db.query("SELECT CURRENT_DATABASE()")
schema_res = db.query("SELECT CURRENT_SCHEMA()")
db_name = db_res.rows[0][0]
schema_name = schema_res.rows[0][0]

table_one_part = table((table_name,), schema={"id": int})
table_two_part = table((schema_name, table_name), schema={"id": int})
table_three_part = table((db_name, schema_name, table_name), schema={"id": int})

for part in (table_one_part, table_two_part, table_three_part):
db.query(part.create())
d = db.query_table_schema(part.path)
assert len(d) == 1
db.query(part.drop())