From c519fbbf9489e39e3f1cf6d73a34fd43e596ae26 Mon Sep 17 00:00:00 2001 From: jakedave Date: Mon, 20 Mar 2023 13:08:59 -0600 Subject: [PATCH 1/4] minimum required changes for priv key auth --- data_diff/dbt.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 0baec9d5..b40d34f3 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -8,6 +8,11 @@ from typing import List, Optional, Dict, Tuple from pathlib import Path +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives import serialization + import requests @@ -366,22 +371,56 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]: return credentials, conn_type + @staticmethod + def _get_snowflake_private_key(credentials): + """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" + if credentials.get("private_key") and credentials.get("private_key_path"): + raise Exception("Cannot specify both `private_key` and `private_key_path`") + + if credentials.get("private_key_passphrase"): + encoded_passphrase = credentials.get("private_key_passphrase").encode() + else: + encoded_passphrase = None + + if credentials.get("private_key"): + p_key = serialization.load_der_private_key( + base64.b64decode(credentials.get("private_key")), + password=encoded_passphrase, + backend=default_backend(), + ) + elif credentials.get("private_key_path"): + with open(credentials.get("private_key_path"), "rb") as key: + p_key = serialization.load_pem_private_key( + key.read(), password=encoded_passphrase, backend=default_backend() + ) + else: + return None + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + def set_connection(self): credentials, conn_type = self._get_connection_creds() if conn_type == "snowflake": - if credentials.get("password") is None or credentials.get("private_key_path") is not None: - raise Exception("Only password authentication is currently supported for Snowflake.") + if credentials.get("authenticator") is not None: + raise Exception("Federated authentication is not currently supported for Snowflake.") conn_info = { "driver": conn_type, "user": credentials.get("user"), - "password": credentials.get("password"), "account": credentials.get("account"), "database": credentials.get("database"), "warehouse": credentials.get("warehouse"), "role": credentials.get("role"), "schema": credentials.get("schema"), } + if credentials.get("password") is not None: + conn_info["password"] = credentials.get("password") + else: + conn_info["private_key"] = self._get_snowflake_private_key(credentials) self.threads = credentials.get("threads") self.requires_upper = True elif conn_type == "bigquery": From 528d130b77f0414749547666538782b7cdab4a11 Mon Sep 17 00:00:00 2001 From: jakedave Date: Wed, 22 Mar 2023 09:54:08 -0600 Subject: [PATCH 2/4] handle multiple provided auth methods --- data_diff/dbt.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/data_diff/dbt.py b/data_diff/dbt.py index b40d34f3..9de5aa1c 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -374,9 +374,6 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]: @staticmethod def _get_snowflake_private_key(credentials): """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" - if credentials.get("private_key") and credentials.get("private_key_path"): - raise Exception("Cannot specify both `private_key` and `private_key_path`") - if credentials.get("private_key_passphrase"): encoded_passphrase = credentials.get("private_key_passphrase").encode() else: @@ -417,9 +414,14 @@ def set_connection(self): "role": credentials.get("role"), "schema": credentials.get("schema"), } + if credentials.get("password") is not None: + if credentials.get("private_key") is not None or credentials.get("private_key_path") is not None: + raise Exception("Cannot use password and key at the same time") conn_info["password"] = credentials.get("password") else: + if credentials.get("private_key") is not None and credentials.get("private_key_path") is not None: + raise Exception("Cannot specify both `private_key` and `private_key_path`") conn_info["private_key"] = self._get_snowflake_private_key(credentials) self.threads = credentials.get("threads") self.requires_upper = True From 1740c82b049149486f25ca29d56b50871514386b Mon Sep 17 00:00:00 2001 From: jakedave Date: Wed, 22 Mar 2023 10:11:43 -0600 Subject: [PATCH 3/4] more verbose error handling - fix original unit test --- data_diff/dbt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 9de5aa1c..d9bec277 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -374,6 +374,8 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]: @staticmethod def _get_snowflake_private_key(credentials): """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" + if credentials.get("private_key") and credentials.get("private_key_path"): + raise Exception("Cannot specify both `private_key` and `private_key_path`") if credentials.get("private_key_passphrase"): encoded_passphrase = credentials.get("private_key_passphrase").encode() else: @@ -415,14 +417,15 @@ def set_connection(self): "schema": credentials.get("schema"), } - if credentials.get("password") is not None: - if credentials.get("private_key") is not None or credentials.get("private_key_path") is not None: + if credentials.get("private_key") is not None or credentials.get("private_key_path") is not None: + if credentials.get("password") is not None: raise Exception("Cannot use password and key at the same time") + conn_info["private_key"] = self._get_snowflake_private_key(credentials) + elif credentials.get("password") is not None: conn_info["password"] = credentials.get("password") else: - if credentials.get("private_key") is not None and credentials.get("private_key_path") is not None: - raise Exception("Cannot specify both `private_key` and `private_key_path`") - conn_info["private_key"] = self._get_snowflake_private_key(credentials) + raise Exception("Password or key authentication not provided.") + self.threads = credentials.get("threads") self.requires_upper = True elif conn_type == "bigquery": From bbe3a0f3a46463d64a3981866544389840203863 Mon Sep 17 00:00:00 2001 From: jakedave Date: Wed, 22 Mar 2023 10:56:18 -0600 Subject: [PATCH 4/4] unit tests --- tests/test_dbt.py | 49 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 376c7e5c..019eb399 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -144,7 +144,7 @@ def test_set_project_dict(self, mock_open): self.assertEqual(mock_self.project_dict, expected_dict) mock_open.assert_called_once_with(Path(PROJECT_FILE)) - def test_set_connection_snowflake_success(self): + def test_set_connection_snowflake_password_success(self): expected_driver = "snowflake" expected_credentials = {"user": "user", "password": "password"} mock_self = Mock() @@ -158,7 +158,52 @@ def test_set_connection_snowflake_success(self): self.assertEqual(mock_self.connection.get("password"), expected_credentials["password"]) self.assertEqual(mock_self.requires_upper, True) - def test_set_connection_snowflake_no_password(self): + def test_set_connection_snowflake_private_key_success(self): + expected_driver = "snowflake" + expected_credentials = {"user": "user", "private_key": "password", "private_key_passphrase": "pass"} + expected_connection = {"user": "user", "private_key": "password"} + mock_self = Mock() + mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver) + mock_self._get_snowflake_private_key.return_value = expected_connection["private_key"] + + DbtParser.set_connection(mock_self) + + self.assertIsInstance(mock_self.connection, dict) + self.assertEqual(mock_self.connection.get("driver"), expected_driver) + self.assertEqual(mock_self.connection.get("user"), expected_connection["user"]) + self.assertEqual(mock_self.connection.get("private_key"), expected_connection["private_key"]) + self.assertEqual(mock_self.connection.get("private_key_passphrase"), None) + self.assertEqual(mock_self.requires_upper, True) + + def test_set_connection_snowflake_private_key_path_success(self): + expected_driver = "snowflake" + expected_credentials = {"user": "user", "private_key_path": "password", "private_key_passphrase": "pass"} + expected_connection = {"user": "user", "private_key": "password"} + mock_self = Mock() + mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver) + mock_self._get_snowflake_private_key.return_value = expected_connection["private_key"] + + DbtParser.set_connection(mock_self) + + self.assertIsInstance(mock_self.connection, dict) + self.assertEqual(mock_self.connection.get("driver"), expected_driver) + self.assertEqual(mock_self.connection.get("user"), expected_connection["user"]) + self.assertEqual(mock_self.connection.get("private_key"), expected_connection["private_key"]) + self.assertEqual(mock_self.connection.get("private_key_passphrase"), None) + self.assertEqual(mock_self.requires_upper, True) + + def test_set_connection_snowflake_multiple_authentication(self): + expected_driver = "snowflake" + expected_credentials = {"user": "user", "password": "password", "private_key": "password"} + mock_self = Mock() + mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver) + + with self.assertRaises(Exception): + DbtParser.set_connection(mock_self) + + self.assertNotIsInstance(mock_self.connection, dict) + + def test_set_connection_snowflake_no_authentication(self): expected_driver = "snowflake" expected_credentials = {"user": "user"} mock_self = Mock()