Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,30 @@ def _insert_overwrite_by_condition(
target_columns_to_types=target_columns_to_types,
order_projections=False,
)
elif insert_overwrite_strategy.is_merge:
columns = [exp.column(col) for col in target_columns_to_types]
when_not_matched_by_source = exp.When(
matched=False,
source=True,
condition=where,
then=exp.Delete(),
)
when_not_matched_by_target = exp.When(
matched=False,
source=False,
then=exp.Insert(
this=exp.Tuple(expressions=columns),
expression=exp.Tuple(expressions=columns),
),
)
self._merge(
target_table=table_name,
query=query,
on=exp.false(),
whens=exp.Whens(
expressions=[when_not_matched_by_source, when_not_matched_by_target]
),
)
else:
insert_exp = exp.insert(
query,
Expand Down
5 changes: 3 additions & 2 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.mixins import (
InsertOverwriteWithMergeMixin,
ClusteredByMixin,
RowDiffMixin,
TableAlterClusterByOperation,
Expand All @@ -20,6 +19,7 @@
DataObjectType,
SourceQuery,
set_catalog,
InsertOverwriteStrategy,
)
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
Expand Down Expand Up @@ -54,7 +54,7 @@


@set_catalog()
class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, RowDiffMixin):
class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin):
"""
BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API.
"""
Expand All @@ -68,6 +68,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
MAX_COLUMN_COMMENT_LENGTH = 1024
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE

