diff --git a/docs/intro.md b/docs/intro.md index 25dddf9..f703f9b 100644 --- a/docs/intro.md +++ b/docs/intro.md @@ -463,6 +463,8 @@ List of available abstract mixins: - `AbstractMixin_TimeTravel` - Only snowflake & bigquery +- `AbstractMixin_OptimizerHints` - Only oracle & mysql + More will be added in the future. Note that it's still possible to use user-defined mixins that aren't on this list. diff --git a/sqeleton/abcs/mixins.py b/sqeleton/abcs/mixins.py index 34d3250..bf1f277 100644 --- a/sqeleton/abcs/mixins.py +++ b/sqeleton/abcs/mixins.py @@ -145,3 +145,15 @@ def time_travel( Must specify exactly one of `timestamp`, `offset` or `statement`. """ + +class AbstractMixin_OptimizerHints(AbstractMixin): + @abstractmethod + def optimizer_hints( + self, + optimizer_hints: str + ) -> str: + """Creates a compatible optimizer_hints string + + Parameters: + optimizer_hints - string of opimizer hints + """ \ No newline at end of file diff --git a/sqeleton/databases/base.py b/sqeleton/databases/base.py index ea705ab..9d1b6c3 100644 --- a/sqeleton/databases/base.py +++ b/sqeleton/databases/base.py @@ -36,7 +36,12 @@ Boolean, ) from ..abcs.mixins import Compilable -from ..abcs.mixins import AbstractMixin_Schema, AbstractMixin_RandomSample, AbstractMixin_NormalizeValue +from ..abcs.mixins import ( + AbstractMixin_Schema, + AbstractMixin_RandomSample, + AbstractMixin_NormalizeValue, + AbstractMixin_OptimizerHints +) from ..bound_exprs import bound_table logger = logging.getLogger("database") @@ -134,6 +139,14 @@ def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> Abstra return tbl.where(Random() < ratio) +class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): + def optimizer_hints( + self, + hints: str + ) -> str: + return f"/*+ {hints} */ " + + class BaseDialect(AbstractDialect): SUPPORTS_PRIMARY_KEY = False SUPPORTS_INDEXES = False @@ -168,6 +181,7 @@ def current_timestamp(self) -> str: def explain_as_text(self, query: str) -> str: return f"EXPLAIN {query}" + def _constant_value(self, v): if v is None: return "NULL" diff --git a/sqeleton/databases/databricks.py b/sqeleton/databases/databricks.py index fff3d90..585a418 100644 --- a/sqeleton/databases/databricks.py +++ b/sqeleton/databases/databricks.py @@ -95,6 +95,7 @@ def set_timezone_to_utc(self) -> str: return "SET TIME ZONE 'UTC'" + class Databricks(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "databricks://:@/" diff --git a/sqeleton/databases/mysql.py b/sqeleton/databases/mysql.py index 527282b..fd4bc29 100644 --- a/sqeleton/databases/mysql.py +++ b/sqeleton/databases/mysql.py @@ -17,7 +17,7 @@ AbstractMixin_Regex, AbstractMixin_RandomSample, ) -from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable +from .base import Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema, Mixin_RandomSample from ..queries.ast_classes import BinBoolOp @@ -54,7 +54,7 @@ def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: return BinBoolOp("REGEXP", [string, pattern]) -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): name = "MySQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True @@ -109,6 +109,9 @@ def type_repr(self, t) -> str: def explain_as_text(self, query: str) -> str: return f"EXPLAIN FORMAT=TREE {query}" + def optimizer_hints(self, s: str): + return f"/*+ {s} */ " + def set_timezone_to_utc(self) -> str: return "SET @@session.time_zone='+00:00'" diff --git a/sqeleton/databases/oracle.py b/sqeleton/databases/oracle.py index b3da12a..79b5ec3 100644 --- a/sqeleton/databases/oracle.py +++ b/sqeleton/databases/oracle.py @@ -17,7 +17,7 @@ from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema from ..abcs import Compilable from ..queries import this, table, SKIP -from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_RandomSample +from .base import BaseDialect, Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_RandomSample from .base import TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -72,7 +72,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) -class Dialect(BaseDialect, Mixin_Schema): +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): name = "Oracle" SUPPORTS_PRIMARY_KEY = True SUPPORTS_INDEXES = True diff --git a/sqeleton/databases/snowflake.py b/sqeleton/databases/snowflake.py index ebc9bb8..5edafc0 100644 --- a/sqeleton/databases/snowflake.py +++ b/sqeleton/databases/snowflake.py @@ -136,6 +136,9 @@ def table_information(self) -> Compilable: def set_timezone_to_utc(self) -> str: return "ALTER SESSION SET TIMEZONE = 'UTC'" + def optimizer_hints(self, hints: str) -> str: + raise NotImplementedError("Optimizer hints not yet implemented in snowflake") + class Snowflake(Database): dialect = Dialect() diff --git a/sqeleton/queries/ast_classes.py b/sqeleton/queries/ast_classes.py index 5888e56..5f676f7 100644 --- a/sqeleton/queries/ast_classes.py +++ b/sqeleton/queries/ast_classes.py @@ -91,14 +91,14 @@ class ITable(AbstractTable): source_table: Any schema: Schema = None - def select(self, *exprs, distinct=SKIP, **named_exprs): + def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs): """Create a new table with the specified fields""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) named_exprs = _drop_skips_dict(named_exprs) exprs += _named_exprs_as_aliases(named_exprs) resolve_names(self.source_table, exprs) - return Select.make(self, columns=exprs, distinct=distinct) + return Select.make(self, columns=exprs, distinct=distinct, optimizer_hints=optimizer_hints) def where(self, *exprs): exprs = args_as_tuple(exprs) @@ -682,6 +682,7 @@ class Select(ExprNode, ITable, Root): having_exprs: Sequence[Expr] = None limit_expr: int = None distinct: bool = False + optimizer_hints: Sequence[Expr] = None @property def schema(self): @@ -699,7 +700,8 @@ def compile(self, parent_c: Compiler) -> str: columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" distinct = "DISTINCT " if self.distinct else "" - select = f"SELECT {distinct}{columns}" + optimizer_hints = c.dialect.optimizer_hints(self.optimizer_hints) if self.optimizer_hints else "" + select = f"SELECT {optimizer_hints}{distinct}{columns}" if self.table: select += " FROM " + c.compile(self.table) @@ -729,15 +731,19 @@ def compile(self, parent_c: Compiler) -> str: return select @classmethod - def make(cls, table: ITable, distinct: bool = SKIP, **kwargs): + def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs): assert "table" not in kwargs if not isinstance(table, cls): # If not Select if distinct is not SKIP: kwargs["distinct"] = distinct + if optimizer_hints is not SKIP: + kwargs["optimizer_hints"] = optimizer_hints return cls(table, **kwargs) # We can safely assume isinstance(table, Select) + if optimizer_hints is not SKIP: + kwargs["optimizer_hints"] = optimizer_hints if distinct is not SKIP: if distinct == False and table.distinct: @@ -752,7 +758,7 @@ def make(cls, table: ITable, distinct: bool = SKIP, **kwargs): if getattr(table, k) is not None: if k == "where_exprs": # Additive attribute kwargs[k] = getattr(table, k) + v - elif k == "distinct": + elif k in ["distinct", "optimizer_hints"]: pass else: raise ValueError(k) diff --git a/tests/test_query.py b/tests/test_query.py index 9330092..09469cd 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -16,6 +16,7 @@ def normalize_spaces(s: str): class MockDialect(AbstractDialect): name = "MockDialect" + PLACEHOLDER_TABLE = None ROUNDS_ON_PREC_LOSS = False def quote(self, s: str) -> str: @@ -50,6 +51,9 @@ def timestamp_value(self, t: datetime) -> str: def set_timezone_to_utc(self) -> str: return "set timezone 'UTC'" + def optimizer_hints(self, s: str): + return f"/*+ {s} */ " + def load_mixins(self): raise NotImplementedError() @@ -189,6 +193,22 @@ def test_select_distinct(self): q = c.compile(t.select(this.b, distinct=True).select(distinct=False)) self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2") + def test_select_with_optimizer_hints(self): + c = Compiler(MockDatabase()) + t = table("a") + + q = c.compile(t.select(this.b, optimizer_hints='PARALLEL(a 16)')) + assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a" + + q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints='PARALLEL(a 16)')) + self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)") + + q = c.compile(t.limit(10).select(this.b, optimizer_hints='PARALLEL(a 16)')) + self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM a LIMIT 10) tmp1") + + q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints='PARALLEL(a 16)')) + self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3") + def test_table_ops(self): c = Compiler(MockDatabase()) a = table("a").select(this.x)