Skip to content
This repository was archived by the owner on May 2, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ List of available abstract mixins:

- `AbstractMixin_TimeTravel` - Only snowflake & bigquery

- `AbstractMixin_OptimizerHints` - Only oracle & mysql

More will be added in the future.

Note that it's still possible to use user-defined mixins that aren't on this list.
Expand Down
10 changes: 10 additions & 0 deletions sqeleton/abcs/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,13 @@ def time_travel(

Must specify exactly one of `timestamp`, `offset` or `statement`.
"""


class AbstractMixin_OptimizerHints(AbstractMixin):
@abstractmethod
def optimizer_hints(self, optimizer_hints: str) -> str:
"""Creates a compatible optimizer_hints string

Parameters:
optimizer_hints - string of optimizer hints
"""
12 changes: 11 additions & 1 deletion sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
Boolean,
)
from ..abcs.mixins import Compilable
from ..abcs.mixins import AbstractMixin_Schema, AbstractMixin_RandomSample, AbstractMixin_NormalizeValue
from ..abcs.mixins import (
AbstractMixin_Schema,
AbstractMixin_RandomSample,
AbstractMixin_NormalizeValue,
AbstractMixin_OptimizerHints,
)
from ..bound_exprs import bound_table

logger = logging.getLogger("database")
Expand Down Expand Up @@ -134,6 +139,11 @@ def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> Abstra
return tbl.where(Random() < ratio)


class Mixin_OptimizerHints(AbstractMixin_OptimizerHints):
def optimizer_hints(self, hints: str) -> str:
return f"/*+ {hints} */ "


class BaseDialect(AbstractDialect):
SUPPORTS_PRIMARY_KEY = False
SUPPORTS_INDEXES = False
Expand Down
7 changes: 5 additions & 2 deletions sqeleton/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AbstractMixin_Regex,
AbstractMixin_RandomSample,
)
from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable
from .base import Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema, Mixin_RandomSample
from ..queries.ast_classes import BinBoolOp

Expand Down Expand Up @@ -54,7 +54,7 @@ def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable:
return BinBoolOp("REGEXP", [string, pattern])


class Dialect(BaseDialect, Mixin_Schema):
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
name = "MySQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down Expand Up @@ -109,6 +109,9 @@ def type_repr(self, t) -> str:
def explain_as_text(self, query: str) -> str:
return f"EXPLAIN FORMAT=TREE {query}"

def optimizer_hints(self, s: str):
return f"/*+ {s} */ "

def set_timezone_to_utc(self) -> str:
return "SET @@session.time_zone='+00:00'"

Expand Down
12 changes: 10 additions & 2 deletions sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_RandomSample
from .base import (
BaseDialect,
Mixin_OptimizerHints,
ThreadedDatabase,
import_helper,
ConnectError,
QueryError,
Mixin_RandomSample,
)
from .base import TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
Expand Down Expand Up @@ -72,7 +80,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
)


class Dialect(BaseDialect, Mixin_Schema):
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
name = "Oracle"
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True
Expand Down
3 changes: 3 additions & 0 deletions sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def table_information(self) -> Compilable:
def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIMEZONE = 'UTC'"

def optimizer_hints(self, hints: str) -> str:
raise NotImplementedError("Optimizer hints not yet implemented in snowflake")


class Snowflake(Database):
dialect = Dialect()
Expand Down
16 changes: 11 additions & 5 deletions sqeleton/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ class ITable(AbstractTable):
source_table: Any
schema: Schema = None

def select(self, *exprs, distinct=SKIP, **named_exprs):
def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs):
"""Create a new table with the specified fields"""
exprs = args_as_tuple(exprs)
exprs = _drop_skips(exprs)
named_exprs = _drop_skips_dict(named_exprs)
exprs += _named_exprs_as_aliases(named_exprs)
resolve_names(self.source_table, exprs)
return Select.make(self, columns=exprs, distinct=distinct)
return Select.make(self, columns=exprs, distinct=distinct, optimizer_hints=optimizer_hints)

def where(self, *exprs):
exprs = args_as_tuple(exprs)
Expand Down Expand Up @@ -682,6 +682,7 @@ class Select(ExprNode, ITable, Root):
having_exprs: Sequence[Expr] = None
limit_expr: int = None
distinct: bool = False
optimizer_hints: Sequence[Expr] = None

@property
def schema(self):
Expand All @@ -699,7 +700,8 @@ def compile(self, parent_c: Compiler) -> str:

columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*"
distinct = "DISTINCT " if self.distinct else ""
select = f"SELECT {distinct}{columns}"
optimizer_hints = c.dialect.optimizer_hints(self.optimizer_hints) if self.optimizer_hints else ""
select = f"SELECT {optimizer_hints}{distinct}{columns}"

if self.table:
select += " FROM " + c.compile(self.table)
Expand Down Expand Up @@ -729,15 +731,19 @@ def compile(self, parent_c: Compiler) -> str:
return select

@classmethod
def make(cls, table: ITable, distinct: bool = SKIP, **kwargs):
def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs):
assert "table" not in kwargs

if not isinstance(table, cls): # If not Select
if distinct is not SKIP:
kwargs["distinct"] = distinct
if optimizer_hints is not SKIP:
kwargs["optimizer_hints"] = optimizer_hints
return cls(table, **kwargs)

# We can safely assume isinstance(table, Select)
if optimizer_hints is not SKIP:
kwargs["optimizer_hints"] = optimizer_hints

if distinct is not SKIP:
if distinct == False and table.distinct:
Expand All @@ -752,7 +758,7 @@ def make(cls, table: ITable, distinct: bool = SKIP, **kwargs):
if getattr(table, k) is not None:
if k == "where_exprs": # Additive attribute
kwargs[k] = getattr(table, k) + v
elif k == "distinct":
elif k in ["distinct", "optimizer_hints"]:
pass
else:
raise ValueError(k)
Expand Down
2 changes: 1 addition & 1 deletion sqeleton/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def repl(uri):
continue
try:
path = db.parse_table_name(table_name)
print('->', path)
print("->", path)
schema = db.query_table_schema(path)
except Exception as e:
logging.error(e)
Expand Down
15 changes: 13 additions & 2 deletions sqeleton/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict, Hashable, TypeVar, TYPE_CHECKING, List
from typing import (
Iterable,
Iterator,
MutableMapping,
Union,
Any,
Sequence,
Dict,
Hashable,
TypeVar,
TYPE_CHECKING,
List,
)
from abc import abstractmethod
from weakref import ref
import math
Expand Down Expand Up @@ -256,7 +268,6 @@ def __eq__(self, other):
return NotImplemented
return self._str == other._str


def new(self, *args, **kw):
return type(self)(*args, **kw, max_len=self._max_len)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def normalize_spaces(s: str):
class MockDialect(AbstractDialect):
name = "MockDialect"

PLACEHOLDER_TABLE = None
ROUNDS_ON_PREC_LOSS = False

def quote(self, s: str) -> str:
Expand Down Expand Up @@ -50,6 +51,9 @@ def timestamp_value(self, t: datetime) -> str:
def set_timezone_to_utc(self) -> str:
return "set timezone 'UTC'"

def optimizer_hints(self, s: str):
return f"/*+ {s} */ "

def load_mixins(self):
raise NotImplementedError()

Expand Down Expand Up @@ -189,6 +193,24 @@ def test_select_distinct(self):
q = c.compile(t.select(this.b, distinct=True).select(distinct=False))
self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2")

def test_select_with_optimizer_hints(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a"

q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)")

q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM a LIMIT 10) tmp1")

q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3"
)

def test_table_ops(self):
c = Compiler(MockDatabase())
a = table("a").select(this.x)
Expand Down