diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 0baec9d5..d9bec277 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -8,6 +8,11 @@ from typing import List, Optional, Dict, Tuple from pathlib import Path +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives import serialization + import requests @@ -366,22 +371,61 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]: return credentials, conn_type + @staticmethod + def _get_snowflake_private_key(credentials): + """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" + if credentials.get("private_key") and credentials.get("private_key_path"): + raise Exception("Cannot specify both `private_key` and `private_key_path`") + if credentials.get("private_key_passphrase"): + encoded_passphrase = credentials.get("private_key_passphrase").encode() + else: + encoded_passphrase = None + + if credentials.get("private_key"): + p_key = serialization.load_der_private_key( + base64.b64decode(credentials.get("private_key")), + password=encoded_passphrase, + backend=default_backend(), + ) + elif credentials.get("private_key_path"): + with open(credentials.get("private_key_path"), "rb") as key: + p_key = serialization.load_pem_private_key( + key.read(), password=encoded_passphrase, backend=default_backend() + ) + else: + return None + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + def set_connection(self): credentials, conn_type = self._get_connection_creds() if conn_type == "snowflake": - if credentials.get("password") is None or credentials.get("private_key_path") is not None: - raise Exception("Only password authentication is currently supported for Snowflake.") + if credentials.get("authenticator") is not None: + raise Exception("Federated authentication is not currently supported for Snowflake.") conn_info = { "driver": conn_type, "user": credentials.get("user"), - "password": credentials.get("password"), "account": credentials.get("account"), "database": credentials.get("database"), "warehouse": credentials.get("warehouse"), "role": credentials.get("role"), "schema": credentials.get("schema"), } + + if credentials.get("private_key") is not None or credentials.get("private_key_path") is not None: + if credentials.get("password") is not None: + raise Exception("Cannot use password and key at the same time") + conn_info["private_key"] = self._get_snowflake_private_key(credentials) + elif credentials.get("password") is not None: + conn_info["password"] = credentials.get("password") + else: + raise Exception("Password or key authentication not provided.") + self.threads = credentials.get("threads") self.requires_upper = True elif conn_type == "bigquery": diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 376c7e5c..019eb399 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -144,7 +144,7 @@ def test_set_project_dict(self, mock_open): self.assertEqual(mock_self.project_dict, expected_dict) mock_open.assert_called_once_with(Path(PROJECT_FILE)) - def test_set_connection_snowflake_success(self): + def test_set_connection_snowflake_password_success(self): expected_driver = "snowflake" expected_credentials = {"user": "user", "password": "password"} mock_self = Mock() @@ -158,7 +158,52 @@ def test_set_connection_snowflake_success(self): self.assertEqual(mock_self.connection.get("password"), expected_credentials["password"]) self.assertEqual(mock_self.requires_upper, True) - def test_set_connection_snowflake_no_password(self): + def test_set_connection_snowflake_private_key_success(self): + expected_driver = "snowflake" + expected_credentials = {"user": "user", "private_key": "password", "private_key_passphrase": "pass"} + expected_connection = {"user": "user", "private_key": "password"} + mock_self = Mock() + mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver) + mock_self._get_snowflake_private_key.return_value = expected_connection["private_key"] + + DbtParser.set_connection(mock_self) + + self.assertIsInstance(mock_self.connection, dict) + self.assertEqual(mock_self.connection.get("driver"), expected_driver) + self.assertEqual(mock_self.connection.get("user"), expected_connection["user"]) + self.assertEqual(mock_self.connection.get("private_key"), expected_connection["private_key"]) + self.assertEqual(mock_self.connection.get("private_key_passphrase"), None) + self.assertEqual(mock_self.requires_upper, True) + + def test_set_connection_snowflake_private_key_path_success(self): + expected_driver = "snowflake" + expected_credentials = {"user": "user", "private_key_path": "password", "private_key_passphrase": "pass"} + expected_connection = {"user": "user", "private_key": "password"} + mock_self = Mock() + mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver) + mock_self._get_snowflake_private_key.return_value = expected_connection["private_key"] + + DbtParser.set_connection(mock_self) + + self.assertIsInstance(mock_self.connection, dict) + self.assertEqual(mock_self.connection.get("driver"), expected_driver) + self.assertEqual(mock_self.connection.get("user"), expected_connection["user"]) + self.assertEqual(mock_self.connection.get("private_key"), expected_connection["private_key"]) + self.assertEqual(mock_self.connection.get("private_key_passphrase"), None) + self.assertEqual(mock_self.requires_upper, True) + + def test_set_connection_snowflake_multiple_authentication(self): + expected_driver = "snowflake" + expected_credentials = {"user": "user", "password": "password", "private_key": "password"} + mock_self = Mock() + mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver) + + with self.assertRaises(Exception): + DbtParser.set_connection(mock_self) + + self.assertNotIsInstance(mock_self.connection, dict) + + def test_set_connection_snowflake_no_authentication(self): expected_driver = "snowflake" expected_credentials = {"user": "user"} mock_self = Mock()