Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions airflow/contrib/hooks/ftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def delete_directory(self, path):
conn = self.get_conn()
conn.rmd(path)

def retrieve_file(self, remote_full_path, local_full_path_or_buffer):
def retrieve_file(
self,
remote_full_path,
local_full_path_or_buffer,
callback=None):
"""
Transfers the remote file to a local location.

Expand All @@ -161,23 +165,59 @@ def retrieve_file(self, remote_full_path, local_full_path_or_buffer):
:param local_full_path_or_buffer: full path to the local file or a
file-like buffer
:type local_full_path_or_buffer: str or file-like buffer
:param callback: callback which is called each time a block of data
is read. if you do not use a callback, these blocks will be written
to the file or buffer passed in. if you do pass in a callback, note
that writing to a file or buffer will need to be handled inside the
callback.
[default: output_handle.write()]
:type callback: callable

Example::
hook = FTPHook(ftp_conn_id='my_conn')

remote_path = '/path/to/remote/file'
local_path = '/path/to/local/file'

# with a custom callback (in this case displaying progress on each read)
def print_progress(percent_progress):
self.log.info('Percent Downloaded: %s%%' % percent_progress)

total_downloaded = 0
total_file_size = hook.get_size(remote_path)
output_handle = open(local_path, 'wb')
def write_to_file_with_progress(data):
total_downloaded += len(data)
output_handle.write(data)
percent_progress = (total_downloaded / total_file_size) * 100
print_progress(percent_progress)
hook.retrieve_file(remote_path, None, callback=write_to_file_with_progress)

# without a custom callback data is written to the local_path
hook.retrieve_file(remote_path, local_path)
"""
conn = self.get_conn()

is_path = isinstance(local_full_path_or_buffer, basestring)

if is_path:
output_handle = open(local_full_path_or_buffer, 'wb')
# without a callback, default to writing to a user-provided file or
# file-like buffer
if not callback:
if is_path:
output_handle = open(local_full_path_or_buffer, 'wb')
else:
output_handle = local_full_path_or_buffer
callback = output_handle.write
else:
output_handle = local_full_path_or_buffer
output_handle = None

remote_path, remote_file_name = os.path.split(remote_full_path)
conn.cwd(remote_path)
self.log.info('Retrieving file from FTP: %s', remote_full_path)
conn.retrbinary('RETR %s' % remote_file_name, output_handle.write)
conn.retrbinary('RETR %s' % remote_file_name, callback)
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)

if is_path:
if is_path and output_handle:
output_handle.close()

def store_file(self, remote_full_path, local_full_path_or_buffer):
Expand Down Expand Up @@ -230,6 +270,12 @@ def rename(self, from_name, to_name):
return conn.rename(from_name, to_name)

def get_mod_time(self, path):
"""
Returns a datetime object representing the last time the file was modified

:param path: remote file path
:type path: string
"""
conn = self.get_conn()
ftp_mdtm = conn.sendcmd('MDTM ' + path)
time_val = ftp_mdtm[4:]
Expand All @@ -239,6 +285,16 @@ def get_mod_time(self, path):
except ValueError:
return datetime.datetime.strptime(time_val, '%Y%m%d%H%M%S')

def get_size(self, path):
"""
Returns the size of a file (in bytes)

:param path: remote file path
:type path: string
"""
conn = self.get_conn()
return conn.size(path)


class FTPSHook(FTPHook):

Expand Down
23 changes: 23 additions & 0 deletions tests/contrib/hooks/test_ftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#

import mock
import six
import unittest

from airflow.contrib.hooks import ftp_hook as fh
Expand Down Expand Up @@ -101,6 +102,28 @@ def test_mod_time_micro(self):

self.conn_mock.sendcmd.assert_called_once_with('MDTM ' + path)

def test_get_size(self):
self.conn_mock.size.return_value = 1942

path = '/path/file'
with fh.FTPHook() as ftp_hook:
ftp_hook.get_size(path)

self.conn_mock.size.assert_called_once_with(path)

def test_retrieve_file(self):
_buffer = six.StringIO('buffer')
with fh.FTPHook() as ftp_hook:
ftp_hook.retrieve_file(self.path, _buffer)
self.conn_mock.retrbinary.assert_called_once_with('RETR path', _buffer.write)

def test_retrieve_file_with_callback(self):
func = mock.Mock()
_buffer = six.StringIO('buffer')
with fh.FTPHook() as ftp_hook:
ftp_hook.retrieve_file(self.path, _buffer, callback=func)
self.conn_mock.retrbinary.assert_called_once_with('RETR path', func)


if __name__ == '__main__':
unittest.main()