diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index c48ce2154d..94900f0193 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1633,6 +1633,30 @@ def _insert_overwrite_by_condition( target_columns_to_types=target_columns_to_types, order_projections=False, ) + elif insert_overwrite_strategy.is_merge: + columns = [exp.column(col) for col in target_columns_to_types] + when_not_matched_by_source = exp.When( + matched=False, + source=True, + condition=where, + then=exp.Delete(), + ) + when_not_matched_by_target = exp.When( + matched=False, + source=False, + then=exp.Insert( + this=exp.Tuple(expressions=columns), + expression=exp.Tuple(expressions=columns), + ), + ) + self._merge( + target_table=table_name, + query=query, + on=exp.false(), + whens=exp.Whens( + expressions=[when_not_matched_by_source, when_not_matched_by_target] + ), + ) else: insert_exp = exp.insert( query, diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index b3d02d8bbf..00b33f67a5 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -9,7 +9,6 @@ from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.mixins import ( - InsertOverwriteWithMergeMixin, ClusteredByMixin, RowDiffMixin, TableAlterClusterByOperation, @@ -20,6 +19,7 @@ DataObjectType, SourceQuery, set_catalog, + InsertOverwriteStrategy, ) from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport @@ -54,7 +54,7 @@ @set_catalog() -class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, RowDiffMixin): +class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin): """ BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API. """ @@ -68,6 +68,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE SCHEMA_DIFFER_KWARGS = { "compatible_types": { diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py index a528be3cb4..8e2fb0e496 100644 --- a/sqlmesh/core/engine_adapter/fabric.py +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -7,22 +7,15 @@ from functools import cached_property from sqlglot import exp from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result +from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter from sqlmesh.core.engine_adapter.shared import ( InsertOverwriteStrategy, - SourceQuery, ) -from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.connection_pool import ConnectionPool -if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName - - -from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin - logger = logging.getLogger(__name__) @@ -58,26 +51,6 @@ def _target_catalog(self) -> t.Optional[str]: def _target_catalog(self, value: t.Optional[str]) -> None: self._connection_pool.set_attribute("target_catalog", value) - def _insert_overwrite_by_condition( - self, - table_name: TableName, - source_queries: t.List[SourceQuery], - target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - where: t.Optional[exp.Condition] = None, - insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, - **kwargs: t.Any, - ) -> None: - # Override to avoid MERGE statement which isn't fully supported in Fabric - return EngineAdapter._insert_overwrite_by_condition( - self, - table_name=table_name, - source_queries=source_queries, - target_columns_to_types=target_columns_to_types, - where=where, - insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, - **kwargs, - ) - @property def api_client(self) -> FabricHttpClient: # the requests Session is not guaranteed to be threadsafe diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 865e47fb93..1d66da0607 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -9,7 +9,6 @@ from sqlglot.helper import seq_get from sqlmesh.core.engine_adapter.base import EngineAdapter -from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery from sqlmesh.core.node import IntervalUnit from sqlmesh.core.dialect import schema_ from sqlmesh.core.schema_diff import TableAlterOperation @@ -75,52 +74,6 @@ def _fetch_native_df( return df -class InsertOverwriteWithMergeMixin(EngineAdapter): - def _insert_overwrite_by_condition( - self, - table_name: TableName, - source_queries: t.List[SourceQuery], - target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - where: t.Optional[exp.Condition] = None, - insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, - **kwargs: t.Any, - ) -> None: - """ - Some engines do not support `INSERT OVERWRITE` but instead support - doing an "INSERT OVERWRITE" using a Merge expression but with the - predicate being `False`. - """ - target_columns_to_types = target_columns_to_types or self.columns(table_name) - for source_query in source_queries: - with source_query as query: - query = self._order_projections_and_filter( - query, target_columns_to_types, where=where - ) - columns = [exp.column(col) for col in target_columns_to_types] - when_not_matched_by_source = exp.When( - matched=False, - source=True, - condition=where, - then=exp.Delete(), - ) - when_not_matched_by_target = exp.When( - matched=False, - source=False, - then=exp.Insert( - this=exp.Tuple(expressions=columns), - expression=exp.Tuple(expressions=columns), - ), - ) - self._merge( - target_table=table_name, - query=query, - on=exp.false(), - whens=exp.Whens( - expressions=[when_not_matched_by_source, when_not_matched_by_target] - ), - ) - - class HiveMetastoreTablePropertiesMixin(EngineAdapter): MAX_TABLE_COMMENT_LENGTH = 4000 MAX_COLUMN_COMMENT_LENGTH = 4000 diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 50a67b4b37..fd0bf1011b 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -16,7 +16,6 @@ ) from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, - InsertOverwriteWithMergeMixin, PandasNativeFetchDFSupportMixin, VarcharSizeWorkaroundMixin, RowDiffMixin, @@ -41,7 +40,6 @@ class MSSQLEngineAdapter( EngineAdapterWithIndexSupport, PandasNativeFetchDFSupportMixin, - InsertOverwriteWithMergeMixin, GetCurrentCatalogFromFunctionMixin, VarcharSizeWorkaroundMixin, RowDiffMixin, @@ -74,6 +72,7 @@ class MSSQLEngineAdapter( }, } VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"} + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE @property def catalog_support(self) -> CatalogSupport: diff --git a/sqlmesh/core/engine_adapter/shared.py b/sqlmesh/core/engine_adapter/shared.py index 55f04a995e..ba0e1fa619 100644 --- a/sqlmesh/core/engine_adapter/shared.py +++ b/sqlmesh/core/engine_adapter/shared.py @@ -243,6 +243,8 @@ class InsertOverwriteStrategy(Enum): # Issue a single INSERT query to replace a data range. The assumption is that the query engine will transparently match partition bounds # and replace data rather than append to it. Trino is an example of this when `hive.insert-existing-partitions-behavior=OVERWRITE` is configured INTO_IS_OVERWRITE = 4 + # Do the INSERT OVERWRITE using merge since the engine doesn't support it natively + MERGE = 5 @property def is_delete_insert(self) -> bool: @@ -260,6 +262,10 @@ def is_replace_where(self) -> bool: def is_into_is_overwrite(self) -> bool: return self == InsertOverwriteStrategy.INTO_IS_OVERWRITE + @property + def is_merge(self) -> bool: + return self == InsertOverwriteStrategy.MERGE + class SourceQuery: def __init__( diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index b2dfcc7ccc..220c3291f7 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -13,7 +13,6 @@ from sqlmesh.core import dialect as d from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport -from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObject from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation, NestedSupport from sqlmesh.utils import columns_to_types_to_struct @@ -21,8 +20,6 @@ from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError from tests.core.engine_adapter import to_sql_calls -if t.TYPE_CHECKING: - pass pytestmark = pytest.mark.engine @@ -482,7 +479,8 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): def test_insert_overwrite_by_condition_column_contains_unsafe_characters( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): - adapter = make_mocked_engine_adapter(InsertOverwriteWithMergeMixin) + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( parse_one("SELECT 1 AS c"), None, target_table="test_table" diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index 5923afa217..a405bb7576 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -16,7 +16,6 @@ from sqlmesh.core.engine_adapter.shared import ( DataObject, DataObjectType, - InsertOverwriteStrategy, ) from sqlmesh.utils.date import to_ds from tests.core.engine_adapter import to_sql_calls @@ -342,46 +341,6 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_exi ] -def test_insert_overwrite_by_time_partition_replace_where_pandas( - make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable -): - mocker.patch( - "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists", - return_value=False, - ) - - adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) - adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE - - temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") - table_name = "test_table" - temp_table_id = "abcdefgh" - temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) - - df = pd.DataFrame({"a": [1, 2], "ds": ["2022-01-01", "2022-01-02"]}) - adapter.insert_overwrite_by_time_partition( - table_name, - df, - start="2022-01-01", - end="2022-01-02", - time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - time_column="ds", - target_columns_to_types={ - "a": exp.DataType.build("INT"), - "ds": exp.DataType.build("STRING"), - }, - ) - adapter._connection_pool.get().bulk_copy.assert_called_with( - f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")] - ) - - assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""", - f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""", - f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];", - ] - - def test_insert_append_pandas( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable ):