From f55d8d4a6dc2dc6d01ea2df444da09e41122f0f1 Mon Sep 17 00:00:00 2001 From: Ben Capodanno Date: Tue, 4 Mar 2025 00:12:06 -0800 Subject: [PATCH 1/6] Add statistics router for counts of published records --- src/mavedb/routers/statistics.py | 38 +++++++++++++++++++- tests/routers/test_statistics.py | 61 ++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/src/mavedb/routers/statistics.py b/src/mavedb/routers/statistics.py index ad720526..c49f2515 100644 --- a/src/mavedb/routers/statistics.py +++ b/src/mavedb/routers/statistics.py @@ -1,5 +1,7 @@ +import itertools +from collections import OrderedDict from enum import Enum -from typing import Union +from typing import Union, Optional from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import Table, func, select @@ -77,6 +79,11 @@ class RecordFields(str, Enum): createdBy = "created-by" +class GroupBy(str, Enum): + month = "month" + year = "year" + + def _target_from_field_and_model( db: Session, model: Union[type[TargetAccession], type[TargetSequence]], @@ -344,3 +351,32 @@ def record_object_statistics( count_data = _record_from_field_and_model(db, model, field) return {field_val: count for field_val, count in count_data if field_val is not None} + + +@router.get("/record/{model}/published/count", status_code=200, response_model=dict[str, int]) +def record_counts(model: RecordNames, group: Optional[GroupBy] = None, db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the number of records in each table. + """ + models: dict[RecordNames, Union[type[Experiment], type[ScoreSet]]] = { + RecordNames.experiment: Experiment, + RecordNames.scoreSet: ScoreSet, + } + + queried_model = models[model] + + # Protects against Nonetype publication dates with where clause and ignore mypy typing errors in dictcomps. + objs = db.scalars( + select(queried_model.published_date) + .where(queried_model.published_date.isnot(None)) + .order_by(queried_model.published_date) + ).all() + + if group == GroupBy.month: + grouped = {k: len(list(g)) for k, g in itertools.groupby(objs, lambda t: t.strftime("%Y-%m"))} # type: ignore + elif group == GroupBy.year: + grouped = {k: len(list(g)) for k, g in itertools.groupby(objs, lambda t: t.strftime("%Y"))} # type: ignore + else: + grouped = {"all": len(objs)} + + return OrderedDict(sorted(grouped.items())) diff --git a/tests/routers/test_statistics.py b/tests/routers/test_statistics.py index 249c86a7..0cdf6ca6 100644 --- a/tests/routers/test_statistics.py +++ b/tests/routers/test_statistics.py @@ -367,3 +367,64 @@ def test_record_statistics_invalid_record_and_field(client): assert response.json()["detail"][0]["ctx"]["enum_values"] == RECORD_MODELS assert response.json()["detail"][1]["loc"] == ["path", "field"] assert response.json()["detail"][1]["ctx"]["enum_values"] == RECORD_SHARED_FIELDS + + +# Test record counts statistics +@pytest.mark.parametrize("model_value", RECORD_MODELS) +def test_record_counts_no_published_data(client, model_value, setup_router_db): + """Test record counts endpoint for published experiments and score sets.""" + response = client.get(f"/api/v1/statistics/record/{model_value}/published/count") + assert response.status_code == 200 + assert "all" in response.json() + assert response.json()["all"] == 0 + + +@pytest.mark.parametrize("model_value", RECORD_MODELS) +def test_record_counts(client, model_value, setup_router_db, setup_seq_scoreset): + """Test record counts endpoint for published experiments and score sets.""" + response = client.get(f"/api/v1/statistics/record/{model_value}/published/count") + assert response.status_code == 200 + assert "all" in response.json() + assert response.json()["all"] == 1 + + +@pytest.mark.parametrize("model_value", RECORD_MODELS) +@pytest.mark.parametrize("group_value", ["month", "year"]) +def test_record_counts_grouped_no_published_data(client, model_value, group_value, setup_router_db): + """Test record counts endpoint grouped by month and year for published experiments and score sets.""" + response = client.get(f"/api/v1/statistics/record/{model_value}/published/count?group={group_value}") + assert response.status_code == 200 + assert isinstance(response.json(), dict) + for key, value in response.json().items(): + assert isinstance(key, str) + assert isinstance(value, int) + + +@pytest.mark.parametrize("model_value", RECORD_MODELS) +@pytest.mark.parametrize("group_value", ["month", "year"]) +def test_record_counts_grouped( + session, client, model_value, group_value, setup_router_db, setup_seq_scoreset, setup_acc_scoreset +): + """Test record counts endpoint grouped by month and year for published experiments and score sets.""" + response = client.get(f"/api/v1/statistics/record/{model_value}/published/count?group={group_value}") + assert response.status_code == 200 + assert isinstance(response.json(), dict) + for key, value in response.json().items(): + assert isinstance(key, str) + assert value == 2 + + +def test_record_counts_invalid_model(client): + """Test record counts endpoint with an invalid model.""" + response = client.get("/api/v1/statistics/record/invalid-model/published/count") + assert response.status_code == 422 + assert response.json()["detail"][0]["loc"] == ["path", "model"] + assert response.json()["detail"][0]["ctx"]["enum_values"] == RECORD_MODELS + + +def test_record_counts_invalid_group(client): + """Test record counts endpoint with an invalid group.""" + response = client.get("/api/v1/statistics/record/experiment/published/count?group=invalid-group") + assert response.status_code == 422 + assert response.json()["detail"][0]["loc"] == ["query", "group"] + assert response.json()["detail"][0]["ctx"]["enum_values"] == ["month", "year"] From 8076cf5155de9770805dede953cb2a5d09e01792 Mon Sep 17 00:00:00 2001 From: Ben Capodanno Date: Wed, 5 Mar 2025 14:34:53 -0800 Subject: [PATCH 2/6] Refactor statistics endpoints for simplicity The previous version of these statistics endpoints featured code that was overly complex for the task at hand. This change refactors statistics code in a way that slightly increases duplicated code (mostly API decorators) while decreasing the complexity of shared functions and logic switches. The result is much clearer and concise endpoints for existing routes. --- src/mavedb/routers/statistics.py | 564 +++++++++++++++++-------------- tests/routers/test_statistics.py | 77 +++-- 2 files changed, 357 insertions(+), 284 deletions(-) diff --git a/src/mavedb/routers/statistics.py b/src/mavedb/routers/statistics.py index c49f2515..137d97f2 100644 --- a/src/mavedb/routers/statistics.py +++ b/src/mavedb/routers/statistics.py @@ -1,10 +1,10 @@ import itertools from collections import OrderedDict from enum import Enum -from typing import Union, Optional +from typing import Any, Union, Optional from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Table, func, select +from sqlalchemy import Table, func, select, Select from sqlalchemy.orm import Session from mavedb.deps import get_db @@ -43,27 +43,20 @@ responses={404: {"description": "Not found"}}, ) -## Enum classes hold valid endpoints for different statistics routes. - - -class TargetGeneFields(str, Enum): - category = "category" - organism = "organism" +TARGET_ACCESSION_TAXONOMY = "Homo sapiens" - ensemblIdentifier = "ensembl-identifier" - refseqIdentifier = "refseq-identifier" - uniprotIdentifier = "uniprot-identifier" +## Union types +RecordModels = Union[type[Experiment], type[ScoreSet]] +RecordAssociationTables = Union[ + Table, + type[ExperimentControlledKeywordAssociation], + type[ExperimentPublicationIdentifierAssociation], + type[ScoreSetPublicationIdentifierAssociation], +] -class TargetAccessionFields(str, Enum): - accession = "accession" - assembly = "assembly" - gene = "gene" - -class TargetSequenceFields(str, Enum): - sequence = "sequence" - sequenceType = "sequence-type" +## Enum classes hold valid endpoints for different statistics routes. class RecordNames(str, Enum): @@ -84,288 +77,189 @@ class GroupBy(str, Enum): year = "year" -def _target_from_field_and_model( - db: Session, - model: Union[type[TargetAccession], type[TargetSequence]], - field: Union[TargetAccessionFields, TargetSequenceFields], -): +def _model_and_association_from_record_field( + record: RecordNames, field: Optional[RecordFields] +) -> tuple[RecordModels, Optional[RecordAssociationTables]]: """ - Given either the target accession or target sequence model, generate counts that can be used to create - a statistic for those fields. + Given a member of the RecordNames and RecordFields Enums, generate the model and association table that can be used + to generate statistics for those fields. + + This function should be used for generating statistics for fields shared between Experiments and Score Sets. + If necessary, Experiment Sets can be handled in a similar manner in the future. """ - # Protection from this case occurs via FastApi/pydantic Enum validation on endpoints that reference this function. - # If we are careful with our enumeration definitons, we should not end up here. - if (model == TargetAccession and field not in TargetAccessionFields) or ( - model == TargetSequence and field not in TargetSequenceFields - ): - raise HTTPException(422, f"Field `{field.name}` is incompatible with target model `{model}`.") + record_to_model_map: dict[RecordNames, RecordModels] = { + RecordNames.experiment: Experiment, + RecordNames.scoreSet: ScoreSet, + } + record_to_assc_map: dict[RecordNames, dict[RecordFields, RecordAssociationTables]] = { + RecordNames.experiment: { + RecordFields.doiIdentifiers: experiments_doi_identifiers_association_table, + RecordFields.publicationIdentifiers: ExperimentPublicationIdentifierAssociation, + RecordFields.rawReadIdentifiers: experiments_raw_read_identifiers_association_table, + RecordFields.keywords: ExperimentControlledKeywordAssociation, + }, + RecordNames.scoreSet: { + RecordFields.doiIdentifiers: score_sets_doi_identifiers_association_table, + RecordFields.publicationIdentifiers: ScoreSetPublicationIdentifierAssociation, + RecordFields.rawReadIdentifiers: score_sets_raw_read_identifiers_association_table, + }, + } - published_score_sets_stmt = select(ScoreSet).where(ScoreSet.published_date.is_not(None)).subquery() + queried_model = record_to_model_map[record] + queried_model_assc = record_to_assc_map[record] - # getattr obscures MyPy errors by coercing return type to Any - model_field = field.value.replace("-", "_") - column_field = getattr(model, model_field) - query = ( - select(column_field, func.count(column_field)) - .join(TargetGene) - .group_by(column_field) - .join_from(TargetGene, published_score_sets_stmt) - ) + if field is None or field not in queried_model_assc: + return queried_model, None - return db.execute(query).all() + return queried_model, queried_model_assc[field] -# Accession based targets only. -@router.get("/target/accession/{field}", status_code=200, response_model=dict[str, int]) -def target_accessions_by_field(field: TargetAccessionFields, db: Session = Depends(get_db)) -> dict[str, int]: - """ - Returns a dictionary of counts for the distinct values of the provided `field` (member of the `target_accessions` table). - Don't include any NULL field values. - """ - return { - field_val: count - for field_val, count in _target_from_field_and_model(db, TargetAccession, field) - if field_val is not None - } +def _join_model_and_filter_unpublished(query: Select, model: RecordModels) -> Select: + return query.join(model).where(model.published_date.is_not(None)) -# Sequence based targets only. -@router.get("/target/sequence/{field}", status_code=200, response_model=dict[str, int]) -def target_sequences_by_field(field: TargetSequenceFields, db: Session = Depends(get_db)) -> dict[str, int]: - """ - Returns a dictionary of counts for the distinct values of the provided `field` (member of the `target_sequences` table). - Don't include any NULL field values. - """ - return { - field_val: count - for field_val, count in _target_from_field_and_model(db, TargetSequence, field) - if field_val is not None - } +def _count_for_identifier_in_query(db: Session, query: Select[tuple[Any, int]]) -> dict[Any, int]: + return {value: count for value, count in db.execute(query).all() if value is not None} + + +######################################################################################## +# Record statistics +######################################################################################## -# Statistics on fields relevant to both accession and sequence based targets. Generally, these require custom logic to harmonize both target sub types. -@router.get("/target/gene/{field}", status_code=200, response_model=dict[str, int]) -def target_genes_by_field(field: TargetGeneFields, db: Session = Depends(get_db)) -> dict[str, int]: +@router.get( + "/record/{record}/keywords", status_code=200, response_model=Union[dict[str, int], dict[str, dict[str, int]]] +) +def experiment_keyword_statistics( + record: RecordNames, db: Session = Depends(get_db) +) -> Union[dict[str, int], dict[str, dict[str, int]]]: """ - Returns a dictionary of counts for the distinct values of the provided `field` (member of the `target_sequences` table). - Don't include any NULL field values. Each field here is handled individually because of the unique structure of this - target gene object- fields might require information from both TargetGene subtypes (accession and sequence). + Returns a dictionary of counts for the distinct values of the `value` field (member of the `controlled_keywords` table). + Don't include any NULL field values. Don't include any keywords from unpublished experiments. """ - association_tables: dict[TargetGeneFields, Union[type[EnsemblOffset], type[RefseqOffset], type[UniprotOffset]]] = { - TargetGeneFields.ensemblIdentifier: EnsemblOffset, - TargetGeneFields.refseqIdentifier: RefseqOffset, - TargetGeneFields.uniprotIdentifier: UniprotOffset, - } - identifier_models: dict[ - TargetGeneFields, Union[type[EnsemblIdentifier], type[RefseqIdentifier], type[UniprotIdentifier]] - ] = { - TargetGeneFields.ensemblIdentifier: EnsemblIdentifier, - TargetGeneFields.refseqIdentifier: RefseqIdentifier, - TargetGeneFields.uniprotIdentifier: UniprotIdentifier, - } + if record == RecordNames.scoreSet: + raise HTTPException( + 422, + "The 'keywords' field can only be used with the 'experiment' model. Score sets do not have associated keywords.", + ) - # All targets linked to a published score set. - published_score_sets_stmt = select(TargetGene).join(ScoreSet).where(ScoreSet.published_date.is_not(None)).subquery() + queried_model, queried_assc = _model_and_association_from_record_field(record, RecordFields.keywords) - # Assumes identifiers cannot be duplicated within a Target. - if field in identifier_models.keys(): - # getattr obscures MyPy errors by coercing return type to Any - attr_for_identifier = getattr(identifier_models[field], "identifier") + if queried_assc is None: + raise HTTPException(500, "No association table associated with the keywords field when one was expected.") - query = ( - select(attr_for_identifier, func.count(attr_for_identifier)) - .join(association_tables[field]) - .join(published_score_sets_stmt) - .group_by(attr_for_identifier) - ) + query = _join_model_and_filter_unpublished( + select(ControlledKeyword.value, func.count(ControlledKeyword.value)).join(queried_assc), queried_model + ).group_by(ControlledKeyword.value) - return {identifier: count for identifier, count in db.execute(query).all() if identifier is not None} + return _count_for_identifier_in_query(db, query) - # Can't join a TargetGene query to TargetGene query, so just select the desired columns directly from the subquery. - elif field is TargetGeneFields.category: - query = select(published_score_sets_stmt.c.category, func.count(published_score_sets_stmt.c.category)).group_by( - published_score_sets_stmt.c.category - ) - return {category: count for category, count in db.execute(query).all() if category is not None} +@router.get("/record/{record}/publication-identifiers", status_code=200, response_model=dict[str, dict[str, int]]) +def experiment_publication_identifier_statistics( + record: RecordNames, db: Session = Depends(get_db) +) -> dict[str, dict[str, int]]: + """ + Returns a dictionary of counts for the distinct values of the `identifier` field (member of the `publication_identifiers` table). + Don't include any publication identifiers from unpublished experiments. + """ + queried_model, queried_assc = _model_and_association_from_record_field(record, RecordFields.publicationIdentifiers) - # Target gene organism needs special handling: it is stored differently between accession and sequence Targets. - elif field is TargetGeneFields.organism: - sequence_based_targets_query = ( - select(Taxonomy.organism_name, func.count(Taxonomy.organism_name)) - .join(TargetSequence) - .join(published_score_sets_stmt) - .group_by(Taxonomy.organism_name) + if queried_assc is None: + raise HTTPException( + 500, "No association table associated with the publication identifiers field when one was expected." ) - accession_based_targets_query = select(func.count(TargetAccession.id)).join(published_score_sets_stmt) - organisms: dict[str, int] = { - organism: count - for organism, count in db.execute(sequence_based_targets_query).all() - if organism is not None - } - accession_count = db.execute(accession_based_targets_query).scalar_one_or_none() + query = _join_model_and_filter_unpublished( + select( + PublicationIdentifier.identifier, + PublicationIdentifier.db_name, + func.count(PublicationIdentifier.identifier), + ).join(queried_assc), + queried_model, + ).group_by(PublicationIdentifier.identifier, PublicationIdentifier.db_name) - # NOTE: For now (forever?), all accession based targets are human genomic sequences. It is possible this - # assumption changes if we add mouse (or other non-human) genomes to MaveDB. - if "Homo sapiens" in organisms and accession_count: - organisms["Homo sapiens"] += accession_count - elif accession_count: - organisms["Homo sapiens"] = accession_count + publication_identifiers: dict[str, dict[str, int]] = {} - return organisms + for identifier, db_name, count in db.execute(query).all(): + # We don't need to worry about overwriting existing identifiers within these internal dictionaries because + # of the SQL group by clause. + if db_name in publication_identifiers: + publication_identifiers[db_name][identifier] = count + else: + publication_identifiers[db_name] = {identifier: count} - # Protection from this case occurs via FastApi/pydantic Enum validation. - else: - raise ValueError(f"Unknown field: {field}") + return publication_identifiers -def _record_from_field_and_model( - db: Session, - model: RecordNames, - field: RecordFields, -): +@router.get("/record/{record}/raw-read-identifiers", status_code=200, response_model=dict[str, int]) +def experiment_raw_read_identifier_statistics(record: RecordNames, db: Session = Depends(get_db)) -> dict[str, int]: """ - Given a member of the RecordNames and RecordFields Enums, generate counts that can be used to create a - statistic for those enums. - - This function should be used for generating statistics for fields shared between Experiments and Score Sets. - If necessary, Experiment Sets can be handled in a similar manner in the future. + Returns a dictionary of counts for the distinct values of the `identifier` field (member of the `raw_read_identifiers` table). + Don't include any raw read identifiers from unpublished experiments. """ - association_tables: dict[ - RecordNames, - dict[ - RecordFields, - Union[ - Table, - type[ExperimentControlledKeywordAssociation], - type[ExperimentPublicationIdentifierAssociation], - type[ScoreSetPublicationIdentifierAssociation], - ], - ], - ] = { - RecordNames.experiment: { - RecordFields.doiIdentifiers: experiments_doi_identifiers_association_table, - RecordFields.publicationIdentifiers: ExperimentPublicationIdentifierAssociation, - RecordFields.rawReadIdentifiers: experiments_raw_read_identifiers_association_table, - RecordFields.keywords: ExperimentControlledKeywordAssociation, - }, - RecordNames.scoreSet: { - RecordFields.doiIdentifiers: score_sets_doi_identifiers_association_table, - RecordFields.publicationIdentifiers: ScoreSetPublicationIdentifierAssociation, - RecordFields.rawReadIdentifiers: score_sets_raw_read_identifiers_association_table, - }, - } - - models: dict[RecordNames, Union[type[Experiment], type[ScoreSet]]] = { - RecordNames.experiment: Experiment, - RecordNames.scoreSet: ScoreSet, - } - - queried_model = models[model] - - # created-by field does not operate on association tables and is defined directly on score set / experiment - # records, so we operate directly on those records. - # getattr obscures MyPy errors by coercing return type to Any - model_created_by_field = getattr(queried_model, "created_by_id") - model_published_data_field = getattr(queried_model, "published_date") - if field is RecordFields.createdBy: - query = ( - select(User.username, func.count(User.id)) - .join(queried_model, model_created_by_field == User.id) - .where(model_published_data_field.is_not(None)) - .group_by(User.id) - ) + queried_model, queried_assc = _model_and_association_from_record_field(record, RecordFields.rawReadIdentifiers) - return db.execute(query).all() - else: - # All assc table identifiers which are linked to a published model. - queried_assc_table = association_tables[model][field] - published_score_sets_statement = ( - select(queried_assc_table).join(queried_model).where(model_published_data_field.is_not(None)).subquery() + if queried_assc is None: + raise HTTPException( + 500, "No association table associated with the raw read identifiers field when one was expected." ) - # Assumes any identifiers / keywords may not be duplicated within a record. - if field is RecordFields.doiIdentifiers: - query = select(DoiIdentifier.identifier, func.count(DoiIdentifier.identifier)).group_by( - DoiIdentifier.identifier - ) - elif field is RecordFields.keywords: - query = select(ControlledKeyword.value, func.count(ControlledKeyword.value)).group_by(ControlledKeyword.value) - elif field is RecordFields.rawReadIdentifiers: - query = select(RawReadIdentifier.identifier, func.count(RawReadIdentifier.identifier)).group_by( - RawReadIdentifier.identifier - ) + query = _join_model_and_filter_unpublished( + select(RawReadIdentifier.identifier, func.count(RawReadIdentifier.identifier)).join(queried_assc), queried_model + ).group_by(RawReadIdentifier.identifier) - # Handle publication identifiers separately since they may have duplicated identifiers - elif field is RecordFields.publicationIdentifiers: - publication_query = ( - select( - PublicationIdentifier.identifier, - PublicationIdentifier.db_name, - func.count(PublicationIdentifier.identifier), - ) - .join(published_score_sets_statement) - .group_by(PublicationIdentifier.identifier, PublicationIdentifier.db_name) - ) + return _count_for_identifier_in_query(db, query) - publication_identifiers: dict[str, dict[str, int]] = {} - for identifier, db_name, count in db.execute(publication_query).all(): - # We don't need to worry about overwriting existing identifiers within these internal dictionaries because - # of the SQL group by clause. - if db_name in publication_identifiers: - publication_identifiers[db_name][identifier] = count - else: - publication_identifiers[db_name] = {identifier: count} +@router.get("/record/{record}/doi-identifiers", status_code=200, response_model=dict[str, int]) +def experiment_doi_identifiers_statistics(record: RecordNames, db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `identifier` field (member of the `doi_identifiers` table). + Don't include any DOI identifiers from unpublished experiments. + """ + queried_model, queried_assc = _model_and_association_from_record_field(record, RecordFields.doiIdentifiers) - return [(db_name, identifiers) for db_name, identifiers in publication_identifiers.items()] + if queried_assc is None: + raise HTTPException( + 500, "No association table associated with the doi identifiers field when one was expected." + ) - # Protection from this case occurs via FastApi/pydantic Enum validation on methods which reference this one. - else: - return [] + query = _join_model_and_filter_unpublished( + select(DoiIdentifier.identifier, func.count(DoiIdentifier.identifier)).join(queried_assc), queried_model + ).group_by(DoiIdentifier.identifier) - return db.execute(query.join(published_score_sets_statement)).all() + return _count_for_identifier_in_query(db, query) -# Model based statistics for shared fields. -# -# NOTE: If custom logic is needed for record models with more specific endpoint paths, -# i.e. non-shared fields, define them above this route so as not to obscure them. -@router.get("/record/{model}/{field}", status_code=200, response_model=Union[dict[str, int], dict[str, dict[str, int]]]) -def record_object_statistics( - model: RecordNames, field: RecordFields, db: Session = Depends(get_db) -) -> Union[dict[str, int], dict[str, dict[str, int]]]: +@router.get("/record/{record}/created-by", status_code=200, response_model=dict[str, int]) +def experiment_created_by_statistics(record: RecordNames, db: Session = Depends(get_db)) -> dict[str, int]: """ - Resolve a dictionary of statistics based on the provided model name and model field. - - Model names and fields should be members of the Enum classes defined above. Providing an invalid model name or - model field will yield a 422 Unprocessable Entity error with details about valid enum values. + Returns a dictionary of counts for the distinct values of the `username` field (member of the `users` table). + Don't include any usernames from unpublished experiments. """ - # Validation to ensure 'keywords' is only used with 'experiment'. - if model == RecordNames.scoreSet and field == RecordFields.keywords: - raise HTTPException( - status_code=422, detail="The 'keywords' field can only be used with the 'experiment' model." - ) + queried_model, queried_assc = _model_and_association_from_record_field(record, RecordFields.createdBy) - count_data = _record_from_field_and_model(db, model, field) + query = ( + select(User.username, func.count(User.id)) + .join(queried_model, queried_model.created_by_id == User.id) + .filter(queried_model.published_date.is_not(None)) + .group_by(User.id) + ) - return {field_val: count for field_val, count in count_data if field_val is not None} + return _count_for_identifier_in_query(db, query) @router.get("/record/{model}/published/count", status_code=200, response_model=dict[str, int]) def record_counts(model: RecordNames, group: Optional[GroupBy] = None, db: Session = Depends(get_db)) -> dict[str, int]: """ - Returns a dictionary of counts for the number of records in each table. + Returns a dictionary of counts for the number of published records of the `model` parameter. + Optionally, group the counts by the published month or year. """ - models: dict[RecordNames, Union[type[Experiment], type[ScoreSet]]] = { - RecordNames.experiment: Experiment, - RecordNames.scoreSet: ScoreSet, - } - - queried_model = models[model] + queried_model, queried_assc = _model_and_association_from_record_field(model, None) - # Protects against Nonetype publication dates with where clause and ignore mypy typing errors in dictcomps. + # Protect against Nonetype publication dates with where clause. + # We can safely ignore Mypy Nonetype errors in the following dictcomps. objs = db.scalars( select(queried_model.published_date) .where(queried_model.published_date.isnot(None)) @@ -380,3 +274,173 @@ def record_counts(model: RecordNames, group: Optional[GroupBy] = None, db: Sessi grouped = {"all": len(objs)} return OrderedDict(sorted(grouped.items())) + + +######################################################################################## +# Target statistics +######################################################################################## + + +##### Accession based targets ##### + + +@router.get("/target/accession/accession", status_code=200, response_model=dict[str, int]) +def target_accessions_accession_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `accession` field (member of the `target_accessions` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetAccession.accession, func.count(TargetAccession.accession)).join(TargetGene), ScoreSet + ).group_by(TargetAccession.accession) + + return _count_for_identifier_in_query(db, query) + + +@router.get("/target/accession/assembly", status_code=200, response_model=dict[str, int]) +def target_accessions_assembly_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `assembly` field (member of the `target_accessions` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetAccession.assembly, func.count(TargetAccession.assembly)).join(TargetGene), ScoreSet + ).group_by(TargetAccession.assembly) + + return _count_for_identifier_in_query(db, query) + + +@router.get("/target/accession/gene", status_code=200, response_model=dict[str, int]) +def target_accessions_gene_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `gene` field (member of the `target_accessions` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetAccession.gene, func.count(TargetAccession.gene)).join(TargetGene), ScoreSet + ).group_by(TargetAccession.gene) + + return _count_for_identifier_in_query(db, query) + + +##### Sequence based targets ##### + + +@router.get("/target/sequence/sequence", status_code=200, response_model=dict[str, int]) +def target_sequences_sequence_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `sequence` field (member of the `target_sequences` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetSequence.sequence, func.count(TargetSequence.sequence)).join(TargetGene), ScoreSet + ).group_by(TargetSequence.sequence) + + return _count_for_identifier_in_query(db, query) + + +@router.get("/target/sequence/sequence-type", status_code=200, response_model=dict[str, int]) +def target_sequences_sequence_type_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `sequence_type` field (member of the `target_sequences` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetSequence.sequence_type, func.count(TargetSequence.sequence_type)).join(TargetGene), ScoreSet + ).group_by(TargetSequence.sequence_type) + + return _count_for_identifier_in_query(db, query) + + +##### Target genes ##### + + +@router.get("/target/gene/category", status_code=200, response_model=dict[str, int]) +def target_genes_category_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `category` field (member of the `target_sequences` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetGene.category, func.count(TargetGene.category)), ScoreSet + ).group_by(TargetGene.category) + + return _count_for_identifier_in_query(db, query) + + +@router.get("/target/gene/organism", status_code=200, response_model=dict[str, int]) +def target_genes_organism_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `organism` field (member of the `taxonomies` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + + NOTE: For now (and perhaps forever), all accession based targets are human genomic sequences (ie: of taxonomy `Homo sapiens`). + It is possible this assumption changes if we add mouse (or other non-human) genomes to MaveDB. + """ + target_sequence_query = _join_model_and_filter_unpublished( + select(Taxonomy.organism_name, func.count(Taxonomy.organism_name)).join(TargetSequence).join(TargetGene), + ScoreSet, + ).group_by(Taxonomy.organism_name) + target_accession_query = _join_model_and_filter_unpublished( + select(func.count(TargetAccession.id)).join(TargetGene), ScoreSet + ) + + # Ensure the `Homo sapiens` key exists in the organisms counts dictionary. + organisms = _count_for_identifier_in_query(db, target_sequence_query) + organisms.setdefault(TARGET_ACCESSION_TAXONOMY, 0) + + count_accession_based_targets = db.execute(target_accession_query).scalar_one_or_none() + if count_accession_based_targets: + organisms[TARGET_ACCESSION_TAXONOMY] += count_accession_based_targets + else: + organisms.pop(TARGET_ACCESSION_TAXONOMY) + + return organisms + + +@router.get("/target/gene/ensembl-identifier", status_code=200, response_model=dict[str, int]) +def target_genes_ensembl_identifier_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `identifier` field (member of the `ensembl_identifiers` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(EnsemblIdentifier.identifier, func.count(EnsemblIdentifier.identifier)) + .join(EnsemblOffset) + .join(TargetGene), + ScoreSet, + ).group_by(EnsemblIdentifier.identifier) + + return _count_for_identifier_in_query(db, query) + + +@router.get("/target/gene/refseq-identifier", status_code=200, response_model=dict[str, int]) +def target_genes_refseq_identifier_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `identifier` field (member of the `refseq_identifiers` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(RefseqIdentifier.identifier, func.count(RefseqIdentifier.identifier)) + .join(RefseqOffset) + .join(TargetGene), + ScoreSet, + ).group_by(RefseqIdentifier.identifier) + + return _count_for_identifier_in_query(db, query) + + +@router.get("/target/gene/uniprot-identifier", status_code=200, response_model=dict[str, int]) +def target_genes_uniprot_identifier_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `identifier` field (member of the `uniprot_identifiers` table). + Don't include any NULL field values. Don't include any targets from unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(UniprotIdentifier.identifier, func.count(UniprotIdentifier.identifier)) + .join(UniprotOffset) + .join(TargetGene), + ScoreSet, + ).group_by(UniprotIdentifier.identifier) + + return _count_for_identifier_in_query(db, query) diff --git a/tests/routers/test_statistics.py b/tests/routers/test_statistics.py index 0cdf6ca6..7196196d 100644 --- a/tests/routers/test_statistics.py +++ b/tests/routers/test_statistics.py @@ -53,11 +53,14 @@ def assert_statistic(desired_field_value, response): ), f"Target accession statistic {desired_field_value} should appear on one (and only one) test score set." -# Test base case empty database responses for each statistic endpoint. +#################################################################################################### +# Test empty database statistics +#################################################################################################### -def test_empty_database_statistics(client): - stats_endpoints = ( +@pytest.mark.parametrize( + "stats_endpoint", + ( "target/accession/accession", "target/accession/assembly", "target/accession/gene", @@ -77,14 +80,19 @@ def test_empty_database_statistics(client): "record/score-set/doi-identifiers", "record/score-set/raw-read-identifiers", "record/score-set/created-by", - ) - for endpoint in stats_endpoints: - response = client.get(f"/api/v1/statistics/{endpoint}") - assert response.status_code == 200, f"Non-200 status code for endpoint {endpoint}." - assert response.json() == {}, f"Non-empty response for endpoint {endpoint}." + ), +) +def test_empty_database_statistics(client, stats_endpoint): + response = client.get(f"/api/v1/statistics/{stats_endpoint}") + assert response.status_code == 200, f"Non-200 status code for endpoint {stats_endpoint}." + assert response.json() == {}, f"Non-empty response for endpoint {stats_endpoint}." +#################################################################################################### # Test target accession statistics +#################################################################################################### + + @pytest.mark.parametrize( "field_value", TARGET_ACCESSION_FIELDS, @@ -100,19 +108,20 @@ def test_target_accession_statistics(client, field_value, setup_acc_scoreset): def test_target_accession_invalid_field(client): """Test target accession statistic response for an invalid target accession field.""" response = client.get("/api/v1/statistics/target/accession/invalid-field") - assert response.status_code == 422 - assert response.json()["detail"][0]["loc"] == ["path", "field"] - assert response.json()["detail"][0]["ctx"]["enum_values"] == TARGET_ACCESSION_FIELDS + assert response.status_code == 404 def test_target_accession_empty_field(client): """Test target accession statistic response for an empty field.""" response = client.get("/api/v1/statistics/target/accession/") assert response.status_code == 404 - assert response.json()["detail"] == "Not Found" +#################################################################################################### # Test target sequence statistics +#################################################################################################### + + @pytest.mark.parametrize( "field_value", TARGET_SEQUENCE_FIELDS, @@ -128,19 +137,18 @@ def test_target_sequence_statistics(client, field_value, setup_seq_scoreset): def test_target_sequence_invalid_field(client): """Test target sequence statistic response for an invalid field.""" response = client.get("/api/v1/statistics/target/sequence/invalid-field") - assert response.status_code == 422 - assert response.json()["detail"][0]["loc"] == ["path", "field"] - assert response.json()["detail"][0]["ctx"]["enum_values"] == TARGET_SEQUENCE_FIELDS + assert response.status_code == 404 def test_target_sequence_empty_field(client): """Test target sequence statistic response for an empty field.""" response = client.get("/api/v1/statistics/target/sequence/") assert response.status_code == 404 - assert response.json()["detail"] == "Not Found" -# Test target gene statistics. +#################################################################################################### +# Test target gene statistics +#################################################################################################### # Desired values live in different spots for fields on target genes because of the differing target sequence @@ -213,19 +221,20 @@ def test_target_gene_identifier_statistiscs( def test_target_gene_invalid_field(client): """Test target gene statistic response for an invalid field.""" response = client.get("/api/v1/statistics/target/gene/invalid-field") - assert response.status_code == 422 - assert response.json()["detail"][0]["loc"] == ["path", "field"] - assert response.json()["detail"][0]["ctx"]["enum_values"] == TARGET_GENE_FIELDS + TARGET_GENE_IDENTIFIER_FIELDS + assert response.status_code == 404 def test_target_gene_empty_field(client): """Test target gene statistic response for an empty field.""" response = client.get("/api/v1/statistics/target/gene/") assert response.status_code == 404 - assert response.json()["detail"] == "Not Found" -# Test Experiment and Score Set statistics +#################################################################################################### +# Test record statistics +#################################################################################################### + + @pytest.mark.parametrize("model_value", RECORD_MODELS) @pytest.mark.parametrize( "mock_publication_fetch", @@ -340,13 +349,20 @@ def test_record_raw_read_identifier_statistics( assert response.json() == {} +@pytest.mark.parametrize("field_value", RECORD_SHARED_FIELDS) +def test_record_statistics_invalid_record(client, field_value): + """Test record model statistic response for a record we don't provide statisticss on.""" + response = client.get(f"/api/v1/statistics/record/invalid-record/{field_value}") + assert response.status_code == 422 + assert response.json()["detail"][0]["loc"] == ["path", "record"] + assert response.json()["detail"][0]["ctx"]["enum_values"] == RECORD_MODELS + + @pytest.mark.parametrize("model_value", RECORD_MODELS) def test_record_statistics_invalid_field(client, model_value): - """Test record model statistic response for an invalid field.""" + """Test record model statistic response for a field we don't provide statisticss on.""" response = client.get(f"/api/v1/statistics/record/{model_value}/invalid-field") - assert response.status_code == 422 - assert response.json()["detail"][0]["loc"] == ["path", "field"] - assert response.json()["detail"][0]["ctx"]["enum_values"] == RECORD_SHARED_FIELDS + assert response.status_code == 404 @pytest.mark.parametrize("model_value", RECORD_MODELS) @@ -360,16 +376,9 @@ def test_record_statistics_empty_field(client, model_value): def test_record_statistics_invalid_record_and_field(client): """Test record model statistic response for an invalid model and field.""" response = client.get("/api/v1/statistics/record/invalid-model/invalid-field") - - # The order of this list should be reliable. - assert response.status_code == 422 - assert response.json()["detail"][0]["loc"] == ["path", "model"] - assert response.json()["detail"][0]["ctx"]["enum_values"] == RECORD_MODELS - assert response.json()["detail"][1]["loc"] == ["path", "field"] - assert response.json()["detail"][1]["ctx"]["enum_values"] == RECORD_SHARED_FIELDS + assert response.status_code == 404 -# Test record counts statistics @pytest.mark.parametrize("model_value", RECORD_MODELS) def test_record_counts_no_published_data(client, model_value, setup_router_db): """Test record counts endpoint for published experiments and score sets.""" From e407383b617969684c2790aacc142e7eb43c5181 Mon Sep 17 00:00:00 2001 From: Ben Capodanno Date: Wed, 5 Mar 2025 17:49:44 -0800 Subject: [PATCH 3/6] Script for populating mapped gene targets via ClinGen For historical data, we do not always have the genes that we mapped a target to. This script can be used to infer the mapped gene from the HGVS string of a mapped variant. --- .../mapped_gene_from_mapped_variant.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 src/mavedb/scripts/mapped_gene_from_mapped_variant.py diff --git a/src/mavedb/scripts/mapped_gene_from_mapped_variant.py b/src/mavedb/scripts/mapped_gene_from_mapped_variant.py new file mode 100644 index 00000000..f1cb641b --- /dev/null +++ b/src/mavedb/scripts/mapped_gene_from_mapped_variant.py @@ -0,0 +1,126 @@ +import json +import logging +import requests +from typing import Sequence, Optional + +import click +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.models.score_set import ScoreSet +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.target_gene import TargetGene +from mavedb.models.variant import Variant + +from mavedb.scripts.environment import script_environment, with_database_session + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +CLINGEN_API_URL = "https://reg.test.genome.network/allele" + + +def get_gene_symbol_from_clingen(hgvs_string: str) -> Optional[str]: + response = requests.get(f"{CLINGEN_API_URL}?hgvs={hgvs_string}") + if response.status_code != 200: + logger.error(f"Failed to query ClinGen API for {hgvs_string}: {response.status_code}") + return None + + data = response.json() + if "aminoAcidAlleles" in data: + return data["aminoAcidAlleles"][0].get("geneSymbol") + elif "transcriptAlleles" in data: + return data["transcriptAlleles"][0].get("geneSymbol") + + return None + + +@script_environment.command() +@with_database_session +@click.argument("urns", nargs=-1) +@click.option("--all", help="Generate gene mappings for every score set in MaveDB.", is_flag=True) +def generate_gene_mappings(db: Session, urns: Sequence[Optional[str]], all: bool): + score_set_ids: Sequence[Optional[int]] + if all: + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info( + f"Command invoked with --all. Routine will generate gene mappings for {len(score_set_ids)} score sets." + ) + else: + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Generating gene mappings for the provided score sets ({len(score_set_ids)}).") + + for ss_id in score_set_ids: + if not ss_id: + continue + + score_set = db.scalar(select(ScoreSet).where(ScoreSet.id == ss_id)) + if not score_set: + logger.warning(f"Could not fetch score set with id={ss_id}.") + continue + + try: + mapped_variant = db.scalars( + select(MappedVariant) + .join(Variant) + .where( + Variant.score_set_id == ss_id, + MappedVariant.current.is_(True), + MappedVariant.post_mapped.isnot(None), + ) + .limit(1) + ).one_or_none() + + if not mapped_variant: + logger.info(f"No current mapped variant found for score set {score_set.urn}.") + continue + + # From line 69, this object must not be None. + hgvs_string = mapped_variant.post_mapped.get("expressions", {})[0].get("value") # type: ignore + if not hgvs_string: + logger.warning(f"No HGVS string found in post_mapped for variant {mapped_variant.id}.") + continue + + gene_symbol = get_gene_symbol_from_clingen(hgvs_string) + if not gene_symbol: + logger.warning(f"No gene symbol found for HGVS string {hgvs_string}.") + continue + + # This script has been designed to work prior to the introduction of multi-target mapping. + # This .one() call reflects those assumptions. + target_gene = db.scalars(select(TargetGene).where(TargetGene.score_set_id == ss_id)).one() + + if target_gene.post_mapped_metadata is None: + logger.warning( + f"Target gene for score set {score_set.urn} has no post_mapped_metadata despite containing current mapped variants." + ) + continue + + # Cannot update JSONB fields directly. They can be converted to dictionaries over Mypy's objections. + post_mapped_metadata = json.loads(json.dumps(dict(target_gene.post_mapped_metadata.copy()))) # type: ignore + if "genomic" in post_mapped_metadata: + key = "genomic" + elif "protein" in post_mapped_metadata: + key = "protein" + else: + logger.warning(f"Unknown post_mapped type for variant {mapped_variant.id}.") + + if "sequence_genes" not in post_mapped_metadata[key]: + post_mapped_metadata[key]["sequence_genes"] = [] + post_mapped_metadata[key]["sequence_genes"].append(gene_symbol) + + target_gene.post_mapped_metadata = post_mapped_metadata + + db.add(target_gene) + db.commit() + logger.info(f"Gene symbol {gene_symbol} added to target gene for score set {score_set.urn}.") + + except Exception as e: + logger.error(f"Failed to generate gene mappings for score set {score_set.urn}: {str(e)}") + db.rollback() + + logger.info("Done generating gene mappings.") + + +if __name__ == "__main__": + generate_gene_mappings() From 9acd2abfbb05e04be39bf9d03d6e1247b37ed668 Mon Sep 17 00:00:00 2001 From: Ben Capodanno Date: Wed, 5 Mar 2025 17:50:52 -0800 Subject: [PATCH 4/6] Statistics count endpoints for variants, mapped variants, and mapped gene targets --- src/mavedb/routers/statistics.py | 86 ++++++++++++++++++++- tests/routers/test_statistics.py | 124 +++++++++++++++++++++++++------ 2 files changed, 184 insertions(+), 26 deletions(-) diff --git a/src/mavedb/routers/statistics.py b/src/mavedb/routers/statistics.py index 137d97f2..bb4b9ea6 100644 --- a/src/mavedb/routers/statistics.py +++ b/src/mavedb/routers/statistics.py @@ -1,5 +1,5 @@ import itertools -from collections import OrderedDict +from collections import OrderedDict, Counter from enum import Enum from typing import Any, Union, Optional @@ -19,6 +19,7 @@ ) from mavedb.models.experiment_controlled_keyword import ExperimentControlledKeywordAssociation from mavedb.models.experiment_publication_identifier import ExperimentPublicationIdentifierAssociation +from mavedb.models.mapped_variant import MappedVariant from mavedb.models.publication_identifier import PublicationIdentifier from mavedb.models.raw_read_identifier import RawReadIdentifier from mavedb.models.refseq_identifier import RefseqIdentifier @@ -36,6 +37,7 @@ from mavedb.models.uniprot_identifier import UniprotIdentifier from mavedb.models.uniprot_offset import UniprotOffset from mavedb.models.user import User +from mavedb.models.variant import Variant router = APIRouter( prefix="/api/v1/statistics", @@ -271,7 +273,7 @@ def record_counts(model: RecordNames, group: Optional[GroupBy] = None, db: Sessi elif group == GroupBy.year: grouped = {k: len(list(g)) for k, g in itertools.groupby(objs, lambda t: t.strftime("%Y"))} # type: ignore else: - grouped = {"all": len(objs)} + grouped = {"count": len(objs)} return OrderedDict(sorted(grouped.items())) @@ -444,3 +446,83 @@ def target_genes_uniprot_identifier_counts(db: Session = Depends(get_db)) -> dic ).group_by(UniprotIdentifier.identifier) return _count_for_identifier_in_query(db, query) + + +# TODO: Test coverage for this route. +@router.get("/target/mapped/gene") +def mapped_target_gene_counts(db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the distinct values of the `gene` property within the `post_mapped_metadata` + field (member of the `target_gene` table). Don't include any NULL field values. Don't include any targets from + unpublished score sets. + """ + query = _join_model_and_filter_unpublished( + select(TargetGene.post_mapped_metadata), + ScoreSet, + ).where(TargetGene.post_mapped_metadata.isnot(None)) + + mapping_metadata = db.scalars(query).all() + gene_counts = Counter( + gene + for metadata in mapping_metadata + for key in ("genomic", "protein") + if key in metadata + for gene in metadata[key].get("sequence_genes", []) + ) + + # The gene will always be a string + return dict(gene_counts) # type: ignore + + +######################################################################################## +# Variant (and mapped variant) statistics +######################################################################################## + + +@router.get("/variant/count", status_code=200, response_model=dict[str, int]) +def variant_counts(group: Optional[GroupBy] = None, db: Session = Depends(get_db)) -> dict[str, int]: + """ + Returns a dictionary of counts for the number of published and distinct variants in the database. + Optionally, group the counts by the day on which the score set (and by extension, the variant) was published. + """ + query = _join_model_and_filter_unpublished(select(ScoreSet.published_date, func.count(Variant.id)), ScoreSet) + + variants = db.execute(query.group_by(ScoreSet.published_date).order_by(ScoreSet.published_date)).all() + if group == GroupBy.month: + grouped = {k: sum(c for _, c in g) for k, g in itertools.groupby(variants, lambda t: t[0].strftime("%Y-%m"))} + elif group == GroupBy.year: + grouped = {k: sum(c for _, c in g) for k, g in itertools.groupby(variants, lambda t: t[0].strftime("%Y"))} + else: + grouped = {"count": sum(count for _, count in variants)} + + return OrderedDict(sorted(grouped.items())) + + +@router.get("/mapped-variant/count", status_code=200, response_model=dict[str, int]) +def mapped_variant_counts( + group: Optional[GroupBy] = None, onlyCurrent: bool = True, db: Session = Depends(get_db) +) -> dict[str, int]: + """ + Returns a dictionary of counts for the number of published and distinct variants in the database. + Optionally, group the counts by the day on which the score set (and by extension, the variant) was published. + Optionally, return the count of all mapped variants, not just the current/most up to date ones. + """ + query = _join_model_and_filter_unpublished( + select(ScoreSet.published_date, func.count(MappedVariant.id)).join( + Variant, Variant.id == MappedVariant.variant_id + ), + ScoreSet, + ) + + if onlyCurrent: + query = query.where(MappedVariant.current.is_(True)) + + variants = db.execute(query.group_by(ScoreSet.published_date).order_by(ScoreSet.published_date)).all() + if group == GroupBy.month: + grouped = {k: sum(c for _, c in g) for k, g in itertools.groupby(variants, lambda t: t[0].strftime("%Y-%m"))} + elif group == GroupBy.year: + grouped = {k: sum(c for _, c in g) for k, g in itertools.groupby(variants, lambda t: t[0].strftime("%Y"))} + else: + grouped = {"count": sum(count for _, count in variants)} + + return OrderedDict(sorted(grouped.items())) diff --git a/tests/routers/test_statistics.py b/tests/routers/test_statistics.py index 7196196d..e39987e4 100644 --- a/tests/routers/test_statistics.py +++ b/tests/routers/test_statistics.py @@ -53,6 +53,14 @@ def assert_statistic(desired_field_value, response): ), f"Target accession statistic {desired_field_value} should appear on one (and only one) test score set." +def add_query_param(url, query_name, query_value): + """Add a group value to the URL if one is provided.""" + if query_name and query_value: + return f"{url}?{query_name}={query_value}" + + return url + + #################################################################################################### # Test empty database statistics #################################################################################################### @@ -380,42 +388,29 @@ def test_record_statistics_invalid_record_and_field(client): @pytest.mark.parametrize("model_value", RECORD_MODELS) -def test_record_counts_no_published_data(client, model_value, setup_router_db): - """Test record counts endpoint for published experiments and score sets.""" - response = client.get(f"/api/v1/statistics/record/{model_value}/published/count") - assert response.status_code == 200 - assert "all" in response.json() - assert response.json()["all"] == 0 - - -@pytest.mark.parametrize("model_value", RECORD_MODELS) -def test_record_counts(client, model_value, setup_router_db, setup_seq_scoreset): - """Test record counts endpoint for published experiments and score sets.""" - response = client.get(f"/api/v1/statistics/record/{model_value}/published/count") - assert response.status_code == 200 - assert "all" in response.json() - assert response.json()["all"] == 1 - - -@pytest.mark.parametrize("model_value", RECORD_MODELS) -@pytest.mark.parametrize("group_value", ["month", "year"]) -def test_record_counts_grouped_no_published_data(client, model_value, group_value, setup_router_db): +@pytest.mark.parametrize("group_value", ["month", "year", None]) +def test_record_counts_no_published_data(client, model_value, group_value, setup_router_db): """Test record counts endpoint grouped by month and year for published experiments and score sets.""" - response = client.get(f"/api/v1/statistics/record/{model_value}/published/count?group={group_value}") + response = client.get( + add_query_param(f"/api/v1/statistics/record/{model_value}/published/count", "group", group_value) + ) + assert response.status_code == 200 assert isinstance(response.json(), dict) for key, value in response.json().items(): assert isinstance(key, str) - assert isinstance(value, int) + assert value == 0 @pytest.mark.parametrize("model_value", RECORD_MODELS) -@pytest.mark.parametrize("group_value", ["month", "year"]) +@pytest.mark.parametrize("group_value", ["month", "year", None]) def test_record_counts_grouped( session, client, model_value, group_value, setup_router_db, setup_seq_scoreset, setup_acc_scoreset ): """Test record counts endpoint grouped by month and year for published experiments and score sets.""" - response = client.get(f"/api/v1/statistics/record/{model_value}/published/count?group={group_value}") + response = client.get( + add_query_param(f"/api/v1/statistics/record/{model_value}/published/count", "group", group_value) + ) assert response.status_code == 200 assert isinstance(response.json(), dict) for key, value in response.json().items(): @@ -437,3 +432,84 @@ def test_record_counts_invalid_group(client): assert response.status_code == 422 assert response.json()["detail"][0]["loc"] == ["query", "group"] assert response.json()["detail"][0]["ctx"]["enum_values"] == ["month", "year"] + + +#################################################################################################### +# Test variant statistics +#################################################################################################### + + +@pytest.mark.parametrize("group_value", ["month", "year", None]) +def test_variant_counts(client, group_value, setup_router_db, setup_seq_scoreset): + """Test variant counts endpoint for published variants.""" + response = client.get(add_query_param("/api/v1/statistics/variant/count", "group", group_value)) + assert response.status_code == 200 + assert isinstance(response.json(), dict) + + for key, value in response.json().items(): + assert isinstance(key, str) + assert value == 3 + + +@pytest.mark.parametrize("group_value", ["month", "year", None]) +def test_variant_counts_no_published_data(client, group_value, setup_router_db): + """Test variant counts endpoint with no published variants.""" + response = client.get(add_query_param("/api/v1/statistics/variant/count", "group", group_value)) + assert response.status_code == 200 + assert isinstance(response.json(), dict) + + for key, value in response.json().items(): + assert isinstance(key, str) + assert value == 0 + + +@pytest.mark.parametrize("group_value", ["month", "year", None]) +def test_mapped_variant_counts_groups(client, group_value, setup_router_db, setup_seq_scoreset): + """Test variant counts endpoint for published variants.""" + url_with_group = add_query_param("/api/v1/statistics/mapped-variant/count", "group", group_value) + response = client.get(url_with_group) + assert response.status_code == 200 + assert isinstance(response.json(), dict) + + for key, value in response.json().items(): + assert isinstance(key, str) + assert isinstance(value, int) + + +@pytest.mark.parametrize("group_value", ["month", "year", None]) +def test_mapped_variant_counts_groups_no_published_data(client, group_value, setup_router_db): + """Test variant counts endpoint with no published variants.""" + url_with_group = add_query_param("/api/v1/statistics/mapped-variant/count", "group", group_value) + response = client.get(url_with_group) + assert response.status_code == 200 + assert isinstance(response.json(), dict) + + for key, value in response.json().items(): + assert isinstance(key, str) + assert value == 0 + + +@pytest.mark.parametrize("current_value", [True, False]) +def test_mapped_variant_counts_current(client, current_value, setup_router_db, setup_seq_scoreset): + """Test variant counts endpoint for published variants.""" + url_with_current = add_query_param("/api/v1/statistics/mapped-variant/count", "current", current_value) + response = client.get(url_with_current) + assert response.status_code == 200 + assert isinstance(response.json(), dict) + + for key, value in response.json().items(): + assert isinstance(key, str) + assert isinstance(value, int) + + +@pytest.mark.parametrize("current_value", [True, False]) +def test_mapped_variant_counts_current_no_published_data(client, current_value, setup_router_db): + """Test variant counts endpoint with no published variants.""" + url_with_current = add_query_param("/api/v1/statistics/mapped-variant/count", "current", current_value) + response = client.get(url_with_current) + assert response.status_code == 200 + assert isinstance(response.json(), dict) + + for key, value in response.json().items(): + assert isinstance(key, str) + assert value == 0 From c6a53bab684a10cc0fafb7d599593fd33097b5be Mon Sep 17 00:00:00 2001 From: Ben Capodanno Date: Thu, 6 Mar 2025 11:08:25 -0800 Subject: [PATCH 5/6] Test utility for inserting mapped variants to DB --- tests/helpers/constants.py | 25 +++++++++++++++++++++++++ tests/helpers/util.py | 26 ++++++++++++++++++++++++++ tests/routers/conftest.py | 2 ++ 3 files changed, 53 insertions(+) diff --git a/tests/helpers/constants.py b/tests/helpers/constants.py index a0aeb42e..03abc856 100644 --- a/tests/helpers/constants.py +++ b/tests/helpers/constants.py @@ -637,6 +637,20 @@ } +TEST_MINIMAL_PRE_MAPPED_METADATA = { + "genomic": {"sequence_id": "ga4gh:SQ.em9khDCUYXrVWBfWr9r8fjBUrTjj1aig", "sequence_type": "dna"} +} + + +TEST_MINIMAL_POST_MAPPED_METADATA = { + "genomic": { + "sequence_id": "ga4gh:SQ.em9khDCUYXrVWBfWr9r8fjBUrTjj1aig", + "sequence_type": "dna", + "sequence_accessions": [VALID_ACCESSION], + "sequence_genes": [VALID_GENE], + } +} + TEST_VARIANT_MAPPING_SCAFFOLD = { "metadata": {}, "computed_genomic_reference_sequence": { @@ -656,6 +670,17 @@ } +TEST_MINIMAL_MAPPED_VARIANT = { + "pre_mapped": {}, + "post_mapped": {}, + "modification_date": datetime.isoformat(datetime.now()), + "mapped_date": datetime.isoformat(datetime.now()), + "current": True, + "vrs_version": "2.0", + "mapping_api_version": "pytest.0.0", +} + + TEST_SCORESET_RANGE = { "wt_score": 1.0, "ranges": [ diff --git a/tests/helpers/util.py b/tests/helpers/util.py index f4df4586..0cbfe4c9 100644 --- a/tests/helpers/util.py +++ b/tests/helpers/util.py @@ -11,9 +11,13 @@ from mavedb.lib.validation.dataframe import validate_and_standardize_dataframe_pair from mavedb.models.contributor import Contributor from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet as ScoreSetDbModel from mavedb.models.license import License +from mavedb.models.target_gene import TargetGene from mavedb.models.user import User +from mavedb.models.variant import Variant from mavedb.view_models.collection import Collection from mavedb.view_models.experiment import Experiment, ExperimentCreate from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate @@ -23,7 +27,10 @@ TEST_COLLECTION, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_EXPERIMENT, + TEST_MINIMAL_PRE_MAPPED_METADATA, + TEST_MINIMAL_POST_MAPPED_METADATA, TEST_MINIMAL_SEQ_SCORESET, + TEST_MINIMAL_MAPPED_VARIANT, ) @@ -185,6 +192,25 @@ def mock_worker_variant_insertion(client, db, data_provider, score_set, scores_c return client.get(f"/api/v1/score-sets/{score_set['urn']}").json() +def create_mapped_variants_for_score_set(db, score_set_urn): + score_set = db.scalar(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)) + targets = db.scalars(select(TargetGene).where(TargetGene.score_set_id == score_set.id)) + variants = db.scalars(select(Variant).where(Variant.score_set_id == score_set.id)).all() + + for variant in variants: + mv = MappedVariant(**TEST_MINIMAL_MAPPED_VARIANT, variant_id=variant.id) + db.add(mv) + + for target in targets: + target.pre_mapped_metadata = TEST_MINIMAL_PRE_MAPPED_METADATA + target.post_mapped_metadata = TEST_MINIMAL_POST_MAPPED_METADATA + db.add(target) + + score_set.mapping_state = MappingState.complete + db.commit() + return + + def create_seq_score_set_with_variants( client, db, data_provider, experiment_urn, scores_csv_path, update=None, counts_csv_path=None ): diff --git a/tests/routers/conftest.py b/tests/routers/conftest.py index f16ff93b..d5a69cd9 100644 --- a/tests/routers/conftest.py +++ b/tests/routers/conftest.py @@ -28,6 +28,7 @@ create_acc_score_set_with_variants, create_experiment, create_seq_score_set_with_variants, + create_mapped_variants_for_score_set, publish_score_set, ) @@ -76,6 +77,7 @@ def setup_seq_scoreset(setup_router_db, session, data_provider, client, data_fil score_set = create_seq_score_set_with_variants( client, session, data_provider, experiment["urn"], data_files / "scores.csv" ) + create_mapped_variants_for_score_set(session, score_set["urn"]) publish_score_set(client, score_set["urn"]) From 9aedb21b3e230b4e22a666eff10e96b89df435fb Mon Sep 17 00:00:00 2001 From: Ben Capodanno Date: Thu, 6 Mar 2025 11:08:49 -0800 Subject: [PATCH 6/6] Add test for statistics mapped target gene counts --- src/mavedb/routers/statistics.py | 1 - tests/routers/test_statistics.py | 19 +++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/mavedb/routers/statistics.py b/src/mavedb/routers/statistics.py index bb4b9ea6..a1308a25 100644 --- a/src/mavedb/routers/statistics.py +++ b/src/mavedb/routers/statistics.py @@ -448,7 +448,6 @@ def target_genes_uniprot_identifier_counts(db: Session = Depends(get_db)) -> dic return _count_for_identifier_in_query(db, query) -# TODO: Test coverage for this route. @router.get("/target/mapped/gene") def mapped_target_gene_counts(db: Session = Depends(get_db)) -> dict[str, int]: """ diff --git a/tests/routers/test_statistics.py b/tests/routers/test_statistics.py index e39987e4..279147e6 100644 --- a/tests/routers/test_statistics.py +++ b/tests/routers/test_statistics.py @@ -12,6 +12,7 @@ TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_SEQ_SCORESET, TEST_PUBMED_IDENTIFIER, + VALID_GENE, ) from tests.helpers.util import ( create_acc_score_set_with_variants, @@ -238,6 +239,20 @@ def test_target_gene_empty_field(client): assert response.status_code == 404 +#################################################################################################### +# Test mapped target gene statistics +#################################################################################################### + + +def test_mapped_target_gene_counts(client, setup_router_db, setup_seq_scoreset): + """Test mapped target gene counts endpoint for published score sets.""" + response = client.get("/api/v1/statistics/target/mapped/gene") + assert response.status_code == 200 + assert isinstance(response.json(), dict) + assert len(response.json().keys()) == 1 + assert response.json()[VALID_GENE] == 1 + + #################################################################################################### # Test record statistics #################################################################################################### @@ -473,7 +488,7 @@ def test_mapped_variant_counts_groups(client, group_value, setup_router_db, setu for key, value in response.json().items(): assert isinstance(key, str) - assert isinstance(value, int) + assert value == 3 @pytest.mark.parametrize("group_value", ["month", "year", None]) @@ -499,7 +514,7 @@ def test_mapped_variant_counts_current(client, current_value, setup_router_db, s for key, value in response.json().items(): assert isinstance(key, str) - assert isinstance(value, int) + assert value == 3 @pytest.mark.parametrize("current_value", [True, False])