diff --git a/VERSION b/VERSION index 080c74d0..32dc00b1 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1!10.5.0 +1!10.6.0 diff --git a/pycloudlib/cloud.py b/pycloudlib/cloud.py index f0c9e220..1e433ca2 100644 --- a/pycloudlib/cloud.py +++ b/pycloudlib/cloud.py @@ -66,7 +66,11 @@ def __init__( self.tag = get_timestamped_tag(tag) if timestamp_suffix else tag self._validate_tag(self.tag) - self.key_pair = self._get_ssh_keys() + self.key_pair = self._get_ssh_keys( + public_key_path=self.config.get("public_key_path", ""), + private_key_path=self.config.get("private_key_path", ""), + name=self.config.get("key_name", getpass.getuser()), + ) def __enter__(self): """Enter context manager for this class.""" @@ -251,7 +255,11 @@ def use_key(self, public_key_path, private_key_path=None, name=None): name: name to reference key by """ self._log.debug("using SSH key from %s", public_key_path) - self.key_pair = KeyPair(public_key_path, private_key_path, name) + self.key_pair = self._get_ssh_keys( + public_key_path=public_key_path, + private_key_path=private_key_path, + name=name, + ) def _check_and_set_config( self, @@ -310,14 +318,37 @@ def _validate_tag(tag: str): if rules_failed: raise InvalidTagNameError(tag=tag, rules_failed=rules_failed) - def _get_ssh_keys(self) -> KeyPair: - user = getpass.getuser() - # check if id_rsa or id_ed25519 keys exist in the user's .ssh directory + def _get_ssh_keys( + self, + public_key_path: Optional[str] = None, + private_key_path: Optional[str] = None, + name: Optional[str] = None, + ) -> KeyPair: + """Retrieve SSH key pair paths. + + This method attempts to retrieve the paths to the public and private SSH keys. + If no public key path is provided, it will look for default keys in the user's + `~/.ssh` directory. If no keys are found, it logs a warning and returns a KeyPair + with None values. + + Args: + public_key_path (Optional[str]): The path to the public SSH key. If not provided, + the method will search for default keys. + private_key_path (Optional[str]): The path to the private SSH key. Defaults to None. + name (Optional[str]): An optional name for the key pair. Defaults to None. + + Returns: + KeyPair: An instance of KeyPair containing the paths to the public and private keys, + and the optional name. + + Raises: + PycloudlibError: If the provided public key path does not exist. + """ possible_default_keys = [ os.path.expanduser("~/.ssh/id_rsa.pub"), os.path.expanduser("~/.ssh/id_ed25519.pub"), ] - public_key_path: Optional[str] = os.path.expanduser(self.config.get("public_key_path", "")) + public_key_path = os.path.expanduser(public_key_path or "") if not public_key_path: for pubkey in possible_default_keys: if os.path.exists(pubkey): @@ -325,18 +356,19 @@ def _get_ssh_keys(self) -> KeyPair: public_key_path = pubkey break if not public_key_path: - raise PycloudlibError( + self._log.warning( "No public key path provided and no key found in default locations: " - "'~/.ssh/id_rsa.pub' or '~/.ssh/id_ed25519.pub'" + "'~/.ssh/id_rsa.pub' or '~/.ssh/id_ed25519.pub'. SSH key authentication will " + "not be possible unless a key is later provided with the 'use_key' method." ) + return KeyPair(None, None, None) if not os.path.exists(os.path.expanduser(public_key_path)): raise PycloudlibError(f"Provided public key path '{public_key_path}' does not exist") if public_key_path not in possible_default_keys: self._log.info("Using provided public key path: '%s'", public_key_path) - private_key_path = self.config.get("private_key_path", "") return KeyPair( public_key_path=public_key_path, private_key_path=private_key_path, - name=self.config.get("key_name", user), + name=name, ) diff --git a/pycloudlib/errors.py b/pycloudlib/errors.py index b4502876..82747634 100644 --- a/pycloudlib/errors.py +++ b/pycloudlib/errors.py @@ -156,3 +156,18 @@ def __init__(self, tag: str, rules_failed: List[str]): def __str__(self) -> str: """Return string representation of the error.""" return f"Tag '{self.tag}' failed the following rules: {', '.join(self.rules_failed)}" + + +class UnsetSSHKeyError(PycloudlibException): + """Raised when a SSH key is unset and no default key can be found.""" + + def __str__(self) -> str: + """Return string representation of the error.""" + return ( + "No public key content available for unset key pair. This error occurs when no SSH " + "key is provided in the pycloudlib.toml file and no default keys can be found on " + "the system. If you wish to provide custom SSH keys at runtime, you can do so by " + "calling the `use_key` method on the `Cloud` class. If you wish to use default SSH " + "keys, make sure they are present on the system and that they are located in the " + "default locations." + ) diff --git a/pycloudlib/key.py b/pycloudlib/key.py index 7690e21c..5ecce784 100644 --- a/pycloudlib/key.py +++ b/pycloudlib/key.py @@ -2,12 +2,20 @@ """Base Key Class.""" import os +from typing import Optional + +from pycloudlib.errors import UnsetSSHKeyError class KeyPair: """Key Class.""" - def __init__(self, public_key_path, private_key_path=None, name=None): + def __init__( + self, + public_key_path: Optional[str], + private_key_path: Optional[str] = None, + name: Optional[str] = None, + ): """Initialize key class to generate key or reuse existing key. The public key path is given then the key is stored and the @@ -21,11 +29,15 @@ def __init__(self, public_key_path, private_key_path=None, name=None): """ self.name = name self.public_key_path = public_key_path - if private_key_path: - self.private_key_path = private_key_path - else: - self.private_key_path = self.public_key_path.replace(".pub", "") + # don't set private key path if public key path is None (ssh key is unset) + if self.public_key_path is None: + self.private_key_path = None + return + + self.private_key_path = private_key_path or self.public_key_path.replace(".pub", "") + + # Expand user paths after setting private key path self.public_key_path = os.path.expanduser(self.public_key_path) self.private_key_path = os.path.expanduser(self.private_key_path) @@ -40,7 +52,8 @@ def public_key_content(self): """Read the contents of the public key. Returns: - output of public key - + str: The public key content """ + if self.public_key_path is None: + raise UnsetSSHKeyError() return open(self.public_key_path, encoding="utf-8").read() diff --git a/tests/unit_tests/test_cloud.py b/tests/unit_tests/test_cloud.py index df8836dc..158e6703 100644 --- a/tests/unit_tests/test_cloud.py +++ b/tests/unit_tests/test_cloud.py @@ -1,14 +1,15 @@ """Tests related to pycloudlib.cloud module.""" from io import StringIO +import logging from textwrap import dedent -from typing import List +from typing import List, Optional import mock import pytest from pycloudlib.cloud import BaseCloud -from pycloudlib.errors import InvalidTagNameError +from pycloudlib.errors import InvalidTagNameError, UnsetSSHKeyError # mock module path MPATH = "pycloudlib.cloud." @@ -181,6 +182,30 @@ def test_missing_private_key_in_ssh_config(self, _m_expanduser, _m_exists): assert mycloud.key_pair.public_key_path == "/home/asdf/.ssh/id_rsa.pub" assert mycloud.key_pair.private_key_path == "/home/asdf/.ssh/id_rsa" + @pytest.mark.dont_mock_ssh_keys + @mock.patch("os.path.expanduser", side_effect=lambda x: x.replace("~", "/root")) + @mock.patch("os.path.exists", return_value=False) + def test_init_raises_error_when_no_ssh_keys_found( + self, + _m_expanduser, + _m_exists, + caplog, + ): + """ + Test that an error is raised when no SSH keys can be found. + + This test verifies that an error is raised when no SSH keys can be found in the default + locations and no public key path is provided in the config file. + """ + # set log level to Warning to ensure warning gets logged + caplog.set_level(logging.WARNING) + with pytest.raises(UnsetSSHKeyError) as exc_info: + cloud = CloudSubclass(tag="tag", timestamp_suffix=False, config_file=StringIO(CONFIG)) + # now we try to access the public key content to trigger the exception + cloud.key_pair.public_key_content + assert "No public key path provided and no key found in default locations" in caplog.text + assert "No public key content available for unset key pair." in str(exc_info.value) + rule1 = "All letters must be lowercase" rule2 = "Must be between 1 and 63 characters long" rule3 = "Must not start or end with a hyphen"