diff --git a/Makefile b/Makefile index b67c0831..e4d49ecd 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ QUERY_TESTING_FILE = spec/test/stored_queries/test_query.py -.PHONY: test reset full_query_testing sampling_query_testing +.PHONY: test reset full_query_testing sampling_query_testing graph_query_testing test: docker-compose build @@ -21,3 +21,12 @@ full_query_testing: sampling_query_testing: DO_QUERY_TESTING=sampling time python -m pytest -s $(QUERY_TESTING_FILE) + +compare_query_testing: + DO_QUERY_TESTING=compare time python -m pytest -s $(QUERY_TESTING_FILE) + +graph_query_testing: + # invocation example: + # make graph_query_testing data_new_fp="tmp/blah.json" data_old_fp="tmp/bleh.json" + # where `data_new_fp` and `data_old_fp` are generated by `make compare_query_testing` + DO_QUERY_TESTING=graph python $(QUERY_TESTING_FILE) $(data_new_fp) $(data_old_fp) diff --git a/dev-requirements.txt b/dev-requirements.txt index 96007184..de91a89d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,5 +7,3 @@ coverage==5.2.1 typed-ast>=1.4.0 black==20.8b1 pytest==6.2.5 -python-arango==5.4.0 -numpy==1.21.2 diff --git a/spec/test/stored_queries/test_query.py b/spec/test/stored_queries/test_query.py index cf0dbe49..e76d19f5 100644 --- a/spec/test/stored_queries/test_query.py +++ b/spec/test/stored_queries/test_query.py @@ -1,28 +1,40 @@ -import traceback as tb -import sys +""" +This script can be run from `make` +Essentially it was created to run stored queries against the ncbi_taxon collection +and collect data and stats. +""" + import os -import json -import datetime -import time -import random -import textwrap -import warnings -import pytest -from typing import Tuple, List -from requests.exceptions import ReadTimeout import unittest -from arango import ArangoClient -import numpy as np - -from relation_engine_server.utils import json_validation - # Skip entire module if env var not set +# to avoid non-Docker-container imports or otherwise +# specific/costly operations in script if not os.environ.get("DO_QUERY_TESTING"): raise unittest.SkipTest( "Env var DO_QUERY_TESTING not set. Skipping query testing module" ) +import traceback as tb # noqa E402 +import sys # noqa E402 +import json # noqa E402 +import datetime # noqa E402 +import time # noqa E402 +import random # noqa E402 +import textwrap # noqa E402 +import warnings # noqa E402 +import pytest # noqa E402 +from typing import Tuple, List # noqa E402 +from requests.exceptions import ReadTimeout # noqa E402 + +from arango import ArangoClient # noqa E402 +import numpy as np # noqa E402 +import pandas as pd # noqa E402 +import seaborn as sns # noqa E402 +import matplotlib.pyplot as plt # noqa E402 + +from relation_engine_server.utils import json_validation # noqa E402 + warnings.filterwarnings("ignore") # Directories and files @@ -39,6 +51,9 @@ STORED_QUERY_NO_SORT_FP = os.path.join( ROOT_DIR, "spec/stored_queries/taxonomy/taxonomy_search_species_strain_no_sort.yaml" ) +STORED_QUERY_OLD_FP = os.path.join( + ROOT_DIR, "spec/stored_queries/taxonomy/taxonomy_search_species.yaml" +) if not os.path.exists(TMP_OUT_DIR): os.mkdir(TMP_OUT_DIR) @@ -49,7 +64,7 @@ CONFIG = json.load(fh) CLIENT = ArangoClient(hosts=CONFIG["host"]) DB = CLIENT.db("ci", username=CONFIG["username"], password=CONFIG["password"]) -except Exception as e: +except Exception: help_msg = """ Please set host URL, username, and password in arango_live_server_config.json, e.g., { @@ -71,6 +86,7 @@ # Load the queries QUERY = json_validation.load_json_yaml(STORED_QUERY_FP)["query"] QUERY_NO_SORT = json_validation.load_json_yaml(STORED_QUERY_NO_SORT_FP)["query"] +QUERY_OLD = json_validation.load_json_yaml(STORED_QUERY_OLD_FP)["query"] # Set query bind parameters LIMIT = 20 @@ -102,7 +118,10 @@ def use_sort(search_text): - """Determine whether to use the sorting or non-sorting query""" + """ + Determine whether to use the sorting or non-sorting stored query for the new query. + Smaller search texts' results will not be sorted on. + """ return len(search_text) > 3 @@ -124,8 +143,28 @@ def jprint(jo, dry=False): print(txt) -def taxonomy_search_species_strain(search_text): - """Make the query""" +def do_taxonomy_search_species_query(search_text): + """Do the old query""" + cursor = DB.aql.execute( + QUERY_OLD, + bind_vars={ + "@taxon_coll": "ncbi_taxon", + "sciname_field": "scientific_name", + "search_text": "prefix:" + search_text, # how the old query was set up + "ts": NOW, + "offset": None, + "limit": LIMIT, + "select": ["scientific_name"], + }, + ) + return { + "results": [e["scientific_name"] for e in list(cursor.batch())], + **cursor.statistics(), + } + + +def do_taxonomy_search_species_strain_query(search_text): + """Do the new query""" cursor = DB.aql.execute( QUERY if use_sort(search_text) else QUERY_NO_SORT, bind_vars={ @@ -146,8 +185,8 @@ def taxonomy_search_species_strain(search_text): def get_search_text_samplings( resample=True, - cap_scinames=2000, - cap_scinames_prefixes=5000, + cap_scinames=1000, + cap_scinames_prefixes=1000, ): """ Get samplings of scinames or prefixes thereof to gauge execution time @@ -165,14 +204,12 @@ def get_search_text_samplings( samplings = json.load(fh) return samplings - print("Sampling search texts and prefixes thereof ...") - - seen_prefixes = set() + print("\nSampling search texts and prefixes thereof ...") def get_capped_samplings(styp: str) -> Tuple[list, list]: """ Randomly sample scinames - Then take all prefixes (not already seen in accumulated prefixes) + Then take all prefixes, deduplicated "Wild" just means the exclusion of "simple" """ if styp not in ["simple", "wild"]: @@ -185,19 +222,17 @@ def get_capped_samplings(styp: str) -> Tuple[list, list]: if is_simple(sciname) == (styp == "simple") ] random.shuffle(sampling) - sampling = sampling[:cap_scinames] - sampling_prefixes = [ - sciname[:i] for sciname in sampling for i in range(1, len(sciname)) - ] - sampling_prefixes = [ - sciname - for sciname in sampling_prefixes - if sciname not in seen_prefixes - and not seen_prefixes.add( - sciname - ) # latter operand always evaluates to true - ] - return sampling, sampling_prefixes[:cap_scinames_prefixes] + sampling = sampling[ + :cap_scinames + ] # cap this first to avoid generating overabundant prefixes + + sampling_prefixes = list( + set([sciname[:i] for sciname in sampling for i in range(1, len(sciname))]) + ) + random.shuffle(sampling_prefixes) + sampling_prefixes = sampling_prefixes[:cap_scinames_prefixes] + + return sampling, sampling_prefixes scinames_simple, scinames_simple_prefixes = get_capped_samplings("simple") scinames_wild, scinames_wild_prefixes = get_capped_samplings("wild") @@ -233,7 +268,7 @@ def get_capped_samplings(styp: str) -> Tuple[list, list]: return samplings -def handle_err(msg, dat, failed): +def handle_err(msg, dat=None): """ During sampling/sciname/query loops, if error arises, @@ -241,11 +276,12 @@ def handle_err(msg, dat, failed): """ print(msg) tb.print_exc() - jprint(dat) - failed.append(dat) + if dat: + dat["failed"] = True + jprint(dat) -def update_print_timekeepers(i, t0, exe_times, sampling, failed): +def update_print_timekeepers(i, t0, exe_times, sampling, num_failed): """ Calculate and print * Running average time per iteration @@ -258,10 +294,10 @@ def update_print_timekeepers(i, t0, exe_times, sampling, failed): tper_iter, tper_exe, tmed_exe, tmin_exe, tmax_exe = 0, 0, 0, 0, 0 else: tper_iter = (time.time() - t0) / i - tper_exe = np.mean(exe_times) - tmed_exe = np.median(exe_times) - tmin_exe = np.min(exe_times) - tmax_exe = np.max(exe_times) + tper_exe = np.nanmean(exe_times) + tmed_exe = np.nanmedian(exe_times) + tmin_exe = np.nanmin(exe_times) + tmax_exe = np.nanmax(exe_times) print( f"[{datetime.datetime.now().strftime('%b%d %H:%M').upper()}]", "...", @@ -277,20 +313,22 @@ def update_print_timekeepers(i, t0, exe_times, sampling, failed): "...", f"{'%.3fs' % tper_iter} per round trip", "...", - f"{'%d/%d' % (len(failed), i)} failed", + f"{'%d/%d' % (num_failed, i)} failed", ) -################################################################################ -################################################################################ +######################################################################################################################## +######################################################################################################################## def do_query_testing( samplings: dict, + do_query_func=do_taxonomy_search_species_strain_query, expect_hits: list = [ "scinames_simple", "scinames_wild", "scinames_latest", "scinames_latest_permute", ], + permute: bool = True, update_period: int = 100, ): """ @@ -298,9 +336,10 @@ def do_query_testing( Periodically outputs accumulated mean and median execution times """ # Permute since the scinames tend to start out simpler - for styp, sampling in samplings.items(): - samplings[styp] = sampling[:] - random.shuffle(samplings[styp]) + if permute: + for styp, sampling in samplings.items(): + samplings[styp] = sampling[:] + random.shuffle(samplings[styp]) # Get some nice stats to print out samplings_metadata = [ @@ -312,11 +351,10 @@ def do_query_testing( w = 120 dec = "=" * w prelude = textwrap.wrap( - "\n".join( - [ - f"samplings_num_queries={samplings_metadata},", - f"total_num_queries={total_num_queries},", - ] + ( + f"do_query_func={do_query_func.__name__}, " + f"samplings_num_queries={samplings_metadata}, " + f"total_num_queries={total_num_queries}, " ), width=w, ) @@ -330,13 +368,11 @@ def do_query_testing( # Data structures accumulating all info data_all = dict() # For all queries - failed_all = dict() # For failed queries try: for j, (styp, sampling) in enumerate(samplings.items()): - failed: List[dict] = [] - failed_all[styp] = failed + num_failed: int = 0 data: List[dict] = [] data_all[styp] = data @@ -354,52 +390,54 @@ def do_query_testing( for i, search_text in enumerate(sampling): # Calculate and print running time stats if not i % update_period: - update_print_timekeepers(i, t0, exe_times, sampling, failed) + update_print_timekeepers(i, t0, exe_times, sampling, num_failed) dat = { - "styp": styp, "i": i, "search_text": search_text, + "failed": False, } data.append(dat) try: - query_res = taxonomy_search_species_strain(search_text) + query_res = do_query_func(search_text) except Exception: - handle_err("Something went wrong in the query!", dat, failed) + handle_err("Something went wrong in the query!", dat) + query_res = { + "execution_time": np.nan, + "results": [], + } exe_times.append(query_res["execution_time"]) dat.update(query_res) + # Set `has_results` + dat["has_results"] = len(query_res["results"]) > 0 + # Set `failed` if styp in expect_hits: + hits = query_res["results"] + # Given that limit=20, + # test that sciname is in top 20, + # and they aren't >20 duplicates. + # Raise to get traceback in stdout try: - hits = query_res["results"] - # Given that limit=20, - # test that sciname is in top 20, - # and they aren't >20 duplicates. - # Raise to get traceback in stdout - if search_text not in hits or ( + assert search_text in hits # nosec B101 + assert not ( # nosec B101 len(hits) == LIMIT and all([hit == search_text for hit in hits]) - ): - raise AssertionError( - "Target sciname not in results " - "or results are all duplicates" - ) + ) except AssertionError: + num_failed += 1 handle_err( "Something went wrong in the expect hit assertion!", dat, - failed, ) # One last time after all of sampling has run - update_print_timekeepers(i + 1, t0, exe_times, sampling, failed) + update_print_timekeepers(i + 1, t0, exe_times, sampling, num_failed) except Exception: - handle_err( - "Something went wrong in the samplings/scinames/query loops!", dat, failed - ) + handle_err("Something went wrong in the samplings/scinames/query loops!") finally: results_fp = os.path.join( @@ -409,6 +447,8 @@ def do_query_testing( "__" f"{datetime.datetime.now().strftime('%d%b%Y_%H:%M').upper()}" "__" + f"{do_query_func.__name__}" + "__" f"{len(samplings)}_samplings" "__" f"{total_num_queries}_search_texts" @@ -416,24 +456,28 @@ def do_query_testing( ), ) data_meta = { + "do_query_func": do_query_func.__name__, "samplings": list(samplings.keys()), "expect_hits": expect_hits, "total_num_queries": total_num_queries, - "sampling": styp, - "i": i, + "_sampling": styp, # where it may have + "_i": i, # stopped at "data_all": data_all, - "failed_all": failed_all, } - print(f"\nWriting results/failures to {results_fp}") + print(dec) + print(f"\nWriting results to {results_fp}") + print(dec) with open(results_fp, "w") as fh: json.dump(data_meta, fh, indent=3) return data_meta +######################################################################################################################## +######################################################################################################################## @pytest.mark.skipif( not os.environ.get("DO_QUERY_TESTING") == "full", - reason="This can take a couple days, and only needs to be ascertained once", + reason="This can take a couple days, and only needs to be ascertained sporadically", ) def test_all_ncbi_latest_scinames(): do_query_testing({"scinames_latest": SCINAMES_LATEST}) @@ -441,7 +485,147 @@ def test_all_ncbi_latest_scinames(): @pytest.mark.skipif( not os.environ.get("DO_QUERY_TESTING") == "sampling", - reason="This can take a few hours, and only needs to be ascertained once", + reason="This can take an hour or so, and only needs to be ascertained sporadically", ) def test_samplings(): - do_query_testing(get_search_text_samplings()) + do_query_testing( + samplings=get_search_text_samplings(resample=True), + do_query_func=do_taxonomy_search_species_strain_query, + ) + + +@pytest.mark.skipif( + not os.environ.get("DO_QUERY_TESTING") == "compare", + reason="This can take an hour or so, and only needs to be ascertained sporadically", +) +def test_compare_queries(): + do_query_testing( + samplings=get_search_text_samplings( + resample=True, cap_scinames=500, cap_scinames_prefixes=500 + ), + do_query_func=do_taxonomy_search_species_strain_query, + permute=False, + ) + do_query_testing( + samplings=get_search_text_samplings(resample=False), + do_query_func=do_taxonomy_search_species_query, + permute=False, + ) + + +def do_graph(data_new_fp, data_old_fp): + """ + { + "data_all": { + "styp0": [ + { + "i": int, # index in sampling + "search_text": str, + "failed": bool, + "results": [ # resulting scinames + ... + ], + "execution_time": float, # s + ... + }, + ... + ], + "styp1": [ + ... + ], + ... + }, + ... + } + """ + with open(data_new_fp) as fh: + data_new = json.load(fh)["data_all"] + with open(data_old_fp) as fh: + data_old = json.load(fh)["data_all"] + + # Not meaningful/large enough to make the figure + if "edge_cases" in data_new: + del data_new["edge_cases"] + if "edge_cases" in data_old: + del data_old["edge_cases"] + + # Count num queries where the old stored query `has_results`/`failed` + old_failed_counts = { + styp: ( + len([1 for dat in data if not dat["failed"]]), + len([1 for dat in data if dat["failed"]]), + ) + for styp, data in data_old.items() + } + old_has_results_counts = { + styp: ( + len([1 for dat in data if not dat["results"]]), + len([1 for dat in data if dat["results"]]), + ) + for styp, data in data_old.items() + } + + # Sanity checks + # Should have same ordering in `styp` and `search_text` + for (styp0, data0), (styp1, data1) in zip(data_new.items(), data_old.items()): + assert styp0 == styp1 # nosec B101 + assert len(data0) == len(data1) # nosec B101 + for dat0, dat1 in zip(data0, data1): + assert dat0["search_text"] == dat1["search_text"] # nosec B101 + assert not np.isnan(dat0["execution_time"]) # nosec B101 + assert not np.isnan(dat1["execution_time"]) # nosec B101 + # old_has_results and old_failed counts should add up + for counts in [old_failed_counts, old_has_results_counts]: + for styp, count in counts.items(): + assert sum(count) == len(data_old[styp]) # nosec B101 + + df_data = [] + df_columns = [ + "exe_time_ms", + "stored_query", + "sampling", + "failed", + "has_results", + "old_failed", + "old_has_results", + ] + for sq, data_epoch in zip(["new", "old"], [data_new, data_old]): + for styp, data in data_epoch.items(): + for i, dat in enumerate(data): + # Toggle the literal strings here in tandem with + # toggling the `hue` below + df_row = [ + int(dat["execution_time"] * 1000), + sq, + f"{styp}\nn = {len(data)} ({old_failed_counts[styp][0]}/{old_failed_counts[styp][1]})", + # f"{styp}\nn = {len(data)} ({old_has_results_counts[styp][0]}/{old_has_results_counts[styp][1]})", + dat["failed"], + dat["has_results"], + data_old[styp][i]["failed"], + data_old[styp][i]["has_results"], + ] + df_data.append(df_row) + + df = pd.DataFrame(df_data, columns=df_columns) + + sns.catplot( + x="stored_query", + y="exe_time_ms", + hue="old_failed", # Toggle the `hue` here in tandem with + # hue="old_has_results", # toggling the literal strings n `df_row` above + scale="area", + scale_hue=False, + col="sampling", + data=df, + kind="violin", + split=True, + cut=0, + aspect=0.7, + bw=0.2, + ) + + plt.show() + + +if __name__ == "__main__": + do_graph(sys.argv[1], sys.argv[2])