diff --git a/airflow/contrib/hooks/ftp_hook.py b/airflow/contrib/hooks/ftp_hook.py index 8beefb372916c..03849012a3d5b 100644 --- a/airflow/contrib/hooks/ftp_hook.py +++ b/airflow/contrib/hooks/ftp_hook.py @@ -148,7 +148,11 @@ def delete_directory(self, path): conn = self.get_conn() conn.rmd(path) - def retrieve_file(self, remote_full_path, local_full_path_or_buffer): + def retrieve_file( + self, + remote_full_path, + local_full_path_or_buffer, + callback=None): """ Transfers the remote file to a local location. @@ -161,23 +165,59 @@ def retrieve_file(self, remote_full_path, local_full_path_or_buffer): :param local_full_path_or_buffer: full path to the local file or a file-like buffer :type local_full_path_or_buffer: str or file-like buffer + :param callback: callback which is called each time a block of data + is read. if you do not use a callback, these blocks will be written + to the file or buffer passed in. if you do pass in a callback, note + that writing to a file or buffer will need to be handled inside the + callback. + [default: output_handle.write()] + :type callback: callable + + Example:: + hook = FTPHook(ftp_conn_id='my_conn') + + remote_path = '/path/to/remote/file' + local_path = '/path/to/local/file' + + # with a custom callback (in this case displaying progress on each read) + def print_progress(percent_progress): + self.log.info('Percent Downloaded: %s%%' % percent_progress) + + total_downloaded = 0 + total_file_size = hook.get_size(remote_path) + output_handle = open(local_path, 'wb') + def write_to_file_with_progress(data): + total_downloaded += len(data) + output_handle.write(data) + percent_progress = (total_downloaded / total_file_size) * 100 + print_progress(percent_progress) + hook.retrieve_file(remote_path, None, callback=write_to_file_with_progress) + + # without a custom callback data is written to the local_path + hook.retrieve_file(remote_path, local_path) """ conn = self.get_conn() is_path = isinstance(local_full_path_or_buffer, basestring) - if is_path: - output_handle = open(local_full_path_or_buffer, 'wb') + # without a callback, default to writing to a user-provided file or + # file-like buffer + if not callback: + if is_path: + output_handle = open(local_full_path_or_buffer, 'wb') + else: + output_handle = local_full_path_or_buffer + callback = output_handle.write else: - output_handle = local_full_path_or_buffer + output_handle = None remote_path, remote_file_name = os.path.split(remote_full_path) conn.cwd(remote_path) self.log.info('Retrieving file from FTP: %s', remote_full_path) - conn.retrbinary('RETR %s' % remote_file_name, output_handle.write) + conn.retrbinary('RETR %s' % remote_file_name, callback) self.log.info('Finished retrieving file from FTP: %s', remote_full_path) - if is_path: + if is_path and output_handle: output_handle.close() def store_file(self, remote_full_path, local_full_path_or_buffer): @@ -230,6 +270,12 @@ def rename(self, from_name, to_name): return conn.rename(from_name, to_name) def get_mod_time(self, path): + """ + Returns a datetime object representing the last time the file was modified + + :param path: remote file path + :type path: string + """ conn = self.get_conn() ftp_mdtm = conn.sendcmd('MDTM ' + path) time_val = ftp_mdtm[4:] @@ -239,6 +285,16 @@ def get_mod_time(self, path): except ValueError: return datetime.datetime.strptime(time_val, '%Y%m%d%H%M%S') + def get_size(self, path): + """ + Returns the size of a file (in bytes) + + :param path: remote file path + :type path: string + """ + conn = self.get_conn() + return conn.size(path) + class FTPSHook(FTPHook): diff --git a/tests/contrib/hooks/test_ftp_hook.py b/tests/contrib/hooks/test_ftp_hook.py index 8b9ae2cd59556..1274990827096 100644 --- a/tests/contrib/hooks/test_ftp_hook.py +++ b/tests/contrib/hooks/test_ftp_hook.py @@ -19,6 +19,7 @@ # import mock +import six import unittest from airflow.contrib.hooks import ftp_hook as fh @@ -101,6 +102,28 @@ def test_mod_time_micro(self): self.conn_mock.sendcmd.assert_called_once_with('MDTM ' + path) + def test_get_size(self): + self.conn_mock.size.return_value = 1942 + + path = '/path/file' + with fh.FTPHook() as ftp_hook: + ftp_hook.get_size(path) + + self.conn_mock.size.assert_called_once_with(path) + + def test_retrieve_file(self): + _buffer = six.StringIO('buffer') + with fh.FTPHook() as ftp_hook: + ftp_hook.retrieve_file(self.path, _buffer) + self.conn_mock.retrbinary.assert_called_once_with('RETR path', _buffer.write) + + def test_retrieve_file_with_callback(self): + func = mock.Mock() + _buffer = six.StringIO('buffer') + with fh.FTPHook() as ftp_hook: + ftp_hook.retrieve_file(self.path, _buffer, callback=func) + self.conn_mock.retrbinary.assert_called_once_with('RETR path', func) + if __name__ == '__main__': unittest.main()