diff --git a/.travis.yml b/.travis.yml index 75cc0ab..a7d71f0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ python: install: - pip install ".[testing]" - pip install ".[nlp]" +- pip install ".[s3driver]" - python -m nltk.downloader punkt stopwords wordnet script: pytest deploy: diff --git a/quantgov/corpora/__init__.py b/quantgov/corpora/__init__.py index 2e7d776..da8e2b1 100644 --- a/quantgov/corpora/__init__.py +++ b/quantgov/corpora/__init__.py @@ -7,7 +7,9 @@ FlatFileCorpusDriver, RecursiveDirectoryCorpusDriver, NamePatternCorpusDriver, - IndexDriver + IndexDriver, + S3Driver, + S3DatabaseDriver ) warnings.warn( diff --git a/quantgov/corpus/__init__.py b/quantgov/corpus/__init__.py index be0ec36..f095957 100644 --- a/quantgov/corpus/__init__.py +++ b/quantgov/corpus/__init__.py @@ -5,5 +5,7 @@ FlatFileCorpusDriver, RecursiveDirectoryCorpusDriver, NamePatternCorpusDriver, - IndexDriver + IndexDriver, + S3Driver, + S3DatabaseDriver ) diff --git a/quantgov/corpus/structures.py b/quantgov/corpus/structures.py index 7030c2e..46cdd15 100644 --- a/quantgov/corpus/structures.py +++ b/quantgov/corpus/structures.py @@ -9,16 +9,40 @@ import csv import logging +from decorator import decorator from collections import namedtuple from pathlib import Path from .. import utils as qgutils +try: + import boto3 +except ImportError: + boto3 = None +try: + import sqlalchemy +except ImportError: + sqlalchemy = None + log = logging.getLogger(__name__) Document = namedtuple('Document', ['index', 'text']) +@decorator +def check_boto(func, *args, **kwargs): + if boto3 is None: + raise RuntimeError('Must install boto3 to use {}'.format(func)) + return func(*args, **kwargs) + + +@decorator +def check_sqlalchemy(func, *args, **kwargs): + if sqlalchemy is None: + raise RuntimeError('Must install sqlalchemy to use {}'.format(func)) + return func(*args, **kwargs) + + class CorpusStreamer(object): """ A knowledgable wrapper for a CorpusDriver stream @@ -243,3 +267,64 @@ def gen_indices_and_paths(self): next(reader) for row in reader: yield tuple(row[:-1]), Path(row[-1]) + + +class S3Driver(IndexDriver): + """ + Serve a whole or partial corpus from a remote file location in s3. + Filtering can be done using the values provided in the index file. + """ + + @check_boto + def __init__(self, index, bucket, encoding='utf-8', cache=True): + self.index = Path(index) + self.bucket = bucket + self.client = boto3.client('s3') + self.encoding = encoding + with self.index.open(encoding=encoding) as inf: + index_labels = next(csv.reader(inf))[:-1] + super(IndexDriver, self).__init__( + index_labels=index_labels, encoding=encoding, cache=cache) + + def read(self, docinfo): + idx, path = docinfo + body = self.client.get_object(Bucket=self.bucket, + Key=str(path))['Body'] + return Document(idx, body.read().decode(self.encoding)) + + def filter(self, pattern): + """ Filter paths based on index values. """ + raise NotImplementedError + + def stream(self): + """Yield text from an object stored in s3. """ + return qgutils.lazy_parallel(self.read, self.gen_indices_and_paths()) + + +class S3DatabaseDriver(S3Driver): + """ + Retrieves an index table from a database with an arbitrary, user-provided + query and serves documents like a normal S3Driver. + """ + + @check_boto + @check_sqlalchemy + def __init__(self, protocol, user, password, host, db, port, query, + bucket, cache=True, encoding='utf-8'): + self.bucket = bucket + self.client = boto3.client('s3') + self.index = [] + engine = sqlalchemy.create_engine('{}://{}:{}@{}:{}/{}' + .format(protocol, user, password, + host, port, db)) + conn = engine.connect() + result = conn.execute(query) + for doc in result: + self.index.append(doc) + index_labels = doc.keys() + super(IndexDriver, self).__init__( + index_labels=index_labels, encoding=encoding, cache=cache) + + def gen_indices_and_paths(self): + for row in self.index: + yield tuple(row[:-1]), row[-1] diff --git a/setup.py b/setup.py index 3d424b1..eb1c880 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,10 @@ def find_version(*file_paths): 'nlp': [ 'textblob', 'nltk', + ], + 's3driver': [ + 'sqlalchemy', + 'boto3' ] }, entry_points={ diff --git a/tests/test_corpora.py b/tests/test_corpora.py index f77a249..488fca9 100644 --- a/tests/test_corpora.py +++ b/tests/test_corpora.py @@ -6,14 +6,14 @@ def build_recursive_directory_corpus(directory): - for path, text in (('a/1.txt', u'foo'), ('b/2.txt', u'bar')): + for path, text in (('a/1.txt', 'foo'), ('b/2.txt', 'bar')): directory.join(path).write_text(text, encoding='utf-8', ensure=True) return quantgov.corpus.RecursiveDirectoryCorpusDriver( directory=str(directory), index_labels=('letter', 'number')) def build_name_pattern_corpus(directory): - for path, text in (('a_1.txt', u'foo'), ('b_2.txt', u'bar')): + for path, text in (('a_1.txt', 'foo'), ('b_2.txt', 'bar')): path = directory.join(path).write_text( text, encoding='utf-8', ensure=True) return quantgov.corpus.NamePatternCorpusDriver( @@ -25,23 +25,39 @@ def build_name_pattern_corpus(directory): def build_index_corpus(directory): rows = [] for letter, number, path, text in ( - ('a', '1', 'first.txt', u'foo'), - ('b', '2', 'second.txt', u'bar') + ('a', '1', 'first.txt', 'foo'), + ('b', '2', 'second.txt', 'bar') ): outpath = directory.join(path, abs=1) outpath.write_text(text, encoding='utf-8') rows.append((letter, number, str(outpath))) index_path = directory.join('index.csv') with index_path.open('w', encoding='utf-8') as outf: - outf.write(u'letter,number,path\n') - outf.write(u'\n'.join(','.join(row) for row in rows)) - return quantgov.corpus.IndexDriver(str(index_path)) + outf.write('letter,number,path\n') + outf.write('\n'.join(','.join(row) for row in rows)) + return quantgov.corpora.IndexDriver(str(index_path)) + + +def build_s3_corpus(directory): + rows = [] + for letter, number, path in ( + ('a', '1', 'quantgov_tests/first.txt'), + ('b', '2', 'quantgov_tests/second.txt') + ): + rows.append((letter, number, path)) + index_path = directory.join('index.csv') + with index_path.open('w', encoding='utf-8') as outf: + outf.write('letter,number,path\n') + outf.write('\n'.join(','.join(row) for row in rows)) + return quantgov.corpora.S3Driver(str(index_path), + bucket='quantgov-databanks') BUILDERS = { 'RecursiveDirectoryCorpusDriver': build_recursive_directory_corpus, 'NamePatternCorpusDriver': build_name_pattern_corpus, 'IndexDriver': build_index_corpus, + 'S3Driver': build_s3_corpus }