From fb44de80ff541d064ccbe38a1559257a4b04b800 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Sun, 16 Feb 2025 16:44:26 +0100 Subject: [PATCH 1/3] add csv and geojson-seq output format --- runtimes/eoapi/stac/eoapi/stac/api.py | 35 +++ runtimes/eoapi/stac/eoapi/stac/app.py | 5 +- runtimes/eoapi/stac/eoapi/stac/client.py | 306 +++++++++++++++++-- runtimes/eoapi/stac/eoapi/stac/extensions.py | 21 ++ 4 files changed, 344 insertions(+), 23 deletions(-) diff --git a/runtimes/eoapi/stac/eoapi/stac/api.py b/runtimes/eoapi/stac/eoapi/stac/api.py index 0fef3d1..5a7f05c 100644 --- a/runtimes/eoapi/stac/eoapi/stac/api.py +++ b/runtimes/eoapi/stac/eoapi/stac/api.py @@ -157,6 +157,8 @@ def register_get_item_collection(self): "content": { MimeTypes.geojson.value: {}, MimeTypes.html.value: {}, + MimeTypes.csv.value: {}, + MimeTypes.geojsonseq.value: {}, }, "model": api.ItemCollection, }, @@ -187,6 +189,8 @@ def register_get_search(self): "content": { MimeTypes.geojson.value: {}, MimeTypes.html.value: {}, + MimeTypes.csv.value: {}, + MimeTypes.geojsonseq.value: {}, }, "model": api.ItemCollection, }, @@ -199,3 +203,34 @@ def register_get_search(self): self.client.get_search, self.search_get_request_model ), ) + + def register_post_search(self): + """Register search endpoint (POST /search). + + Returns: + None + """ + self.router.add_api_route( + name="Search", + path="/search", + response_model=api.ItemCollection + if self.settings.enable_response_models + else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + MimeTypes.csv.value: {}, + MimeTypes.geojsonseq.value: {}, + }, + "model": api.ItemCollection, + }, + }, + response_class=GeoJSONResponse, + response_model_exclude_unset=True, + response_model_exclude_none=True, + methods=["POST"], + endpoint=create_async_endpoint( + self.client.post_search, self.search_post_request_model + ), + ) diff --git a/runtimes/eoapi/stac/eoapi/stac/app.py b/runtimes/eoapi/stac/eoapi/stac/app.py index 1d2b555..6fb7bd3 100644 --- a/runtimes/eoapi/stac/eoapi/stac/app.py +++ b/runtimes/eoapi/stac/eoapi/stac/app.py @@ -43,6 +43,7 @@ from .client import FiltersClient, PgSTACClient from .config import Settings from .extensions import ( + HTMLorGeoMultiOutputExtension, HTMLorGeoOutputExtension, HTMLorJSONOutputExtension, ItemCollectionFilterExtension, @@ -84,7 +85,7 @@ FieldsExtension(), SearchFilterExtension(client=FiltersClient()), # type: ignore TokenPaginationExtension(), - HTMLorGeoOutputExtension(), + HTMLorGeoMultiOutputExtension(), ] # collection_search extensions @@ -111,7 +112,7 @@ FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]), ItemCollectionFilterExtension(client=FiltersClient()), # type: ignore TokenPaginationExtension(), - HTMLorGeoOutputExtension(), + HTMLorGeoMultiOutputExtension(), ] # Request Models diff --git a/runtimes/eoapi/stac/eoapi/stac/client.py b/runtimes/eoapi/stac/eoapi/stac/client.py index bfee9d3..9a21384 100644 --- a/runtimes/eoapi/stac/eoapi/stac/client.py +++ b/runtimes/eoapi/stac/eoapi/stac/client.py @@ -1,15 +1,31 @@ """eoapi-devseed: Custom pgstac client.""" +import csv import re -from typing import Any, Dict, List, Literal, Optional, Type, get_args -from urllib.parse import urljoin +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Literal, + Optional, + Type, + get_args, +) +from urllib.parse import unquote_plus, urljoin import attr import jinja2 +import orjson from fastapi import Request +from geojson_pydantic.geometries import parse_geometry_obj +from stac_fastapi.api.models import JSONResponse from stac_fastapi.pgstac.core import CoreCrudClient from stac_fastapi.pgstac.extensions.filter import FiltersClient as PgSTACFiltersClient +from stac_fastapi.pgstac.models.links import ItemCollectionLinks from stac_fastapi.pgstac.types.search import PgstacSearch +from stac_fastapi.types.errors import NotFoundError from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.stac import ( Collection, @@ -20,12 +36,15 @@ LandingPage, ) from stac_pydantic.links import Relations -from stac_pydantic.shared import MimeTypes +from stac_pydantic.shared import BBox, MimeTypes +from starlette.responses import StreamingResponse from starlette.templating import Jinja2Templates, _TemplateResponse ResponseType = Literal["json", "html"] GeoResponseType = Literal["geojson", "html"] QueryablesResponseType = Literal["jsonschema", "html"] +GeoMultiResponseType = Literal["geojson", "html", "geojsonseq", "csv"] +PostMultiResponseType = Literal["geojson", "geojsonseq", "csv"] jinja2_env = jinja2.Environment( @@ -136,6 +155,57 @@ def create_html_response( ) +def _create_csv_rows(data: Iterable[Dict]) -> Generator[str, None, None]: + """Creates an iterator that returns lines of csv from an iterable of dicts.""" + + class DummyWriter: + """Dummy writer that implements write for use with csv.writer.""" + + def write(self, line: str): + """Return line.""" + return line + + # Get the first row and construct the column names + row = next(data) # type: ignore + fieldnames = row.keys() + writer = csv.DictWriter(DummyWriter(), fieldnames=fieldnames) + + # Write header + yield writer.writerow(dict(zip(fieldnames, fieldnames))) + + # Write first row + yield writer.writerow(row) + + # Write all remaining rows + for row in data: + yield writer.writerow(row) + + +def items_to_csv_rows(items: Iterable[Dict]) -> Generator[str, None, None]: + """Creates an iterator that returns lines of csv from an iterable of dicts.""" + if any(f.get("geometry", None) is not None for f in items): + rows = ( + { + "itemId": f.get("id"), + "collectionId": f.get("collection"), + **f.get("properties", {}), + "geometry": parse_geometry_obj(f["geometry"]).wkt, + } + for f in items + ) + else: + rows = ( + { + "itemId": f.get("id"), + "collectionId": f.get("collection"), + **f.get("properties", {}), + } + for f in items + ) + + return _create_csv_rows(rows) + + @attr.s class FiltersClient(PgSTACFiltersClient): async def get_queryables( @@ -367,43 +437,115 @@ async def item_collection( self, collection_id: str, request: Request, - *args, + bbox: Optional[BBox] = None, + datetime: Optional[str] = None, + limit: Optional[int] = None, + # Extensions + query: Optional[str] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter_expr: Optional[str] = None, + filter_lang: Optional[str] = None, + token: Optional[str] = None, f: Optional[str] = None, **kwargs, ) -> ItemCollection: - items = await super().item_collection(collection_id, request, *args, **kwargs) - output_type: Optional[MimeTypes] if f: output_type = MimeTypes[f] else: - accepted_media = [MimeTypes[v] for v in get_args(GeoResponseType)] + accepted_media = [MimeTypes[v] for v in get_args(GeoMultiResponseType)] output_type = accept_media_type( request.headers.get("accept", ""), accepted_media ) + # Check if collection exist + await self.get_collection(collection_id, request=request) + + base_args = { + "collections": [collection_id], + "bbox": bbox, + "datetime": datetime, + "limit": limit, + "token": token, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + clean = self._clean_search_args( + base_args=base_args, + filter_query=filter_expr, + filter_lang=filter_lang, + fields=fields, + sortby=sortby, + ) + + search_request = self.pgstac_search_model(**clean) + item_collection = await self._search_base(search_request, request=request) + item_collection["links"] = await ItemCollectionLinks( + collection_id=collection_id, request=request + ).get_links(extra_links=item_collection["links"]) + + # Additional Headers for StreamingResponse + additional_headers = {} + links = item_collection.get("links", []) + next_link = next(filter(lambda link: link["rel"] == "next", links), None) + prev_link = next( + filter(lambda link: link["rel"] in ["prev", "previous"], links), None + ) + if next_link or prev_link: + additional_headers["Link"] = ",".join( + [ + f'{link["href"]}; rel="{link["rel"]}"' + for link in [next_link, prev_link] + if link + ] + ) + if output_type == MimeTypes.html: - items["id"] = collection_id + item_collection["id"] = collection_id return create_html_response( request, - items, + item_collection, template_name="items", title=f"{collection_id} items", ) - return items + elif output_type == MimeTypes.csv: + return StreamingResponse( + items_to_csv_rows(item_collection["features"]), + media_type=MimeTypes.csv, + headers={ + "Content-Disposition": "attachment;filename=items.csv", + **additional_headers, + }, + ) + + elif output_type == MimeTypes.geojsonseq: + return StreamingResponse( + (orjson.dumps(f) + b"\n" for f in item_collection["features"]), + media_type=MimeTypes.geojsonseq, + headers={ + "Content-Disposition": "attachment;filename=items.geojson", + **additional_headers, + }, + ) + + # If we have the `fields` extension enabled + # we need to avoid Pydantic validation because the + # Items might not be a valid STAC Item objects + if fields := getattr(search_request, "fields", None): + if fields.include or fields.exclude: + return JSONResponse(item_collection) # type: ignore + + return ItemCollection(**item_collection) async def get_item( self, item_id: str, collection_id: str, request: Request, - *args, f: Optional[str] = None, **kwargs, ) -> Item: - item = await super().get_item(item_id, collection_id, request, *args, **kwargs) - output_type: Optional[MimeTypes] if f: output_type = MimeTypes[f] @@ -413,39 +555,161 @@ async def get_item( request.headers.get("accept", ""), accepted_media ) + # Check if collection exist + await self.get_collection(collection_id, request=request) + + search_request = self.pgstac_search_model( + ids=[item_id], collections=[collection_id], limit=1 + ) + item_collection = await self._search_base(search_request, request=request) + if not item_collection["features"]: + raise NotFoundError( + f"Item {item_id} in Collection {collection_id} does not exist." + ) + if output_type == MimeTypes.html: return create_html_response( request, - item, + item_collection["features"][0], template_name="item", title=f"{collection_id}/{item_id} item", ) - return item + return Item(**item_collection["features"][0]) async def get_search( self, request: Request, - *args, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[BBox] = None, + intersects: Optional[str] = None, + datetime: Optional[str] = None, + limit: Optional[int] = None, + # Extensions + query: Optional[str] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter_expr: Optional[str] = None, + filter_lang: Optional[str] = None, + token: Optional[str] = None, f: Optional[str] = None, **kwargs, ) -> ItemCollection: - items = await super().get_search(request, *args, **kwargs) - output_type: Optional[MimeTypes] if f: output_type = MimeTypes[f] else: - accepted_media = [MimeTypes[v] for v in get_args(GeoResponseType)] + accepted_media = [MimeTypes[v] for v in get_args(GeoMultiResponseType)] output_type = accept_media_type( request.headers.get("accept", ""), accepted_media ) + # Parse request parameters + base_args = { + "collections": collections, + "ids": ids, + "bbox": bbox, + "limit": limit, + "token": token, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean = self._clean_search_args( + base_args=base_args, + intersects=intersects, + datetime=datetime, + fields=fields, + sortby=sortby, + filter_query=filter_expr, + filter_lang=filter_lang, + ) + + search_request = self.pgstac_search_model(**clean) + item_collection = await self._search_base(search_request, request=request) + + # Additional Headers for StreamingResponse + additional_headers = {} + links = item_collection.get("links", []) + next_link = next(filter(lambda link: link["rel"] == "next", links), None) + prev_link = next( + filter(lambda link: link["rel"] in ["prev", "previous"], links), None + ) + if next_link or prev_link: + additional_headers["Link"] = ",".join( + [ + f'{link["href"]}; rel="{link["rel"]}"' + for link in [next_link, prev_link] + if link + ] + ) + if output_type == MimeTypes.html: return create_html_response( request, - items, + item_collection, template_name="search", ) - return items + elif output_type == MimeTypes.csv: + return StreamingResponse( + items_to_csv_rows(item_collection["features"]), + media_type=MimeTypes.csv, + headers={ + "Content-Disposition": "attachment;filename=items.csv", + **additional_headers, + }, + ) + + elif output_type == MimeTypes.geojsonseq: + return StreamingResponse( + (orjson.dumps(f) + b"\n" for f in item_collection["features"]), + media_type=MimeTypes.geojsonseq, + headers={ + "Content-Disposition": "attachment;filename=items.geojson", + **additional_headers, + }, + ) + + if fields := getattr(search_request, "fields", None): + if fields.include or fields.exclude: + return JSONResponse(item_collection) # type: ignore + + return ItemCollection(**item_collection) + + async def post_search( + self, + search_request: PgstacSearch, + request: Request, + **kwargs, + ) -> ItemCollection: + accepted_media = [MimeTypes[v] for v in get_args(PostMultiResponseType)] + output_type = accept_media_type( + request.headers.get("accept", ""), accepted_media + ) + + item_collection = await self._search_base(search_request, request=request) + + if output_type == MimeTypes.csv: + return StreamingResponse( + items_to_csv_rows(item_collection["features"]), + media_type=MimeTypes.csv, + headers={ + "Content-Disposition": "attachment;filename=items.csv", + }, + ) + + elif output_type == MimeTypes.geojsonseq: + return StreamingResponse( + (orjson.dumps(f) + b"\n" for f in item_collection["features"]), + media_type=MimeTypes.geojsonseq, + headers={ + "Content-Disposition": "attachment;filename=items.geojson", + }, + ) + + if fields := getattr(search_request, "fields", None): + if fields.include or fields.exclude: + return JSONResponse(item_collection) # type: ignore + + return ItemCollection(**item_collection) diff --git a/runtimes/eoapi/stac/eoapi/stac/extensions.py b/runtimes/eoapi/stac/eoapi/stac/extensions.py index db4a276..3677457 100644 --- a/runtimes/eoapi/stac/eoapi/stac/extensions.py +++ b/runtimes/eoapi/stac/eoapi/stac/extensions.py @@ -131,6 +131,16 @@ class HTMLorGeoGetRequest(APIRequest): ] = attr.ib(default=None) +@attr.s +class HTMLorGeoGetRequestMulti(APIRequest): + """HTML, GeoJSON, GeoJSONSeq or CSV output.""" + + f: Annotated[ + Optional[Literal["geojson", "html", "csv", "geojsonseq"]], + Query(description="Response MediaType."), + ] = attr.ib(default=None) + + @attr.s(kw_only=True) class HTMLorJSONOutputExtension(ApiExtension): """TiTiler extension.""" @@ -153,6 +163,17 @@ def register(self, app: FastAPI) -> None: pass +@attr.s(kw_only=True) +class HTMLorGeoMultiOutputExtension(ApiExtension): + """TiTiler extension.""" + + GET = HTMLorGeoGetRequestMulti + POST = None + + def register(self, app: FastAPI) -> None: + pass + + @attr.s(kw_only=True) class HTMLorSchemaGetRequest(APIRequest): f: Annotated[ From 76f6b6d8abc260c244c04bb7a57b6ce70ae4143a Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Sun, 16 Feb 2025 16:50:15 +0100 Subject: [PATCH 2/3] fix link format --- runtimes/eoapi/stac/eoapi/stac/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtimes/eoapi/stac/eoapi/stac/client.py b/runtimes/eoapi/stac/eoapi/stac/client.py index 9a21384..0f51019 100644 --- a/runtimes/eoapi/stac/eoapi/stac/client.py +++ b/runtimes/eoapi/stac/eoapi/stac/client.py @@ -494,7 +494,7 @@ async def item_collection( if next_link or prev_link: additional_headers["Link"] = ",".join( [ - f'{link["href"]}; rel="{link["rel"]}"' + f'<{link["href"]}>; rel="{link["rel"]}"' for link in [next_link, prev_link] if link ] @@ -638,7 +638,7 @@ async def get_search( if next_link or prev_link: additional_headers["Link"] = ",".join( [ - f'{link["href"]}; rel="{link["rel"]}"' + f'<{link["href"]}>; rel="{link["rel"]}"' for link in [next_link, prev_link] if link ] From d6466e97bf71b873a440cb3867adbe912229a1e4 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 18 Feb 2025 14:38:04 -0500 Subject: [PATCH 3/3] add token headers for POST --- runtimes/eoapi/stac/eoapi/stac/client.py | 18 ++++++++++++++++++ runtimes/eoapi/stac/pyproject.toml | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/runtimes/eoapi/stac/eoapi/stac/client.py b/runtimes/eoapi/stac/eoapi/stac/client.py index 0f51019..16d4726 100644 --- a/runtimes/eoapi/stac/eoapi/stac/client.py +++ b/runtimes/eoapi/stac/eoapi/stac/client.py @@ -690,12 +690,29 @@ async def post_search( item_collection = await self._search_base(search_request, request=request) + # Additional Headers for StreamingResponse + additional_headers = {} + links = item_collection.get("links", []) + next_link = next(filter(lambda link: link["rel"] == "next", links), None) + prev_link = next( + filter(lambda link: link["rel"] in ["prev", "previous"], links), None + ) + if next_link or prev_link: + additional_headers["Pagination-Token"] = ",".join( + [ + f'<{link["body"]["token"]}>; rel="{link["rel"]}"' + for link in [next_link, prev_link] + if link + ] + ) + if output_type == MimeTypes.csv: return StreamingResponse( items_to_csv_rows(item_collection["features"]), media_type=MimeTypes.csv, headers={ "Content-Disposition": "attachment;filename=items.csv", + **additional_headers, }, ) @@ -705,6 +722,7 @@ async def post_search( media_type=MimeTypes.geojsonseq, headers={ "Content-Disposition": "attachment;filename=items.geojson", + **additional_headers, }, ) diff --git a/runtimes/eoapi/stac/pyproject.toml b/runtimes/eoapi/stac/pyproject.toml index 5da6c54..d594ee4 100644 --- a/runtimes/eoapi/stac/pyproject.toml +++ b/runtimes/eoapi/stac/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "stac-fastapi.pgstac>=4.0,<4.1", + "stac-fastapi.pgstac>=4.0.2,<4.1", "jinja2>=2.11.2,<4.0.0", "starlette-cramjam>=0.4,<0.5", "psycopg_pool",