Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion sqlmesh/core/engine_adapter/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from sqlglot import exp

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.base import EngineAdapterWithIndexSupport
from sqlmesh.core.engine_adapter.base import (
EngineAdapterWithIndexSupport,
EngineAdapter,
InsertOverwriteStrategy,
)
from sqlmesh.core.engine_adapter.mixins import (
GetCurrentCatalogFromFunctionMixin,
InsertOverwriteWithMergeMixin,
Expand Down Expand Up @@ -281,3 +285,45 @@ def _rename_table(
# The function that renames tables in MSSQL takes string literals as arguments instead of identifiers,
# so we shouldn't quote the identifiers.
self.execute(exp.rename_table(old_table_name, new_table_name), quote_identifiers=False)

def _insert_overwrite_by_condition(
self,
table_name: TableName,
source_queries: t.List[SourceQuery],
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
where: t.Optional[exp.Condition] = None,
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
**kwargs: t.Any,
) -> None:
if not where or where == exp.true():
# this is a full table replacement, call the base strategy to do DELETE+INSERT
# which will result in TRUNCATE+INSERT due to how we have overridden self.delete_from()
return EngineAdapter._insert_overwrite_by_condition(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super(EngineAdapter, self) threw an exception and super(EngineAdapterWithIndexSupport, self) returned the _insert_overwrite_by_condition method from InsertOverwriteWithMergeMixin instead of from EngineAdapter.

So I worked around this by making a direct call to EngineAdapter

self,
table_name=table_name,
source_queries=source_queries,
columns_to_types=columns_to_types,
where=where,
insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT,
**kwargs,
)

# For actual conditional overwrites, use MERGE from InsertOverwriteWithMergeMixin
return super()._insert_overwrite_by_condition(
table_name=table_name,
source_queries=source_queries,
columns_to_types=columns_to_types,
where=where,
insert_overwrite_strategy_override=insert_overwrite_strategy_override,
**kwargs,
)

def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
if where == exp.true():
# "A TRUNCATE TABLE operation can be rolled back within a transaction."
# ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks
return self.execute(
exp.TruncateTable(expressions=[exp.to_table(table_name, dialect=self.dialect)])
)

return super().delete_from(table_name, where)
78 changes: 74 additions & 4 deletions tests/core/engine_adapter/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
pytestmark = [pytest.mark.engine, pytest.mark.mssql]


def test_columns(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(MSSQLEngineAdapter)
@pytest.fixture
def adapter(make_mocked_engine_adapter: t.Callable) -> MSSQLEngineAdapter:
return make_mocked_engine_adapter(MSSQLEngineAdapter)


def test_columns(adapter: MSSQLEngineAdapter):
adapter.cursor.fetchall.return_value = [
("decimal_ps", "decimal", None, 5, 4),
("decimal", "decimal", None, 18, 0),
Expand Down Expand Up @@ -504,7 +507,8 @@ def test_replace_query(make_mocked_engine_adapter: t.Callable):

assert to_sql_calls(adapter) == [
"""SELECT 1 FROM [information_schema].[tables] WHERE [table_name] = 'test_table';""",
"MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a] FROM [tbl]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a]) VALUES ([a]);",
"TRUNCATE TABLE [test_table];",
"INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];",
]


Expand Down Expand Up @@ -551,7 +555,8 @@ def temp_table_exists(table: exp.Table) -> bool:

assert to_sql_calls(adapter) == [
f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '{temp_table_name}') EXEC('CREATE TABLE [{temp_table_name}] ([a] INTEGER, [b] INTEGER)');""",
"MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [__temp_test_table_abcdefgh]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [b]) VALUES ([a], [b]);",
"TRUNCATE TABLE [test_table];",
f"INSERT INTO [test_table] ([a], [b]) SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [{temp_table_name}];",
f"DROP TABLE IF EXISTS [{temp_table_name}];",
]

Expand Down Expand Up @@ -751,3 +756,68 @@ def test_create_table_from_query(make_mocked_engine_adapter: t.Callable, mocker:
"CREATE VIEW [__temp_ctas_test_random_id] AS SELECT * FROM (SELECT TOP 1 * FROM [t]);"
in to_sql_calls(adapter)
)


def test_replace_query_strategy(adapter: MSSQLEngineAdapter, mocker: MockerFixture):
# ref issue 4472: https://github.com/TobikoData/sqlmesh/issues/4472
# The FULL strategy calls EngineAdapter.replace_query() which calls _insert_overwrite_by_condition() should use DELETE+INSERT and not MERGE
expressions = d.parse(
f"""
MODEL (
name db.table,
kind FULL,
dialect tsql
);

select a, b from db.upstream_table;
"""
)
model = load_sql_based_model(expressions)

exists_mock = mocker.patch(
"sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists",
return_value=False,
)

assert not adapter.table_exists("test_table")

# initial - table doesnt exist
adapter.replace_query(
"test_table",
model.render_query_or_raise(),
table_format=model.table_format,
storage_format=model.storage_format,
partitioned_by=model.partitioned_by,
partition_interval_unit=model.partition_interval_unit,
clustered_by=model.clustered_by,
table_properties=model.physical_properties,
table_description=model.description,
column_descriptions=model.column_descriptions,
columns_to_types=model.columns_to_types_or_raise,
)

# subsequent - table exists
exists_mock.return_value = True
assert adapter.table_exists("test_table")

adapter.replace_query(
"test_table",
model.render_query_or_raise(),
table_format=model.table_format,
storage_format=model.storage_format,
partitioned_by=model.partitioned_by,
partition_interval_unit=model.partition_interval_unit,
clustered_by=model.clustered_by,
table_properties=model.physical_properties,
table_description=model.description,
column_descriptions=model.column_descriptions,
columns_to_types=model.columns_to_types_or_raise,
)

assert to_sql_calls(adapter) == [
# initial - create table if not exists
"IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'test_table') EXEC('SELECT * INTO [test_table] FROM (SELECT [a] AS [a], [b] AS [b] FROM [db].[upstream_table] AS [upstream_table]) AS temp');",
# subsequent - truncate + insert
"TRUNCATE TABLE [test_table];",
"INSERT INTO [test_table] ([a], [b]) SELECT [a] AS [a], [b] AS [b] FROM [db].[upstream_table] AS [upstream_table];",
]