diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index afaf0e080b..2c7687c57c 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -303,7 +303,8 @@ workflows: - bigquery - clickhouse-cloud - athena - - fabric + # todo: enable fabric when cicd catalog create/drop implemented in manage-test-db.sh + #- fabric filters: branches: only: diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index a14de94cba..f912e76dd3 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1735,7 +1735,9 @@ def create_fabric_connection( def _extra_engine_config(self) -> t.Dict[str, t.Any]: return { "database": self.database, - "catalog_support": CatalogSupport.FULL_SUPPORT, + # more operations than not require a specific catalog to be already active + # in particular, create/drop view, create/drop schema and querying information_schema + "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG, "workspace_id": self.workspace_id, "tenant_id": self.tenant_id, "user": self.user, diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py index 684bae1e08..b2764e79b1 100644 --- a/sqlmesh/core/engine_adapter/fabric.py +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -3,16 +3,20 @@ import typing as t import logging import requests +from functools import cached_property from sqlglot import exp from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter -from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery +from sqlmesh.core.engine_adapter.shared import ( + InsertOverwriteStrategy, + SourceQuery, +) from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.connection_pool import ConnectionPool if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName, SchemaName + from sqlmesh.core._typing import TableName from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin @@ -34,92 +38,24 @@ class FabricEngineAdapter(LogicalMergeMixin, MSSQLEngineAdapter): def __init__( self, connection_factory_or_pool: t.Union[t.Callable, t.Any], *args: t.Any, **kwargs: t.Any ) -> None: - # Wrap connection factory to support catalog switching + # Wrap connection factory to support changing the catalog dynamically at runtime if not isinstance(connection_factory_or_pool, ConnectionPool): original_connection_factory = connection_factory_or_pool - def catalog_aware_factory(*args: t.Any, **kwargs: t.Any) -> t.Any: - # Try to pass target_catalog if the factory accepts it - try: - return original_connection_factory( - target_catalog=self._target_catalog, *args, **kwargs - ) - except TypeError: - # Factory doesn't accept target_catalog, call without it - return original_connection_factory(*args, **kwargs) - - connection_factory_or_pool = catalog_aware_factory + connection_factory_or_pool = lambda *args, **kwargs: original_connection_factory( + target_catalog=self._target_catalog, *args, **kwargs + ) super().__init__(connection_factory_or_pool, *args, **kwargs) @property def _target_catalog(self) -> t.Optional[str]: - """Thread-local target catalog storage.""" return self._connection_pool.get_attribute("target_catalog") @_target_catalog.setter def _target_catalog(self, value: t.Optional[str]) -> None: - """Thread-local target catalog storage.""" self._connection_pool.set_attribute("target_catalog", value) - def _switch_to_catalog_if_needed( - self, table_or_name: t.Union[exp.Table, TableName, SchemaName] - ) -> exp.Table: - # Switch catalog context if needed for cross-catalog operations - table = exp.to_table(table_or_name) - - if table.catalog: - catalog_name = table.catalog - logger.debug(f"Switching to catalog '{catalog_name}' for operation") - self.set_current_catalog(catalog_name) - - # Return table without catalog for SQL generation - return exp.Table(this=table.name, db=table.db) - - return table - - def _handle_schema_with_catalog(self, schema_name: SchemaName) -> t.Tuple[t.Optional[str], str]: - # Parse and handle catalog-qualified schema names for cross-catalog operations - # Handle Table objects created by schema_() function - if isinstance(schema_name, exp.Table) and not schema_name.name: - # This is a schema Table object - check for catalog qualification - if schema_name.catalog: - # Catalog-qualified schema: catalog.schema - catalog_name = schema_name.catalog - schema_only = schema_name.db - logger.debug( - f"Detected catalog-qualified schema: catalog='{catalog_name}', schema='{schema_only}'" - ) - # Switch to the catalog first - self.set_current_catalog(catalog_name) - return catalog_name, schema_only - # Schema only, no catalog - schema_only = schema_name.db - logger.debug(f"Detected schema-only: schema='{schema_only}'") - return None, schema_only - # Handle string or table name inputs by parsing as table - table = exp.to_table(schema_name) - - if table.catalog: - # 3-part name detected (catalog.db.table) - this shouldn't happen for schema operations - raise SQLMeshError( - f"Invalid schema name format: {schema_name}. Expected 'schema' or 'catalog.schema', got 3-part name" - ) - elif table.db: - # Catalog-qualified schema: catalog.schema - catalog_name = table.db - schema_only = table.name - logger.debug( - f"Detected catalog.schema format: catalog='{catalog_name}', schema='{schema_only}'" - ) - # Switch to the catalog first - self.set_current_catalog(catalog_name) - return catalog_name, schema_only - else: - # No catalog qualification, use as-is - logger.debug(f"No catalog detected, using original: {schema_name}") - return None, str(schema_name) - def _insert_overwrite_by_condition( self, table_name: TableName, @@ -140,219 +76,52 @@ def _insert_overwrite_by_condition( **kwargs, ) - def _get_access_token(self) -> str: - """Get access token using Service Principal authentication.""" - tenant_id = self._extra_config.get("tenant_id") - client_id = self._extra_config.get("user") - client_secret = self._extra_config.get("password") - - if not all([tenant_id, client_id, client_secret]): + @property + def api_client(self) -> FabricHttpClient: + # the requests Session is not guaranteed to be threadsafe + # so we create a http client per thread on demand + if existing_client := self._connection_pool.get_attribute("api_client"): + return existing_client + + tenant_id: t.Optional[str] = self._extra_config.get("tenant_id") + workspace_id: t.Optional[str] = self._extra_config.get("workspace_id") + client_id: t.Optional[str] = self._extra_config.get("user") + client_secret: t.Optional[str] = self._extra_config.get("password") + + if not tenant_id or not client_id or not client_secret: raise SQLMeshError( "Service Principal authentication requires tenant_id, client_id, and client_secret " "in the Fabric connection configuration" ) - # Use Azure AD OAuth2 token endpoint - token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" - - data = { - "grant_type": "client_credentials", - "client_id": client_id, - "client_secret": client_secret, - "scope": "https://api.fabric.microsoft.com/.default", - } - - try: - response = requests.post(token_url, data=data) - response.raise_for_status() - token_data = response.json() - return token_data["access_token"] - except requests.exceptions.RequestException as e: - raise SQLMeshError(f"Failed to authenticate with Azure AD: {e}") - except KeyError: - raise SQLMeshError("Invalid response from Azure AD token endpoint") - - def _get_fabric_auth_headers(self) -> t.Dict[str, str]: - """Get authentication headers for Fabric REST API calls.""" - access_token = self._get_access_token() - return {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"} - - def _make_fabric_api_request( - self, - method: str, - endpoint: str, - data: t.Optional[t.Dict[str, t.Any]] = None, - include_response_headers: bool = False, - ) -> t.Dict[str, t.Any]: - """Make a request to the Fabric REST API.""" - - workspace_id = self._extra_config.get("workspace_id") if not workspace_id: raise SQLMeshError( - "workspace_id parameter is required in connection config for Fabric catalog operations" + "Fabric requires the workspace_id to be configured in the connection configuration to create / drop catalogs" ) - base_url = "https://api.fabric.microsoft.com/v1" - url = f"{base_url}/workspaces/{workspace_id}/{endpoint}" - - headers = self._get_fabric_auth_headers() - - try: - if method.upper() == "GET": - response = requests.get(url, headers=headers) - elif method.upper() == "POST": - response = requests.post(url, headers=headers, json=data) - elif method.upper() == "DELETE": - response = requests.delete(url, headers=headers) - else: - raise SQLMeshError(f"Unsupported HTTP method: {method}") - - response.raise_for_status() - - if include_response_headers: - result: t.Dict[str, t.Any] = {"status_code": response.status_code} - - # Extract location header for polling - if "location" in response.headers: - result["location"] = response.headers["location"] - - # Include response body if present - if response.content: - json_data = response.json() - if json_data: - result.update(json_data) - - return result - if response.status_code == 204: # No content - return {} - - return response.json() if response.content else {} - - except requests.exceptions.HTTPError as e: - error_details = "" - try: - if response.content: - error_response = response.json() - error_details = error_response.get("error", {}).get( - "message", str(error_response) - ) - except (ValueError, AttributeError): - error_details = response.text if hasattr(response, "text") else str(e) - - raise SQLMeshError(f"Fabric API HTTP error ({response.status_code}): {error_details}") - except requests.exceptions.RequestException as e: - raise SQLMeshError(f"Fabric API request failed: {e}") - - def _make_fabric_api_request_with_location( - self, method: str, endpoint: str, data: t.Optional[t.Dict[str, t.Any]] = None - ) -> t.Dict[str, t.Any]: - """Make a request to the Fabric REST API and return response with status code and location.""" - return self._make_fabric_api_request(method, endpoint, data, include_response_headers=True) - - @retry( - wait=wait_exponential(multiplier=1, min=1, max=30), - stop=stop_after_attempt(60), - retry=retry_if_result(lambda result: result not in ["Succeeded", "Failed"]), - ) - def _check_operation_status(self, location_url: str, operation_name: str) -> str: - """Check the operation status and return the status string.""" - - headers = self._get_fabric_auth_headers() - - try: - response = requests.get(location_url, headers=headers) - response.raise_for_status() - - result = response.json() - status = result.get("status", "Unknown") - - logger.info(f"Operation {operation_name} status: {status}") - - if status == "Failed": - error_msg = result.get("error", {}).get("message", "Unknown error") - raise SQLMeshError(f"Operation {operation_name} failed: {error_msg}") - elif status in ["InProgress", "Running"]: - logger.info(f"Operation {operation_name} still in progress...") - elif status not in ["Succeeded"]: - logger.warning(f"Unknown status '{status}' for operation {operation_name}") - - return status - - except requests.exceptions.RequestException as e: - logger.warning(f"Failed to poll status: {e}") - raise SQLMeshError(f"Failed to poll operation status: {e}") + client = FabricHttpClient( + tenant_id=tenant_id, + workspace_id=workspace_id, + client_id=client_id, + client_secret=client_secret, + ) - def _poll_operation_status(self, location_url: str, operation_name: str) -> None: - """Poll the operation status until completion.""" - try: - final_status = self._check_operation_status(location_url, operation_name) - if final_status != "Succeeded": - raise SQLMeshError( - f"Operation {operation_name} completed with status: {final_status}" - ) - except Exception as e: - if "retry" in str(e).lower(): - raise SQLMeshError(f"Operation {operation_name} did not complete within timeout") - raise + self._connection_pool.set_attribute("api_client", client) + return client def _create_catalog(self, catalog_name: exp.Identifier) -> None: """Create a catalog (warehouse) in Microsoft Fabric via REST API.""" warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) logger.info(f"Creating Fabric warehouse: {warehouse_name}") - request_data = { - "displayName": warehouse_name, - "description": f"Warehouse created by SQLMesh: {warehouse_name}", - } - - response = self._make_fabric_api_request_with_location("POST", "warehouses", request_data) - - # Handle direct success (201) or async creation (202) - if response.get("status_code") == 201: - logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") - return - - if response.get("status_code") == 202 and response.get("location"): - logger.info(f"Warehouse creation initiated for: {warehouse_name}") - self._poll_operation_status(response["location"], warehouse_name) - logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") - else: - raise SQLMeshError(f"Unexpected response from warehouse creation: {response}") + self.api_client.create_warehouse(warehouse_name) def _drop_catalog(self, catalog_name: exp.Identifier) -> None: """Drop a catalog (warehouse) in Microsoft Fabric via REST API.""" warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) logger.info(f"Deleting Fabric warehouse: {warehouse_name}") - - try: - # Get the warehouse ID by listing warehouses - warehouses = self._make_fabric_api_request("GET", "warehouses") - - warehouse_id = next( - ( - warehouse.get("id") - for warehouse in warehouses.get("value", []) - if warehouse.get("displayName") == warehouse_name - ), - None, - ) - - if not warehouse_id: - logger.info(f"Fabric warehouse does not exist: {warehouse_name}") - return - - # Delete the warehouse by ID - self._make_fabric_api_request("DELETE", f"warehouses/{warehouse_id}") - logger.info(f"Successfully deleted Fabric warehouse: {warehouse_name}") - - except SQLMeshError as e: - error_msg = str(e).lower() - if "not found" in error_msg or "does not exist" in error_msg: - logger.info(f"Fabric warehouse does not exist: {warehouse_name}") - return - logger.error(f"Failed to delete Fabric warehouse {warehouse_name}: {e}") - raise + self.api_client.delete_warehouse(warehouse_name) def set_current_catalog(self, catalog_name: str) -> None: """ @@ -382,158 +151,141 @@ def set_current_catalog(self, catalog_name: str) -> None: logger.info(f"Switching from catalog '{current_catalog}' to '{catalog_name}'") - # Set the target catalog for our custom connection factory - self._target_catalog = catalog_name + # note: we call close() on the connection pool instead of self.close() because self.close() calls close_all() + # on the connection pool but we just want to close the connection for this thread + self._connection_pool.close() + self._target_catalog = catalog_name # new connections will use this catalog - # Save the target catalog before closing (close() clears thread-local storage) - target_catalog = self._target_catalog + catalog_after_switch = self.get_current_catalog() - # Close all existing connections since Fabric requires reconnection for catalog changes - self.close() + if catalog_after_switch != catalog_name: + # We need to raise an error if the catalog switch failed to prevent the operation that needed the catalog switch from being run against the wrong catalog + raise SQLMeshError( + f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}" + ) - # Restore the target catalog after closing - self._target_catalog = target_catalog - # Verify the catalog switch worked by getting a new connection - try: - actual_catalog = self.get_current_catalog() - if actual_catalog and actual_catalog == catalog_name: - logger.debug(f"Successfully switched to catalog '{catalog_name}'") - else: - logger.warning( - f"Catalog switch may have failed. Expected '{catalog_name}', got '{actual_catalog}'" - ) - except Exception as e: - logger.debug(f"Could not verify catalog switch: {e}") +class FabricHttpClient: + def __init__(self, tenant_id: str, workspace_id: str, client_id: str, client_secret: str): + self.tenant_id = tenant_id + self.client_id = client_id + self.client_secret = client_secret + self.workspace_id = workspace_id - logger.debug(f"Updated target catalog to '{catalog_name}' and closed connections") + def create_warehouse(self, warehouse_name: str) -> None: + """Create a catalog (warehouse) in Microsoft Fabric via REST API.""" + logger.info(f"Creating Fabric warehouse: {warehouse_name}") - def drop_schema( - self, - schema_name: SchemaName, - ignore_if_not_exists: bool = True, - cascade: bool = False, - **drop_args: t.Any, - ) -> None: - """ - Override drop_schema to handle catalog-qualified schema names. - Fabric doesn't support 'DROP SCHEMA [catalog].[schema]' syntax. - """ - logger.debug(f"drop_schema called with: {schema_name} (type: {type(schema_name)})") + request_data = { + "displayName": warehouse_name, + "description": f"Warehouse created by SQLMesh: {warehouse_name}", + } - # Use helper to handle catalog switching and get schema name - catalog_name, schema_only = self._handle_schema_with_catalog(schema_name) + response = self.session.post(self._endpoint_url("warehouses"), json=request_data) + response.raise_for_status() - # Use just the schema name for the operation - super().drop_schema(schema_only, ignore_if_not_exists, cascade, **drop_args) + # Handle direct success (201) or async creation (202) + if response.status_code == 201: + logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") + return - def create_schema( - self, - schema_name: SchemaName, - ignore_if_exists: bool = True, - **kwargs: t.Any, - ) -> None: - """ - Override create_schema to handle catalog-qualified schema names. - Fabric doesn't support 'CREATE SCHEMA [catalog].[schema]' syntax. - """ - # Use helper to handle catalog switching and get schema name - catalog_name, schema_only = self._handle_schema_with_catalog(schema_name) + if response.status_code == 202 and (location_header := response.headers.get("location")): + logger.info(f"Warehouse creation initiated for: {warehouse_name}") + self._wait_for_completion(location_header, warehouse_name) + logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") + else: + logger.error(f"Unexpected response from Fabric API: {response}\n{response.text}") + raise SQLMeshError(f"Unable to create warehouse: {response}") - # Use just the schema name for the operation - super().create_schema(schema_only, ignore_if_exists, **kwargs) + def delete_warehouse(self, warehouse_name: str) -> None: + """Drop a catalog (warehouse) in Microsoft Fabric via REST API.""" + logger.info(f"Deleting Fabric warehouse: {warehouse_name}") - def _ensure_schema_exists(self, table_name: TableName) -> None: - """ - Ensure that the schema for a table exists before creating the table. - This is necessary for Fabric because schemas must exist before tables can be created in them. - """ - table = exp.to_table(table_name) - if table.db: - schema_name = table.db - catalog_name = table.catalog + # Get the warehouse ID by listing warehouses + response = self.session.get(self._endpoint_url("warehouses")) + response.raise_for_status() - # Build the full schema name - full_schema_name = f"{catalog_name}.{schema_name}" if catalog_name else schema_name + warehouse_name_to_id = { + warehouse.get("displayName"): warehouse.get("id") + for warehouse in response.json().get("value", []) + } - logger.debug(f"Ensuring schema exists: {full_schema_name}") + warehouse_id = warehouse_name_to_id.get(warehouse_name, None) - try: - # Create the schema if it doesn't exist - self.create_schema(full_schema_name, ignore_if_exists=True) - except Exception as e: - logger.debug(f"Error creating schema {full_schema_name}: {e}") - # Continue anyway - the schema might already exist or we might not have permissions + if not warehouse_id: + logger.error( + f"Fabric warehouse does not exist: {warehouse_name}\n(available warehouses: {', '.join(warehouse_name_to_id)})" + ) + raise SQLMeshError( + f"Unable to delete Fabric warehouse {warehouse_name} as it doesnt exist" + ) - def _create_table( - self, - table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], - exists: bool = True, - replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - table_description: t.Optional[str] = None, - column_descriptions: t.Optional[t.Dict[str, str]] = None, - table_kind: t.Optional[str] = None, - **kwargs: t.Any, - ) -> None: - """ - Override _create_table to ensure schema exists before creating tables. - """ - # Extract table name for schema creation - if isinstance(table_name_or_schema, exp.Schema): - table_name = table_name_or_schema.this - else: - table_name = table_name_or_schema + # Delete the warehouse by ID + response = self.session.delete(self._endpoint_url(f"warehouses/{warehouse_id}")) + response.raise_for_status() - # Ensure the schema exists before creating the table - self._ensure_schema_exists(table_name) + logger.info(f"Successfully deleted Fabric warehouse: {warehouse_name}") - # Call the parent implementation - super()._create_table( - table_name_or_schema=table_name_or_schema, - expression=expression, - exists=exists, - replace=replace, - columns_to_types=columns_to_types, - table_description=table_description, - column_descriptions=column_descriptions, - table_kind=table_kind, - **kwargs, - ) + @cached_property + def session(self) -> requests.Session: + s = requests.Session() - def create_view( - self, - view_name: SchemaName, - query_or_df: t.Any, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - replace: bool = True, - materialized: bool = False, - materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, - table_description: t.Optional[str] = None, - column_descriptions: t.Optional[t.Dict[str, str]] = None, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - **create_kwargs: t.Any, - ) -> None: - """ - Override create_view to handle catalog-qualified view names and ensure schema exists. - Fabric doesn't support 'CREATE VIEW [catalog].[schema].[view]' syntax. - """ - # Switch to catalog if needed and get unqualified table - unqualified_view = self._switch_to_catalog_if_needed(view_name) - - # Ensure schema exists for the view - self._ensure_schema_exists(unqualified_view) - - super().create_view( - unqualified_view, - query_or_df, - columns_to_types, - replace, - materialized, - materialized_properties, - table_description, - column_descriptions, - view_properties, - **create_kwargs, + access_token = self._get_access_token() + s.headers.update({"Authorization": f"Bearer {access_token}"}) + + return s + + def _endpoint_url(self, endpoint: str) -> str: + if endpoint.startswith("/"): + endpoint = endpoint[1:] + + return f"https://api.fabric.microsoft.com/v1/workspaces/{self.workspace_id}/{endpoint}" + + def _get_access_token(self) -> str: + """Get access token using Service Principal authentication.""" + + # Use Azure AD OAuth2 token endpoint + token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token" + + data = { + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + "scope": "https://api.fabric.microsoft.com/.default", + } + + response = requests.post(token_url, data=data) + response.raise_for_status() + token_data = response.json() + return token_data["access_token"] + + def _wait_for_completion(self, location_url: str, operation_name: str) -> None: + """Poll the operation status until completion.""" + + @retry( + wait=wait_exponential(multiplier=1, min=1, max=30), + stop=stop_after_attempt(20), + retry=retry_if_result(lambda result: result not in ["Succeeded", "Failed"]), ) + def _poll() -> str: + response = self.session.get(location_url) + response.raise_for_status() + + result = response.json() + status = result.get("status", "Unknown") + + logger.debug(f"Operation {operation_name} status: {status}") + + if status == "Failed": + error_msg = result.get("error", {}).get("message", "Unknown error") + raise SQLMeshError(f"Operation {operation_name} failed: {error_msg}") + elif status in ["InProgress", "Running"]: + logger.debug(f"Operation {operation_name} still in progress...") + elif status not in ["Succeeded"]: + logger.warning(f"Unknown status '{status}' for operation {operation_name}") + + return status + + final_status = _poll() + if final_status != "Succeeded": + raise SQLMeshError(f"Operation {operation_name} completed with status: {final_status}") diff --git a/tests/conftest.py b/tests/conftest.py index ad09deff6f..01fef852f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -478,7 +478,7 @@ def _make_function( connection_mock.cursor.return_value = cursor_mock cursor_mock.connection.return_value = connection_mock adapter = klass( - lambda: connection_mock, + lambda *args, **kwargs: connection_mock, dialect=dialect or klass.DIALECT, register_comments=register_comments, default_catalog=default_catalog, diff --git a/tests/core/engine_adapter/test_fabric.py b/tests/core/engine_adapter/test_fabric.py index 0d283fe064..0ae036bec9 100644 --- a/tests/core/engine_adapter/test_fabric.py +++ b/tests/core/engine_adapter/test_fabric.py @@ -3,10 +3,12 @@ import typing as t import pytest +from pytest_mock import MockerFixture from sqlglot import exp, parse_one from sqlmesh.core.engine_adapter import FabricEngineAdapter from tests.core.engine_adapter import to_sql_calls +from sqlmesh.core.engine_adapter.shared import DataObject pytestmark = [pytest.mark.engine, pytest.mark.fabric] @@ -71,13 +73,18 @@ def test_insert_overwrite_by_time_partition(adapter: FabricEngineAdapter): ] -def test_replace_query(adapter: FabricEngineAdapter): - adapter.cursor.fetchone.return_value = (1,) - adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"}) +def test_replace_query(adapter: FabricEngineAdapter, mocker: MockerFixture): + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) + adapter.replace_query( + "test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("int")} + ) # This behavior is inherited from MSSQLEngineAdapter and should be TRUNCATE + INSERT assert to_sql_calls(adapter) == [ - """SELECT 1 FROM [INFORMATION_SCHEMA].[TABLES] WHERE [TABLE_NAME] = 'test_table';""", "TRUNCATE TABLE [test_table];", "INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];", ]