diff --git a/src/dataworkbench/datacatalogue.py b/src/dataworkbench/datacatalogue.py index c7c2a69..ca89c70 100644 --- a/src/dataworkbench/datacatalogue.py +++ b/src/dataworkbench/datacatalogue.py @@ -4,7 +4,7 @@ from pyspark.sql import DataFrame -from dataworkbench.utils import get_secret +from dataworkbench.utils import get_secret, SparkDataFrame from dataworkbench.storage import DeltaStorage from dataworkbench.gateway import Gateway @@ -46,9 +46,6 @@ def __build_storage_table_root_url(self, folder_id: uuid.UUID) -> str: if not isinstance(folder_id, uuid.UUID): raise TypeError("folder_id must be uuid") - if not folder_id: - raise ValueError("folder_id cannot be empty") - return f"{self.storage_base_url}/{folder_id}" def __build_storage_table_processed_url(self, folder_id: uuid.UUID) -> str: @@ -110,7 +107,7 @@ def save( ... ) """ # Validate input parameters - if not hasattr(df, "write"): + if not isinstance(df, SparkDataFrame): raise TypeError("df must be a DataFrame") if not isinstance(dataset_name, str) or not dataset_name: diff --git a/src/dataworkbench/gateway.py b/src/dataworkbench/gateway.py index dff246d..14859a3 100644 --- a/src/dataworkbench/gateway.py +++ b/src/dataworkbench/gateway.py @@ -152,8 +152,9 @@ def import_dataset( if e.response is not None else None ) - logger.error(f"Error creating data catalog entry: {e}") - return { - "error": "Failed to create data catalog entry.", - "correlation-id": trace_id, - } + error_msg = ( + f"Failed to create data catalog entry. correlation-id: {trace_id}" + ) + + logger.error(error_msg) + raise type(e)(error_msg) from e diff --git a/src/dataworkbench/utils.py b/src/dataworkbench/utils.py index e69746e..04f14cb 100644 --- a/src/dataworkbench/utils.py +++ b/src/dataworkbench/utils.py @@ -1,5 +1,5 @@ import os -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, DataFrame from dataworkbench.log import setup_logger @@ -70,6 +70,14 @@ def get_secret(key: str, scope: str = "dwsecrets") -> str: return secret +if is_databricks(): + from pyspark.sql.connect.dataframe import DataFrame as DatabricksDataFrame + + SparkDataFrame = DataFrame | DatabricksDataFrame +else: + SparkDataFrame = DataFrame + + # Example usage if __name__ == "__main__": CLIENT_ID = get_secret("ClientId") diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 64707a9..0d5ec2d 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -2,6 +2,7 @@ import requests from unittest.mock import patch, MagicMock from dataworkbench.gateway import Gateway +from requests.exceptions import RequestException import json @pytest.fixture @@ -32,6 +33,7 @@ def test_import_dataset_success(mock_gateway, mock_post): assert result == {"status": "success"} mock_post.assert_called_once() + def test_import_dataset_failure(mock_gateway, mock_post): """Test dataset import failure.""" @@ -47,7 +49,8 @@ def test_import_dataset_failure(mock_gateway, mock_post): mock_response.raise_for_status.side_effect = http_error mock_post.return_value = mock_response - result = mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id") + with pytest.raises(RequestException) as e: + mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id") - assert result == {"error": "Failed to create data catalog entry.", "correlation-id": response_body["traceId"]} + assert e.value.args[0] == f"Failed to create data catalog entry. correlation-id: {response_body['traceId']}" mock_post.assert_called_once()