diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 9a198d5324..240352b1f0 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -10,7 +10,11 @@ from sqlglot import exp from sqlmesh.core.dialect import to_schema -from sqlmesh.core.engine_adapter.base import EngineAdapterWithIndexSupport +from sqlmesh.core.engine_adapter.base import ( + EngineAdapterWithIndexSupport, + EngineAdapter, + InsertOverwriteStrategy, +) from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, InsertOverwriteWithMergeMixin, @@ -281,3 +285,45 @@ def _rename_table( # The function that renames tables in MSSQL takes string literals as arguments instead of identifiers, # so we shouldn't quote the identifiers. self.execute(exp.rename_table(old_table_name, new_table_name), quote_identifiers=False) + + def _insert_overwrite_by_condition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + where: t.Optional[exp.Condition] = None, + insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, + **kwargs: t.Any, + ) -> None: + if not where or where == exp.true(): + # this is a full table replacement, call the base strategy to do DELETE+INSERT + # which will result in TRUNCATE+INSERT due to how we have overridden self.delete_from() + return EngineAdapter._insert_overwrite_by_condition( + self, + table_name=table_name, + source_queries=source_queries, + columns_to_types=columns_to_types, + where=where, + insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, + **kwargs, + ) + + # For actual conditional overwrites, use MERGE from InsertOverwriteWithMergeMixin + return super()._insert_overwrite_by_condition( + table_name=table_name, + source_queries=source_queries, + columns_to_types=columns_to_types, + where=where, + insert_overwrite_strategy_override=insert_overwrite_strategy_override, + **kwargs, + ) + + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + if where == exp.true(): + # "A TRUNCATE TABLE operation can be rolled back within a transaction." + # ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks + return self.execute( + exp.TruncateTable(expressions=[exp.to_table(table_name, dialect=self.dialect)]) + ) + + return super().delete_from(table_name, where) diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index cd961c96ac..a5e8aa8ecf 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -24,9 +24,12 @@ pytestmark = [pytest.mark.engine, pytest.mark.mssql] -def test_columns(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> MSSQLEngineAdapter: + return make_mocked_engine_adapter(MSSQLEngineAdapter) + +def test_columns(adapter: MSSQLEngineAdapter): adapter.cursor.fetchall.return_value = [ ("decimal_ps", "decimal", None, 5, 4), ("decimal", "decimal", None, 18, 0), @@ -504,7 +507,8 @@ def test_replace_query(make_mocked_engine_adapter: t.Callable): assert to_sql_calls(adapter) == [ """SELECT 1 FROM [information_schema].[tables] WHERE [table_name] = 'test_table';""", - "MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a] FROM [tbl]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a]) VALUES ([a]);", + "TRUNCATE TABLE [test_table];", + "INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];", ] @@ -551,7 +555,8 @@ def temp_table_exists(table: exp.Table) -> bool: assert to_sql_calls(adapter) == [ f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '{temp_table_name}') EXEC('CREATE TABLE [{temp_table_name}] ([a] INTEGER, [b] INTEGER)');""", - "MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [__temp_test_table_abcdefgh]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [b]) VALUES ([a], [b]);", + "TRUNCATE TABLE [test_table];", + f"INSERT INTO [test_table] ([a], [b]) SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [{temp_table_name}];", f"DROP TABLE IF EXISTS [{temp_table_name}];", ] @@ -751,3 +756,68 @@ def test_create_table_from_query(make_mocked_engine_adapter: t.Callable, mocker: "CREATE VIEW [__temp_ctas_test_random_id] AS SELECT * FROM (SELECT TOP 1 * FROM [t]);" in to_sql_calls(adapter) ) + + +def test_replace_query_strategy(adapter: MSSQLEngineAdapter, mocker: MockerFixture): + # ref issue 4472: https://github.com/TobikoData/sqlmesh/issues/4472 + # The FULL strategy calls EngineAdapter.replace_query() which calls _insert_overwrite_by_condition() should use DELETE+INSERT and not MERGE + expressions = d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + dialect tsql + ); + + select a, b from db.upstream_table; + """ + ) + model = load_sql_based_model(expressions) + + exists_mock = mocker.patch( + "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists", + return_value=False, + ) + + assert not adapter.table_exists("test_table") + + # initial - table doesnt exist + adapter.replace_query( + "test_table", + model.render_query_or_raise(), + table_format=model.table_format, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=model.physical_properties, + table_description=model.description, + column_descriptions=model.column_descriptions, + columns_to_types=model.columns_to_types_or_raise, + ) + + # subsequent - table exists + exists_mock.return_value = True + assert adapter.table_exists("test_table") + + adapter.replace_query( + "test_table", + model.render_query_or_raise(), + table_format=model.table_format, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=model.physical_properties, + table_description=model.description, + column_descriptions=model.column_descriptions, + columns_to_types=model.columns_to_types_or_raise, + ) + + assert to_sql_calls(adapter) == [ + # initial - create table if not exists + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'test_table') EXEC('SELECT * INTO [test_table] FROM (SELECT [a] AS [a], [b] AS [b] FROM [db].[upstream_table] AS [upstream_table]) AS temp');", + # subsequent - truncate + insert + "TRUNCATE TABLE [test_table];", + "INSERT INTO [test_table] ([a], [b]) SELECT [a] AS [a], [b] AS [b] FROM [db].[upstream_table] AS [upstream_table];", + ]