SCHEMA_DIFFER_KWARGS = {
"compatible_types": {
Expand Down
29 changes: 1 addition & 28 deletions sqlmesh/core/engine_adapter/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,15 @@
from functools import cached_property
from sqlglot import exp
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result
from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
from sqlmesh.core.engine_adapter.shared import (
InsertOverwriteStrategy,
SourceQuery,
)
from sqlmesh.core.engine_adapter.base import EngineAdapter
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.connection_pool import ConnectionPool


if t.TYPE_CHECKING:
from sqlmesh.core._typing import TableName


from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -58,26 +51,6 @@ def _target_catalog(self) -> t.Optional[str]:
def _target_catalog(self, value: t.Optional[str]) -> None:
self._connection_pool.set_attribute("target_catalog", value)

def _insert_overwrite_by_condition(
self,
table_name: TableName,
source_queries: t.List[SourceQuery],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
where: t.Optional[exp.Condition] = None,
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
**kwargs: t.Any,
) -> None:
# Override to avoid MERGE statement which isn't fully supported in Fabric
return EngineAdapter._insert_overwrite_by_condition(
self,
table_name=table_name,
source_queries=source_queries,
target_columns_to_types=target_columns_to_types,
where=where,
insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT,
**kwargs,
)

@property
def api_client(self) -> FabricHttpClient:
# the requests Session is not guaranteed to be threadsafe
Expand Down
47 changes: 0 additions & 47 deletions sqlmesh/core/engine_adapter/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sqlglot.helper import seq_get

from sqlmesh.core.engine_adapter.base import EngineAdapter
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.dialect import schema_
from sqlmesh.core.schema_diff import TableAlterOperation
Expand Down Expand Up @@ -75,52 +74,6 @@ def _fetch_native_df(
return df


class InsertOverwriteWithMergeMixin(EngineAdapter):
def _insert_overwrite_by_condition(
self,
table_name: TableName,
source_queries: t.List[SourceQuery],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
where: t.Optional[exp.Condition] = None,
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
**kwargs: t.Any,
) -> None:
"""
Some engines do not support `INSERT OVERWRITE` but instead support
doing an "INSERT OVERWRITE" using a Merge expression but with the
predicate being `False`.
"""
target_columns_to_types = target_columns_to_types or self.columns(table_name)
for source_query in source_queries:
with source_query as query:
query = self._order_projections_and_filter(
query, target_columns_to_types, where=where
)
columns = [exp.column(col) for col in target_columns_to_types]
when_not_matched_by_source = exp.When(
matched=False,
source=True,
condition=where,
then=exp.Delete(),
)
when_not_matched_by_target = exp.When(
matched=False,
source=False,
then=exp.Insert(
this=exp.Tuple(expressions=columns),
expression=exp.Tuple(expressions=columns),
),
)
self._merge(
target_table=table_name,
query=query,
on=exp.false(),
whens=exp.Whens(
expressions=[when_not_matched_by_source, when_not_matched_by_target]
),
)


class HiveMetastoreTablePropertiesMixin(EngineAdapter):
MAX_TABLE_COMMENT_LENGTH = 4000
MAX_COLUMN_COMMENT_LENGTH = 4000
Expand Down
3 changes: 1 addition & 2 deletions sqlmesh/core/engine_adapter/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from sqlmesh.core.engine_adapter.mixins import (
GetCurrentCatalogFromFunctionMixin,
InsertOverwriteWithMergeMixin,
PandasNativeFetchDFSupportMixin,
VarcharSizeWorkaroundMixin,
RowDiffMixin,
Expand All @@ -41,7 +40,6 @@
class MSSQLEngineAdapter(
EngineAdapterWithIndexSupport,
PandasNativeFetchDFSupportMixin,
InsertOverwriteWithMergeMixin,
GetCurrentCatalogFromFunctionMixin,
VarcharSizeWorkaroundMixin,
RowDiffMixin,
Expand Down Expand Up @@ -74,6 +72,7 @@ class MSSQLEngineAdapter(
},
}
VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"}
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE

@property
def catalog_support(self) -> CatalogSupport:
Expand Down
6 changes: 6 additions & 0 deletions sqlmesh/core/engine_adapter/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ class InsertOverwriteStrategy(Enum):
# Issue a single INSERT query to replace a data range. The assumption is that the query engine will transparently match partition bounds
# and replace data rather than append to it. Trino is an example of this when `hive.insert-existing-partitions-behavior=OVERWRITE` is configured
INTO_IS_OVERWRITE = 4
# Do the INSERT OVERWRITE using merge since the engine doesn't support it natively
MERGE = 5

@property
def is_delete_insert(self) -> bool:
Expand All @@ -260,6 +262,10 @@ def is_replace_where(self) -> bool:
def is_into_is_overwrite(self) -> bool:
return self == InsertOverwriteStrategy.INTO_IS_OVERWRITE

@property
def is_merge(self) -> bool:
return self == InsertOverwriteStrategy.MERGE


class SourceQuery:
def __init__(
Expand Down
6 changes: 2 additions & 4 deletions tests/core/engine_adapter/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
from sqlmesh.core import dialect as d
from sqlmesh.core.dialect import normalize_model_name
from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport
from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObject
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation, NestedSupport
from sqlmesh.utils import columns_to_types_to_struct
from sqlmesh.utils.date import to_ds
from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError
from tests.core.engine_adapter import to_sql_calls

if t.TYPE_CHECKING:
pass

pytestmark = pytest.mark.engine

Expand Down Expand Up @@ -482,7 +479,8 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable):
def test_insert_overwrite_by_condition_column_contains_unsafe_characters(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture
):
adapter = make_mocked_engine_adapter(InsertOverwriteWithMergeMixin)
adapter = make_mocked_engine_adapter(EngineAdapter)
adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE

source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types(
parse_one("SELECT 1 AS c"), None, target_table="test_table"
Expand Down
41 changes: 0 additions & 41 deletions tests/core/engine_adapter/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from sqlmesh.core.engine_adapter.shared import (
DataObject,
DataObjectType,
InsertOverwriteStrategy,
)
from sqlmesh.utils.date import to_ds
from tests.core.engine_adapter import to_sql_calls
Expand Down Expand Up @@ -342,46 +341,6 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_exi
]


def test_insert_overwrite_by_time_partition_replace_where_pandas(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
):
mocker.patch(
"sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists",
return_value=False,
)

adapter = make_mocked_engine_adapter(MSSQLEngineAdapter)
adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE

temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
table_name = "test_table"
temp_table_id = "abcdefgh"
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)

df = pd.DataFrame({"a": [1, 2], "ds": ["2022-01-01", "2022-01-02"]})
adapter.insert_overwrite_by_time_partition(
table_name,
df,
start="2022-01-01",
end="2022-01-02",
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
time_column="ds",
target_columns_to_types={
"a": exp.DataType.build("INT"),
"ds": exp.DataType.build("STRING"),
},
)
adapter._connection_pool.get().bulk_copy.assert_called_with(
f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")]
)

assert to_sql_calls(adapter) == [
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""",
f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""",
f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];",
]
Comment on lines -345 to -382
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test didn't make sense. It was changing the MSSQL engine adapter to use REPLACE_WHERE and then showing that it did nothing. It would now do something but that is expected. I think it will likely a copy/paste mistake from other tests without understanding the intent.



def test_insert_append_pandas(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
):
Expand Down