diff --git a/sqeleton/databases/databricks.py b/sqeleton/databases/databricks.py index bd03b2c..fff3d90 100644 --- a/sqeleton/databases/databricks.py +++ b/sqeleton/databases/databricks.py @@ -128,9 +128,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: conn = self.create_connection() - schema, table = self._normalize_table_path(path) + catalog, schema, table = self._normalize_table_path(path) with conn.cursor() as cursor: - cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table) + cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) try: rows = cursor.fetchall() finally: @@ -185,3 +185,15 @@ def parse_table_name(self, name: str) -> DbPath: @property def is_autocommit(self) -> bool: return True + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.catalog, self.default_schema, path[0] + elif len(path) == 2: + return self.catalog, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" + )