diff --git a/quest/api/catalog.py b/quest/api/catalog.py index 7b67f41b..55545192 100644 --- a/quest/api/catalog.py +++ b/quest/api/catalog.py @@ -12,7 +12,7 @@ from .. import util from ..plugins import load_providers from ..static import DatasetSource, UriType -from ..database.database import get_db, db_session +from ..database.database import get_db, db_session, select_datasets @add_async @@ -81,80 +81,93 @@ def search_catalog(uris=None, expand=False, as_dataframe=False, as_geojson=False """ uris = list(itertools.chain(util.listify(uris) or [])) - grouped_uris = util.classify_uris(uris, as_dataframe=False, exclude=[UriType.DATASET, UriType.COLLECTION]) + grouped_uris = util.classify_uris(uris, as_dataframe=False, exclude=[UriType.DATASET], + raise_if_empty=False) services = grouped_uris.get(UriType.SERVICE) or [] + collections = grouped_uris.get(UriType.COLLECTION) or [] - all_datasets = [] + catalog_entries = [d['catalog_entry'] for d in select_datasets(lambda c: c.collection.name in collections)] + + all_catalog_entries = list() filters = filters or dict() for name in services: - provider, service, _ = util.parse_service_uri(name) + provider, service, catalog_entry = util.parse_service_uri(name) + if catalog_entry is not None: + catalog_entries.append(name) + continue provider_plugin = load_providers()[provider] - tmp_datasets = provider_plugin.search_catalog(service, update_cache=update_cache, **filters) - all_datasets.append(tmp_datasets) - - all_datasets.append(tmp_datasets) - - # drop duplicates fails when some columns have nested list/tuples like - # _geom_coords. so drop based on index - datasets = pd.concat(all_datasets) - datasets['index'] = datasets.index - datasets = datasets.drop_duplicates(subset='index') - datasets = datasets.set_index('index').sort_index() + tmp_catalog_entries = provider_plugin.search_catalog(service, update_cache=update_cache, **filters) + all_catalog_entries.append(tmp_catalog_entries) + + if catalog_entries: + all_catalog_entries.append(get_metadata(catalog_entries, as_dataframe=True)) + + if all_catalog_entries: + # drop duplicates fails when some columns have nested list/tuples like + # _geom_coords. so drop based on index + catalog_entries = pd.concat(all_catalog_entries) + catalog_entries['index'] = catalog_entries.index + catalog_entries = catalog_entries.drop_duplicates(subset='index') + catalog_entries = catalog_entries.set_index('index').sort_index() + else: + catalog_entries = pd.DataFrame() # apply any specified filters for k, v in filters.items(): - if datasets.empty: + if catalog_entries.empty: break # if dataframe is empty then doesn't try filtering any further else: if k == 'bbox': bbox = util.bbox2poly(*[float(x) for x in util.listify(v)], as_shapely=True) - idx = datasets.intersects(bbox) # http://geopandas.org/reference.html#GeoSeries.intersects - datasets = datasets[idx] + idx = catalog_entries.intersects(bbox) # http://geopandas.org/reference.html#GeoSeries.intersects + catalog_entries = catalog_entries[idx] elif k == 'geom_type': - idx = datasets.geom_type.str.contains(v).fillna(value=False) - datasets = datasets[idx] + idx = catalog_entries.geom_type.str.contains(v).fillna(value=False) + catalog_entries = catalog_entries[idx] elif k == 'parameter': - idx = datasets.parameters.str.contains(v) - datasets = datasets[idx] + idx = catalog_entries.parameters.str.contains(v) + catalog_entries = catalog_entries[idx] elif k == 'display_name': - idx = datasets.display_name.str.contains(v) - datasets = datasets[idx] + idx = catalog_entries.display_name.str.contains(v) + catalog_entries = catalog_entries[idx] elif k == 'description': - idx = datasets.display_name.str.contains(v) - datasets = datasets[idx] + idx = catalog_entries.display_name.str.contains(v) + catalog_entries = catalog_entries[idx] elif k == 'search_terms': - idx = np.column_stack([datasets[col].apply(str).str.contains(search_term, na=False) - for col, search_term in itertools.product(datasets.columns, v)]).any(axis=1) - datasets = datasets[idx] + idx = np.column_stack([ + catalog_entries[col].apply(str).str.contains(search_term, na=False) + for col, search_term in itertools.product(catalog_entries.columns, v) + ]).any(axis=1) + catalog_entries = catalog_entries[idx] else: - idx = datasets.metadata.map(lambda x: _multi_index(x, k) == v) - datasets = datasets[idx] + idx = catalog_entries.metadata.map(lambda x: _multi_index(x, k) == v) + catalog_entries = catalog_entries[idx] if queries is not None: for query in queries: - datasets = datasets.query(query) + catalog_entries = catalog_entries.query(query) if not (expand or as_dataframe or as_geojson): - return datasets.index.astype('unicode').tolist() + return catalog_entries.index.astype('unicode').tolist() if as_geojson: - if datasets.empty: + if catalog_entries.empty: return geojson.FeatureCollection([]) else: - return json.loads(datasets.to_json(default=util.to_json_default_handler)) + return json.loads(catalog_entries.to_json(default=util.to_json_default_handler)) if not as_dataframe: - datasets = datasets.to_dict(orient='index') + catalog_entries = catalog_entries.to_dict(orient='index') - return datasets + return catalog_entries def _multi_index(d, index): diff --git a/test/data.py b/test/data.py index e131341f..e5bf9a5f 100644 --- a/test/data.py +++ b/test/data.py @@ -307,8 +307,8 @@ ('svc://usgs-nwis:dv', 36368, 1000), ('svc://usgs-nwis:iv', 20412, 1000), ('svc://wmts:seamless_imagery', 1, 1), - ('svc://cuahsi-hydroshare:hs_geo', 1090, 1000), - ('svc://cuahsi-hydroshare:hs_norm', 2331, 1000), + # ('svc://cuahsi-hydroshare:hs_geo', 1090, 1000), + # ('svc://cuahsi-hydroshare:hs_norm', 2331, 1000), ] diff --git a/test/test_catalog.py b/test/test_catalog.py index 77b56af5..a03cae01 100644 --- a/test/test_catalog.py +++ b/test/test_catalog.py @@ -27,6 +27,16 @@ def test_add_datasets(api, catalog_entry): assert b == c +def test_search_catalog_with_no_uris(api): + catalog_entries = api.search_catalog() + assert catalog_entries == [] + + +def test_search_catalog_with_collection(api): + api.set_active_project('test_data') + catalog_entries = api.search_catalog('col1') + assert catalog_entries == ['svc://usgs-nwis:iv/01516350'] + @pytest.mark.slow @pytest.mark.parametrize("service, expected, tolerance", SERVICES_CATALOG_COUNT) def test_search_catalog_from_service(api, service, expected, tolerance):