diff --git a/quantgov/__main__.py b/quantgov/__main__.py index 8eac757..188208d 100644 --- a/quantgov/__main__.py +++ b/quantgov/__main__.py @@ -118,6 +118,9 @@ def parse_args(): estimate.add_argument( '--probability', action='store_true', help='output probabilities instead of predictions') + estimate.add_argument( + '--precision', default=4, type=int, + help='number of decimal places to round the probabilities') estimate.add_argument( '-o', '--outfile', type=lambda x: open(x, 'w', newline='', encoding='utf-8'), @@ -187,7 +190,7 @@ def run_estimator(args): elif args.subcommand == "estimate": quantgov.estimator.estimate( args.vectorizer, args.model, args.corpus, args.probability, - args.outfile + args.precision, args.outfile ) diff --git a/quantgov/estimator/estimation.py b/quantgov/estimator/estimation.py index af3da83..c77dd2d 100644 --- a/quantgov/estimator/estimation.py +++ b/quantgov/estimator/estimation.py @@ -45,7 +45,7 @@ def estimate_simple(vectorizer, model, streamer): yield from zip(streamer.index, pipeline.predict(texts)) -def estimate_probability(vectorizer, model, streamer): +def estimate_probability(vectorizer, model, streamer, precision): """ Generate probabilities for a one-label estimator @@ -61,11 +61,13 @@ def estimate_probability(vectorizer, model, streamer): pipeline = get_pipeline(vectorizer, model) texts = (doc.text for doc in streamer) truecol = list(int(i) for i in model.model.classes_).index(1) - predicted = (i[truecol] for i in pipeline.predict_proba(texts)) + predicted = ( + i[truecol] for i in pipeline.predict_proba(texts).round(precision) + ) yield from zip(streamer.index, predicted) -def estimate_probability_multilabel(vectorizer, model, streamer): +def estimate_probability_multilabel(vectorizer, model, streamer, precision): """ Generate probabilities for a multilabel binary estimator @@ -96,13 +98,13 @@ def estimate_probability_multilabel(vectorizer, model, streamer): try: for i, docidx in enumerate(streamer.index): yield docidx, tuple( - label_predictions[i, truecols[j]] + label_predictions[i, truecols[j]].round(int(precision)) for j, label_predictions in enumerate(predicted)) except IndexError: - yield from zip(streamer.index, predicted) + yield from zip(streamer.index, predicted.round(int(precision))) -def estimate_probability_multiclass(vectorizer, model, streamer): +def estimate_probability_multiclass(vectorizer, model, streamer, precision): """ Generate probabilities for a one-label, multiclass estimator @@ -117,10 +119,14 @@ def estimate_probability_multiclass(vectorizer, model, streamer): """ pipeline = get_pipeline(vectorizer, model) texts = (doc.text for doc in streamer) - yield from zip(streamer.index, pipeline.predict_proba(texts)) + yield from zip( + streamer.index, + (i for i in pipeline.predict_proba(texts).round(precision)) + ) -def estimate_probability_multilabel_multiclass(vectorizer, model, streamer): +def estimate_probability_multilabel_multiclass( + vectorizer, model, streamer, precision): """ Generate probabilities for a multilabel, multiclass estimator @@ -137,8 +143,8 @@ def estimate_probability_multilabel_multiclass(vectorizer, model, streamer): texts = (doc.text for doc in streamer) predicted = pipeline.predict_proba(texts) for i, docidx in enumerate(streamer.index): - yield docidx, tuple(label_predictions[i] - for label_predictions in predicted) + yield docidx, tuple(label_predictions[i] for label_predictions + in predicted.round(precision)) def is_multiclass(classes): @@ -152,7 +158,7 @@ def is_multiclass(classes): return True -def estimate(vectorizer, model, corpus, probability, outfile): +def estimate(vectorizer, model, corpus, probability, precision, outfile): """ Estimate label values for documents in corpus @@ -184,7 +190,7 @@ def estimate(vectorizer, model, corpus, probability, outfile): if multilabel: if multiclass: # Multilabel-multiclass probability results = estimate_probability_multilabel_multiclass( - vectorizer, model, streamer) + vectorizer, model, streamer, precision) writer.writerow(corpus.index_labels + ('label', 'class', 'probability')) writer.writerows( @@ -198,7 +204,7 @@ def estimate(vectorizer, model, corpus, probability, outfile): ) else: # Multilabel probability results = estimate_probability_multilabel( - vectorizer, model, streamer) + vectorizer, model, streamer, precision) writer.writerow(corpus.index_labels + ('label', 'probability')) writer.writerows( docidx + (label_name, prediction) @@ -209,7 +215,7 @@ def estimate(vectorizer, model, corpus, probability, outfile): elif multiclass: # Multiclass probability writer.writerow(corpus.index_labels + ('class', 'probability')) results = estimate_probability_multiclass( - vectorizer, model, streamer) + vectorizer, model, streamer, precision) writer.writerows( docidx + (class_name, prediction) for docidx, predictions in results @@ -217,7 +223,8 @@ def estimate(vectorizer, model, corpus, probability, outfile): model.model.classes_, predictions) ) else: # Simple probability - results = estimate_probability(vectorizer, model, streamer) + results = estimate_probability( + vectorizer, model, streamer, precision) writer.writerow( corpus.index_labels + (model.label_names[0] + '_prob',)) writer.writerows( diff --git a/tests/pseudo_estimator/.gitignore b/tests/pseudo_estimator/.gitignore new file mode 100644 index 0000000..b5d93a5 --- /dev/null +++ b/tests/pseudo_estimator/.gitignore @@ -0,0 +1,92 @@ +.snakemake +notebooks/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# IPython Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# dotenv +.env + +# virtualenv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject diff --git a/tests/pseudo_estimator/data/model.pickle b/tests/pseudo_estimator/data/model.pickle new file mode 100644 index 0000000..2ffaac6 Binary files /dev/null and b/tests/pseudo_estimator/data/model.pickle differ diff --git a/tests/pseudo_estimator/data/modelmulticlass.pickle b/tests/pseudo_estimator/data/modelmulticlass.pickle new file mode 100644 index 0000000..2071d94 Binary files /dev/null and b/tests/pseudo_estimator/data/modelmulticlass.pickle differ diff --git a/tests/pseudo_estimator/data/vectorizer.pickle b/tests/pseudo_estimator/data/vectorizer.pickle new file mode 100644 index 0000000..0fdaee6 Binary files /dev/null and b/tests/pseudo_estimator/data/vectorizer.pickle differ diff --git a/tests/test_estimator.py b/tests/test_estimator.py new file mode 100644 index 0000000..07876d5 --- /dev/null +++ b/tests/test_estimator.py @@ -0,0 +1,71 @@ +import pytest +import quantgov.estimator +import subprocess + +from pathlib import Path + + +PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus') +PSEUDO_ESTIMATOR_PATH = ( + Path(__file__).resolve().parent + .joinpath('pseudo_estimator') +) + + +def check_output(cmd): + return ( + subprocess.check_output(cmd, universal_newlines=True) + .replace('\n\n', '\n') + ) + + +def test_simple_estimator(): + output = check_output( + ['quantgov', 'estimator', 'estimate', + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')), + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'model.pickle')), + str(PSEUDO_CORPUS_PATH)] + ) + assert output == 'file,is_world\ncfr,False\nmoby,False\n' + + +def test_probability_estimator(): + output = check_output( + ['quantgov', 'estimator', 'estimate', + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')), + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'model.pickle')), + str(PSEUDO_CORPUS_PATH), '--probability'] + ) + assert output == ('file,is_world_prob\ncfr,0.0899\nmoby,0.0216\n') + + +def test_probability_estimator_6decimals(): + output = check_output( + ['quantgov', 'estimator', 'estimate', + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')), + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'model.pickle')), + str(PSEUDO_CORPUS_PATH), '--probability', '--precision', '6'] + ) + assert output == ('file,is_world_prob\ncfr,0.089898\nmoby,0.02162\n') + + +def test_multiclass_probability_estimator(): + output = check_output( + ['quantgov', 'estimator', 'estimate', + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')), + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'modelmulticlass.pickle')), + str(PSEUDO_CORPUS_PATH), '--probability'] + ) + assert output == ('file,class,probability\n' + 'cfr,business-and-industry,0.1765\n' + 'cfr,environment,0.1294\n' + 'cfr,health-and-public-welfare,0.1785\n' + 'cfr,money,0.169\n' + 'cfr,science-and-technology,0.147\n' + 'cfr,world,0.1997\n' + 'moby,business-and-industry,0.1804\n' + 'moby,environment,0.1529\n' + 'moby,health-and-public-welfare,0.205\n' + 'moby,money,0.1536\n' + 'moby,science-and-technology,0.1671\n' + 'moby,world,0.141\n')