Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 148 additions & 142 deletions airflow/contrib/hooks/ssh_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

import getpass
import os
import warnings

import paramiko
from paramiko.config import SSH_PORT
from sshtunnel import SSHTunnelForwarder

from contextlib import contextmanager
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self,
username=None,
password=None,
key_file=None,
port=SSH_PORT,
port=None,
timeout=10,
keepalive_interval=30
):
Expand All @@ -75,162 +76,167 @@ def __init__(self,
self.username = username
self.password = password
self.key_file = key_file
self.port = port
self.timeout = timeout
self.keepalive_interval = keepalive_interval

# Default values, overridable from Connection
self.compress = True
self.no_host_key_check = True
self.host_proxy = None

# Placeholder for deprecated __enter__
self.client = None
self.port = port

# Use connection to override defaults
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if self.username is None:
self.username = conn.login
if self.password is None:
self.password = conn.password
if self.remote_host is None:
self.remote_host = conn.host
if self.port is None:
self.port = conn.port
if conn.extra is not None:
extra_options = conn.extra_dejson
self.key_file = extra_options.get("key_file")

if "timeout" in extra_options:
self.timeout = int(extra_options["timeout"], 10)

if "compress" in extra_options\
and str(extra_options["compress"]).lower() == 'false':
self.compress = False
if "no_host_key_check" in extra_options\
and\
str(extra_options["no_host_key_check"]).lower() == 'false':
self.no_host_key_check = False

if not self.remote_host:
raise AirflowException("Missing required param: remote_host")

# Auto detecting username values from system
if not self.username:
self.log.debug(
"username to ssh to host: %s is not specified for connection id"
" %s. Using system's default provided by getpass.getuser()",
self.remote_host, self.ssh_conn_id
)
self.username = getpass.getuser()

user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get('proxycommand'):
self.host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))

if not (self.password or self.key_file):
if host_info and host_info.get('identityfile'):
self.key_file = host_info.get('identityfile')[0]

self.port = self.port or SSH_PORT

def get_conn(self):
if not self.client:
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if self.username is None:
self.username = conn.login
if self.password is None:
self.password = conn.password
if self.remote_host is None:
self.remote_host = conn.host
if conn.port is not None:
self.port = conn.port
if conn.extra is not None:
extra_options = conn.extra_dejson
self.key_file = extra_options.get("key_file")

if "timeout" in extra_options:
self.timeout = int(extra_options["timeout"], 10)

if "compress" in extra_options \
and str(extra_options["compress"]).lower() == 'false':
self.compress = False
if "no_host_key_check" in extra_options \
and \
str(extra_options["no_host_key_check"]).lower() == 'false':
self.no_host_key_check = False

if not self.remote_host:
raise AirflowException("Missing required param: remote_host")

# Auto detecting username values from system
if not self.username:
self.log.debug(
"username to ssh to host: %s is not specified for connection id"
" %s. Using system's default provided by getpass.getuser()",
self.remote_host, self.ssh_conn_id
)
self.username = getpass.getuser()

host_proxy = None
user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get('proxycommand'):
host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))

if not (self.password or self.key_file):
if host_info and host_info.get('identityfile'):
self.key_file = host_info.get('identityfile')[0]

try:
client = paramiko.SSHClient()
client.load_system_host_keys()
if self.no_host_key_check:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

if self.password and self.password.strip():
client.connect(hostname=self.remote_host,
username=self.username,
password=self.password,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=host_proxy)
else:
client.connect(hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=host_proxy)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)

self.client = client
except paramiko.AuthenticationException as auth_error:
self.log.error(
"Auth failed while connecting to host: %s, error: %s",
self.remote_host, auth_error
)
except paramiko.SSHException as ssh_error:
self.log.error(
"Failed connecting to host: %s, error: %s",
self.remote_host, ssh_error
)
except Exception as error:
self.log.error(
"Error connecting to host: %s, error: %s",
self.remote_host, error
)
return self.client

@contextmanager
def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
"""
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
Remember to close() the returned "tunnel" object in order to clean up
after yourself when you are done with the tunnel.
Opens a ssh connection to the remote host.

:param local_port:
:type local_port: int
:param remote_port:
:type remote_port: int
:param remote_host:
:type remote_host: str
:return:
:return paramiko.SSHClient object
"""

import subprocess
# this will ensure the connection to the ssh.remote_host from where the tunnel
# is getting created
self.get_conn()

tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port)

ssh_cmd = ["ssh", "{0}@{1}".format(self.username, self.remote_host),
"-o", "ControlMaster=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "StrictHostKeyChecking=no"]

ssh_tunnel_cmd = ["-L", tunnel_host,
"echo -n ready && cat"
]

ssh_cmd += ssh_tunnel_cmd
self.log.debug("Creating tunnel with cmd: %s", ssh_cmd)

proc = subprocess.Popen(ssh_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
close_fds=True)
ready = proc.stdout.read(5)
assert ready == b"ready", \
"Did not get 'ready' from remote, got '{0}' instead".format(ready)
yield
proc.communicate()
assert proc.returncode == 0, \
"Tunnel process did unclean exit (returncode {}".format(proc.returncode)
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
client = paramiko.SSHClient()
client.load_system_host_keys()
if self.no_host_key_check:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

if self.password and self.password.strip():
client.connect(hostname=self.remote_host,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work if the private key has passphrase. I tried to do that and it fail. Can you please add key_filename=self.key_file to this if statement as well? That worked for me and I don't want to create a separate PR if it can be sorted out here itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi kaxil, I added the key_file for both ssh and the tunnel. The SSHTunnel package has the same behaviour as paramiko and will also use the provided password for the key_file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NielsZeilemaker 👍

username=self.username,
password=self.password,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)
else:
client.connect(hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)

self.client = client
return client

def __enter__(self):
warnings.warn('The contextmanager of SSHHook is deprecated.'
'Please use get_conn() as a contextmanager instead.'
'This method will be removed in Airflow 2.0',
category=DeprecationWarning)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.client is not None:
self.client.close()
self.client = None

def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
"""
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.

:param remote_port: The remote port to create a tunnel to
:type remote_port: int
:param remote_host: The remote host to create a tunnel to (default localhost)
:type remote_host: str
:param local_port: The local port to attach the tunnel to
:type local_port: int

:return: sshtunnel.SSHTunnelForwarder object
"""

if local_port:
local_bind_address = ('localhost', local_port)
else:
local_bind_address = ('localhost',)

if self.password and self.password.strip():
client = SSHTunnelForwarder(self.remote_host,
ssh_port=self.port,
ssh_username=self.username,
ssh_password=self.password,
ssh_pkey=self.key_file,
ssh_proxy=self.host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
logger=self.log)
else:
client = SSHTunnelForwarder(self.remote_host,
ssh_port=self.port,
ssh_username=self.username,
ssh_pkey=self.key_file,
ssh_proxy=self.host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
host_pkey_directories=[],
logger=self.log)

return client

def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
warnings.warn('SSHHook.create_tunnel is deprecated, Please'
'use get_tunnel() instead. But please note that the'
'order of the parameters have changed'
'This method will be removed in Airflow 2.0',
category=DeprecationWarning)

return self.get_tunnel(remote_port, remote_host, local_port)
28 changes: 14 additions & 14 deletions airflow/contrib/operators/sftp_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ def execute(self, context):
if self.remote_host is not None:
self.ssh_hook.remote_host = self.remote_host

ssh_client = self.ssh_hook.get_conn()
sftp_client = ssh_client.open_sftp()
if self.operation.lower() == SFTPOperation.GET:
file_msg = "from {0} to {1}".format(self.remote_filepath,
self.local_filepath)
self.log.debug("Starting to transfer %s", file_msg)
sftp_client.get(self.remote_filepath, self.local_filepath)
else:
file_msg = "from {0} to {1}".format(self.local_filepath,
self.remote_filepath)
self.log.debug("Starting to transfer file %s", file_msg)
sftp_client.put(self.local_filepath,
self.remote_filepath,
confirm=self.confirm)
with self.ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
if self.operation.lower() == SFTPOperation.GET:
file_msg = "from {0} to {1}".format(self.remote_filepath,
self.local_filepath)
self.log.debug("Starting to transfer %s", file_msg)
sftp_client.get(self.remote_filepath, self.local_filepath)
else:
file_msg = "from {0} to {1}".format(self.local_filepath,
self.remote_filepath)
self.log.debug("Starting to transfer file %s", file_msg)
sftp_client.put(self.local_filepath,
self.remote_filepath,
confirm=self.confirm)

except Exception as e:
raise AirflowException("Error while transferring {0}, error: {1}"
Expand Down
Loading