diff --git a/README.md b/README.md index 29abb35..27c6751 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,11 @@ Databases we support: - DuckDB >=0.6 - SQLite (coming soon) + +### Documentation + +[Read the docs!](https://sqeleton.readthedocs.io) + ### Basic usage ```python diff --git a/docs/intro.md b/docs/intro.md index a1c10c1..8543ceb 100644 --- a/docs/intro.md +++ b/docs/intro.md @@ -505,6 +505,72 @@ ddb: AbstractDatabase[NewAbstractDialect] = connect("duckdb://:memory:") # ddb.dialect is now known to implement NewAbstractDialect. ``` +### Query interpreter + +In addition to query expressions, `Database.query()` can accept a generator, which will behave as an "interpreter". + +The generator executes queries by yielding them. + +Using a query interpreter also guarantees that subsequent calls to `.query()` will run in the same session. That can be useful for using temporary tables, or session variables. + +Example: + +```python +def sample_using_temp_table(db: Database, source_table: ITable, sample_size: int): + "This function creates a temporary table from a query and then samples rows from it" + + results = [] + + def _sample_using_temp_table(): + nonlocal results + + yield code("CREATE TEMPORARY TABLE tmp1 AS {source_table}", source_table=source_table) + + tbl = table('tmp1') + try: + results += yield sample(tbl, sample_size) + finally: + yield tbl.drop() + + db.query(_sample_using_temp_table()) + return results +``` + ### Query params -### Query interpreter \ No newline at end of file +TODO + +## Other features + +### SQL client + +Sqeleton comes with a simple built-in SQL client, in the form of a REPL, which accepts SQL commands, and a few special commands. + +It accepts any database URL that is supported by Sqeleton. That can be useful for querying databases that don't have established clients. + +You can call it using `sqeleton repl `. + +Example: + +```bash +# Start a REPL session +$ sqeleton repl duckdb:///pii_test.ddb + +# Run SQL +DuckDB> select (22::float / 7) as almost_pi +┏━━━━━━━━━━━━━━━━━━━┓ +┃ almost_pi ┃ +┡━━━━━━━━━━━━━━━━━━━┩ +│ 3.142857074737549 │ +└───────────────────┘ + 1 rows + +# Display help +DuckDB> ? + +Commands: + ?mytable - shows schema of table 'mytable' + * - shows list of all tables + *pattern - shows list of all tables with name like pattern +Otherwise, runs regular SQL query +``` \ No newline at end of file diff --git a/sqeleton/databases/base.py b/sqeleton/databases/base.py index 7dbe912..294e228 100644 --- a/sqeleton/databases/base.py +++ b/sqeleton/databases/base.py @@ -10,6 +10,8 @@ from uuid import UUID import decimal +from runtype import dataclass + from ..utils import is_uuid, safezip, Self from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this from ..queries.ast_classes import Random @@ -265,6 +267,21 @@ class _DialectWithMixins(cls, *mixins, *abstract_mixins): T = TypeVar("T", bound=BaseDialect) +@dataclass +class QueryResult: + rows: list + columns: list = None + + def __iter__(self): + return iter(self.rows) + + def __len__(self): + return len(self.rows) + + def __getitem__(self, i): + return self.rows[i] + + class Database(AbstractDatabase[T]): """Base abstract class for databases. @@ -473,7 +490,8 @@ def _query_cursor(self, c, sql_code: str): try: c.execute(sql_code) if sql_code.lower().startswith(("select", "explain", "show")): - return c.fetchall() + columns = [col[0] for col in c.description] + return QueryResult(c.fetchall(), columns) except Exception as _e: # logger.exception(e) # logger.error(f'Caused by SQL: {sql_code}') @@ -519,7 +537,7 @@ def set_conn(self): assert not hasattr(self.thread_local, "conn") try: self.thread_local.conn = self.create_connection() - except ModuleNotFoundError as e: + except Exception as e: self._init_error = e def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): diff --git a/sqeleton/databases/clickhouse.py b/sqeleton/databases/clickhouse.py index 8ff292c..5e9326b 100644 --- a/sqeleton/databases/clickhouse.py +++ b/sqeleton/databases/clickhouse.py @@ -163,7 +163,7 @@ def current_timestamp(self) -> str: class Clickhouse(ThreadedDatabase): dialect = Dialect() - CONNECT_URI_HELP = "clickhouse://:@/" + CONNECT_URI_HELP = "clickhouse://:@/" CONNECT_URI_PARAMS = ["database?"] def __init__(self, *, thread_count: int, **kw): diff --git a/sqeleton/databases/duckdb.py b/sqeleton/databases/duckdb.py index f6c6b77..07ae4f8 100644 --- a/sqeleton/databases/duckdb.py +++ b/sqeleton/databases/duckdb.py @@ -141,7 +141,7 @@ class DuckDB(Database): dialect = Dialect() SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it default_schema = "main" - CONNECT_URI_HELP = "duckdb://@" + CONNECT_URI_HELP = "duckdb://@" CONNECT_URI_PARAMS = ["database", "dbpath"] def __init__(self, **kw): diff --git a/sqeleton/databases/mysql.py b/sqeleton/databases/mysql.py index ab39778..522d599 100644 --- a/sqeleton/databases/mysql.py +++ b/sqeleton/databases/mysql.py @@ -109,7 +109,7 @@ class MySQL(ThreadedDatabase): dialect = Dialect() SUPPORTS_ALPHANUMS = False SUPPORTS_UNIQUE_CONSTAINT = True - CONNECT_URI_HELP = "mysql://:@/" + CONNECT_URI_HELP = "mysql://:@/" CONNECT_URI_PARAMS = ["database?"] def __init__(self, *, thread_count, **kw): diff --git a/sqeleton/databases/oracle.py b/sqeleton/databases/oracle.py index b95366b..b3da12a 100644 --- a/sqeleton/databases/oracle.py +++ b/sqeleton/databases/oracle.py @@ -160,7 +160,7 @@ def current_timestamp(self) -> str: class Oracle(ThreadedDatabase): dialect = Dialect() - CONNECT_URI_HELP = "oracle://:@/" + CONNECT_URI_HELP = "oracle://:@/" CONNECT_URI_PARAMS = ["database?"] def __init__(self, *, host, database, thread_count, **kw): diff --git a/sqeleton/databases/postgresql.py b/sqeleton/databases/postgresql.py index ecf07d0..8d1be75 100644 --- a/sqeleton/databases/postgresql.py +++ b/sqeleton/databases/postgresql.py @@ -98,12 +98,13 @@ def current_timestamp(self) -> str: class PostgreSQL(ThreadedDatabase): dialect = PostgresqlDialect() SUPPORTS_UNIQUE_CONSTAINT = True - CONNECT_URI_HELP = "postgresql://:@/" + CONNECT_URI_HELP = "postgresql://:@/" CONNECT_URI_PARAMS = ["database?"] default_schema = "public" def __init__(self, *, thread_count, **kw): + print("###", kw) self._args = kw super().__init__(thread_count=thread_count) diff --git a/sqeleton/databases/redshift.py b/sqeleton/databases/redshift.py index e44847c..eb74d36 100644 --- a/sqeleton/databases/redshift.py +++ b/sqeleton/databases/redshift.py @@ -60,7 +60,7 @@ def is_distinct_from(self, a: str, b: str) -> str: class Redshift(PostgreSQL): dialect = Dialect() - CONNECT_URI_HELP = "redshift://:@/" + CONNECT_URI_HELP = "redshift://:@/" CONNECT_URI_PARAMS = ["database?"] def select_table_schema(self, path: DbPath) -> str: diff --git a/sqeleton/databases/snowflake.py b/sqeleton/databases/snowflake.py index 1643cdf..ebc9bb8 100644 --- a/sqeleton/databases/snowflake.py +++ b/sqeleton/databases/snowflake.py @@ -139,7 +139,7 @@ def set_timezone_to_utc(self) -> str: class Snowflake(Database): dialect = Dialect() - CONNECT_URI_HELP = "snowflake://:@//?warehouse=" + CONNECT_URI_HELP = "snowflake://:@//?warehouse=" CONNECT_URI_PARAMS = ["database", "schema"] CONNECT_URI_KWPARAMS = ["warehouse"] diff --git a/sqeleton/databases/vertica.py b/sqeleton/databases/vertica.py index c80d845..3f853ea 100644 --- a/sqeleton/databases/vertica.py +++ b/sqeleton/databases/vertica.py @@ -152,7 +152,7 @@ def current_timestamp(self) -> str: class Vertica(ThreadedDatabase): dialect = Dialect() - CONNECT_URI_HELP = "vertica://:@/" + CONNECT_URI_HELP = "vertica://:@/" CONNECT_URI_PARAMS = ["database?"] default_schema = "public" diff --git a/sqeleton/repl.py b/sqeleton/repl.py index 8a7b75a..cea1bbd 100644 --- a/sqeleton/repl.py +++ b/sqeleton/repl.py @@ -4,7 +4,6 @@ # logging.basicConfig(level=logging.DEBUG) from . import connect -from .queries import table import sys @@ -54,13 +53,14 @@ def repl(uri): else: print_table([(k, v[1]) for k, v in schema.items()], ["name", "type"], f"Table '{table_name}'") else: + # Normal SQL query try: res = db.query(q) except Exception as e: logging.error(e) else: if res: - print_table(res, [str(i) for i in range(len(res[0]))], q) + print_table(res.rows, res.columns, None) def main():