diff --git a/src/jua/weather/_query_engine.py b/src/jua/weather/_query_engine.py index 1cca0e6..8a536fc 100644 --- a/src/jua/weather/_query_engine.py +++ b/src/jua/weather/_query_engine.py @@ -382,12 +382,18 @@ def load_raw_forecast( stream: bool = False, print_progress: bool | None = None, ) -> pd.DataFrame: - if payload.geo.type == "point" and payload.geo.method == "bilinear" and stream: - logger.warning( - "Cannot use streaming responses with bilinear interpolation. Setting " - "stream=False." - ) - stream = False + group_by = payload.group_by or [ + "model", + "init_time", + "prediction_timedelta", + "latitude", + "longitude", + ] + if payload.geo.type == "point": + group_by.append("point") + + if all(get_model_meta_info(model).has_grid_access for model in payload.models): + payload.group_by = group_by est_requested_points = payload.num_requested_points() if est_requested_points > self._MAX_POINTS_PER_REQUEST: diff --git a/src/jua/weather/_types/query_payload_types.py b/src/jua/weather/_types/query_payload_types.py index 5d554e9..32777ff 100644 --- a/src/jua/weather/_types/query_payload_types.py +++ b/src/jua/weather/_types/query_payload_types.py @@ -52,6 +52,7 @@ class ForecastQueryPayload(BaseModel): timedelta_unit: Literal["h", "m", "d"] = "m" aggregation: list[str] | None = None variables: list[str] | None = None + group_by: list[str] | None = None def num_requested_points(self) -> int: """Estimate number of requested data rows for this payload.