Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions juju/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import ipaddress
import logging
import typing

import pyrfc3339

from . import model, tag, jasyncio
from . import jasyncio, model, tag
from .annotationhelper import _get_annotations, _set_annotations
from .client import client
from .errors import JujuError
from juju.utils import juju_ssh_key_paths
from juju.utils import juju_ssh_key_paths, block_until

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,7 +71,7 @@ def _format_addr(self, addr):
return fmt.format(ipaddr)

async def scp_to(self, source, destination, user='ubuntu', proxy=False,
scp_opts=''):
scp_opts='', wait_for_active=False, timeout=None):
"""Transfer files to this machine.

:param str source: Local path of file(s) to transfer
Expand All @@ -79,10 +80,13 @@ async def scp_to(self, source, destination, user='ubuntu', proxy=False,
:param bool proxy: Proxy through the Juju API server
:param scp_opts: Additional options to the `scp` command
:type scp_opts: str or list
:param bool wait_for_active: Wait until the machine is ready to take in ssh commands.
:param int timeout: Time in seconds to wait until the machine becomes ready.
"""
if proxy:
raise NotImplementedError('proxy option is not implemented')

if wait_for_active:
await block_until(lambda: self.addresses, timeout=timeout)
try:
# if dns_name is an IP address format it appropriately
address = self._format_addr(self.dns_name)
Expand All @@ -93,7 +97,7 @@ async def scp_to(self, source, destination, user='ubuntu', proxy=False,
await self._scp(source, destination, scp_opts)

async def scp_from(self, source, destination, user='ubuntu', proxy=False,
scp_opts=''):
scp_opts='', wait_for_active=False, timeout=None):
"""Transfer files from this machine.

:param str source: Remote path of file(s) to transfer
Expand All @@ -102,10 +106,13 @@ async def scp_from(self, source, destination, user='ubuntu', proxy=False,
:param bool proxy: Proxy through the Juju API server
:param scp_opts: Additional options to the `scp` command
:type scp_opts: str or list
:param bool wait_for_active: Wait until the machine is ready to take in ssh commands.
:param int timeout: Time in seconds to wait until the machine becomes ready.
"""
if proxy:
raise NotImplementedError('proxy option is not implemented')

if wait_for_active:
await block_until(lambda: self.addresses, timeout=timeout)
try:
# if dns_name is an IP address format it appropriately
address = self._format_addr(self.dns_name)
Expand All @@ -129,23 +136,37 @@ async def _scp(self, source, destination, scp_opts):
]
cmd.extend(scp_opts.split() if isinstance(scp_opts, str) else scp_opts)
cmd.extend([source, destination])
process = await jasyncio.create_subprocess_exec(*cmd)
await process.wait()
# There's a bit of a gap between the time that the machine is assigned an IP and the ssh
# service is up and listening, which creates a race for the ssh command. So we retry a
# couple of times until either we run out of attempts, or the ssh command succeeds to
# mitigate that effect.
# TODO (cderici): refactor the ssh and scp subcommand processing into a single method.
retry_backoff = 2
retries = 10
for _ in range(retries):
process = await jasyncio.create_subprocess_exec(*cmd)
await process.wait()
if process.returncode == 0:
break
await jasyncio.sleep(retry_backoff)
if process.returncode != 0:
raise JujuError("command failed: %s" % cmd)
raise JujuError(f"command failed after {retries} attempts: {cmd}")

async def ssh(
self, command, user='ubuntu', proxy=False, ssh_opts=None):
self, command, user='ubuntu', proxy=False, ssh_opts=None, wait_for_active=False, timeout=None):
"""Execute a command over SSH on this machine.

:param str command: Command to execute
:param str user: Remote username
:param bool proxy: Proxy through the Juju API server
:param str ssh_opts: Additional options to the `ssh` command

:param bool wait_for_active: Wait until the machine is ready to take in ssh commands.
:param int timeout: Time in seconds to wait until the machine becomes ready.
"""
if proxy:
raise NotImplementedError('proxy option is not implemented')
if wait_for_active:
await block_until(lambda: self.addresses, timeout=timeout)
address = self.dns_name
destination = "{}@{}".format(user, address)
_, id_path = juju_ssh_key_paths()
Expand All @@ -159,14 +180,32 @@ async def ssh(
if ssh_opts:
cmd.extend(ssh_opts.split() if isinstance(ssh_opts, str) else ssh_opts)
cmd.extend([command])
process = await jasyncio.create_subprocess_exec(
*cmd, stdout=jasyncio.subprocess.PIPE, stderr=jasyncio.subprocess.PIPE)
stdout, stderr = await process.communicate()

# There's a bit of a gap between the time that the machine is assigned an IP and the ssh
# service is up and listening, which creates a race for the ssh command. So we retry a
# couple of times until either we run out of attempts, or the ssh command succeeds to
# mitigate that effect.
retry_backoff = 2
retries = 10
for _ in range(retries):
process = await jasyncio.create_subprocess_exec(
*cmd, stdout=jasyncio.subprocess.PIPE, stderr=jasyncio.subprocess.PIPE)
stdout, stderr = await process.communicate()
if process.returncode == 0:
break
await jasyncio.sleep(retry_backoff)
if process.returncode != 0:
raise JujuError("command failed: %s with %s" % (cmd, stderr.decode()))
raise JujuError(f"command failed: {cmd} after {retries} attempts, with {stderr.decode()}")
# stdout is a bytes-like object, returning a string might be more useful
return stdout.decode()

@property
def addresses(self) -> typing.List[str]:
"""Returns the machine addresses.

"""
return self.safe_data['addresses'] or []

@property
def agent_status(self):
"""Returns the current Juju agent status string.
Expand Down Expand Up @@ -221,11 +260,10 @@ def dns_name(self):

May return None if no suitable address is found.
"""
addresses = self.safe_data['addresses'] or []
ordered_addresses = []
ordered_scopes = ['public', 'local-cloud', 'local-fan']
for scope in ordered_scopes:
for address in addresses:
for address in self.addresses:
if scope == address['scope']:
ordered_addresses.append(address)
for address in ordered_addresses:
Expand Down
10 changes: 10 additions & 0 deletions tests/integration/test_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from .. import base
from juju.machine import Machine


@base.bootstrapped
Expand Down Expand Up @@ -36,3 +37,12 @@ async def test_status():
machine.status_message.lower() == 'running' and
machine.agent_status == 'started')),
timeout=480)


@base.bootstrapped
async def test_machine_ssh():
async with base.CleanModel() as model:
machine: Machine = await model.add_machine()
out = await machine.ssh("echo hello world!", wait_for_active=True)

assert out == "hello world!\n"