diff --git a/bittensor/axon.py b/bittensor/axon.py index 55db8bcea1..8cefadfe61 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -31,8 +31,9 @@ import traceback import typing import uuid -from inspect import Parameter, Signature, signature -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +import warnings +from inspect import signature, Signature, Parameter +from typing import List, Optional, Tuple, Callable, Any, Dict, Awaitable import uvicorn from fastapi import APIRouter, Depends, FastAPI @@ -485,17 +486,50 @@ def verify_custom(synapse: MyCustomSynapse): async def endpoint(*args, **kwargs): start_time = time.time() - response_synapse = forward_fn(*args, **kwargs) - if isinstance(response_synapse, Awaitable): - response_synapse = await response_synapse - return await self.middleware_cls.synapse_to_response( - synapse=response_synapse, start_time=start_time - ) + response = forward_fn(*args, **kwargs) + if isinstance(response, Awaitable): + response = await response + if isinstance(response, bittensor.Synapse): + return await self.middleware_cls.synapse_to_response( + synapse=response, start_time=start_time + ) + else: + response_synapse = getattr(response, "synapse", None) + if response_synapse is None: + warnings.warn( + "The response synapse is None. The input synapse will be used as the response synapse. " + "Reliance on forward_fn modifying input synapse as a side-effects is deprecated. " + "Explicitly set `synapse` on response object instead.", + DeprecationWarning, + ) + # Replace with `return response` in next major version + response_synapse = args[0] + + return await self.middleware_cls.synapse_to_response( + synapse=response_synapse, + start_time=start_time, + response_override=response, + ) + + return_annotation = forward_sig.return_annotation + + if isinstance(return_annotation, type) and issubclass( + return_annotation, bittensor.Synapse + ): + if issubclass( + return_annotation, + bittensor.StreamingSynapse, + ): + warnings.warn( + "The forward_fn return annotation is a subclass of bittensor.StreamingSynapse. " + "Most likely the correct return annotation would be BTStreamingResponse." + ) + else: + return_annotation = JSONResponse - # replace the endpoint signature, but set return annotation to JSONResponse endpoint.__signature__ = Signature( # type: ignore parameters=list(forward_sig.parameters.values()), - return_annotation=JSONResponse, + return_annotation=return_annotation, ) # Add the endpoint to the router, making it available on both GET and POST methods @@ -1433,14 +1467,21 @@ async def run( @classmethod async def synapse_to_response( - cls, synapse: bittensor.Synapse, start_time: float - ) -> JSONResponse: + cls, + synapse: bittensor.Synapse, + start_time: float, + *, + response_override: Optional[Response] = None, + ) -> Response: """ Converts the Synapse object into a JSON response with HTTP headers. Args: - synapse (bittensor.Synapse): The Synapse object representing the request. - start_time (float): The timestamp when the request processing started. + synapse: The Synapse object representing the request. + start_time: The timestamp when the request processing started. + response_override: + Instead of serializing the synapse, mutate the provided response object. + This is only really useful for StreamingSynapse responses. Returns: Response: The final HTTP response, with updated headers, ready to be sent back to the client. @@ -1459,11 +1500,14 @@ async def synapse_to_response( synapse.axon.process_time = time.time() - start_time - serialized_synapse = await serialize_response(response_content=synapse) - response = JSONResponse( - status_code=synapse.axon.status_code, - content=serialized_synapse, - ) + if response_override: + response = response_override + else: + serialized_synapse = await serialize_response(response_content=synapse) + response = JSONResponse( + status_code=synapse.axon.status_code, + content=serialized_synapse, + ) try: updated_headers = synapse.to_headers() diff --git a/bittensor/stream.py b/bittensor/stream.py index e0dc17c42c..3a82edc15a 100644 --- a/bittensor/stream.py +++ b/bittensor/stream.py @@ -1,3 +1,5 @@ +import typing + from aiohttp import ClientResponse import bittensor @@ -49,16 +51,24 @@ class BTStreamingResponse(_StreamingResponse): provided by the subclass. """ - def __init__(self, model: BTStreamingResponseModel, **kwargs): + def __init__( + self, + model: BTStreamingResponseModel, + *, + synapse: typing.Optional["StreamingSynapse"] = None, + **kwargs, + ): """ Initializes the BTStreamingResponse with the given token streamer model. Args: model: A BTStreamingResponseModel instance containing the token streamer callable, which is responsible for generating the content of the response. + synapse: The response Synapse to be used to update the response headers etc. **kwargs: Additional keyword arguments passed to the parent StreamingResponse class. """ super().__init__(content=iter(()), **kwargs) self.token_streamer = model.token_streamer + self.synapse = synapse async def stream_response(self, send: Send): """ @@ -139,4 +149,4 @@ def create_streaming_response( """ model_instance = BTStreamingResponseModel(token_streamer=token_streamer) - return self.BTStreamingResponse(model_instance) + return self.BTStreamingResponse(model_instance, synapse=self) diff --git a/tests/unit_tests/test_axon.py b/tests/unit_tests/test_axon.py index c25fc2e54e..7ba433a151 100644 --- a/tests/unit_tests/test_axon.py +++ b/tests/unit_tests/test_axon.py @@ -23,20 +23,21 @@ import time from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from unittest import IsolatedAsyncioTestCase from unittest.mock import AsyncMock, MagicMock, patch # Third Party +import fastapi import netaddr - +import pydantic import pytest from starlette.requests import Request from fastapi.testclient import TestClient # Bittensor import bittensor -from bittensor import Synapse, RunException +from bittensor import Synapse, RunException, StreamingSynapse from bittensor.axon import AxonMiddleware from bittensor.axon import axon as Axon from bittensor.utils.axon_utils import allowed_nonce_window_ns, calculate_diff_seconds @@ -538,6 +539,39 @@ def http_client(self, axon): async def no_verify_fn(self, synapse): return + class NonDeterministicHeaders(pydantic.BaseModel): + """ + Helper class to verify headers. + + Size headers are non-determistic as for example, header_size depends on non-deterministic + processing-time value. + """ + + bt_header_axon_process_time: float = pydantic.Field(gt=0, lt=30) + timeout: float = pydantic.Field(gt=0, lt=30) + header_size: int = pydantic.Field(None, gt=10, lt=400) + total_size: int = pydantic.Field(gt=100, lt=10000) + content_length: Optional[int] = pydantic.Field( + None, alias="content-length", gt=100, lt=10000 + ) + + def assert_headers(self, response, expected_headers): + expected_headers = { + "bt_header_axon_status_code": "200", + "bt_header_axon_status_message": "Success", + **expected_headers, + } + headers = dict(response.headers) + non_deterministic_headers_names = { + field.alias or field_name + for field_name, field in self.NonDeterministicHeaders.model_fields.items() + } + non_deterministic_headers = { + field: headers.pop(field, None) for field in non_deterministic_headers_names + } + assert headers == expected_headers + self.NonDeterministicHeaders.model_validate(non_deterministic_headers) + async def test_unknown_path(self, http_client): response = http_client.get("/no_such_path") assert (response.status_code, response.json()) == ( @@ -563,6 +597,14 @@ async def test_ping__without_verification(self, http_client, axon): assert response.status_code == 200 response_synapse = Synapse(**response.json()) assert response_synapse.axon.status_code == 200 + self.assert_headers( + response, + { + "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a", + "content-type": "application/json", + "name": "Synapse", + }, + ) @pytest.fixture def custom_synapse_cls(self): @@ -571,6 +613,17 @@ class CustomSynapse(Synapse): return CustomSynapse + @pytest.fixture + def streaming_synapse_cls(self): + class CustomStreamingSynapse(StreamingSynapse): + async def process_streaming_response(self, response): + pass + + def extract_response_json(self, response) -> dict: + return {} + + return CustomStreamingSynapse + async def test_synapse__explicitly_set_status_code( self, http_client, axon, custom_synapse_cls, no_verify_axon ): @@ -678,3 +731,51 @@ def test_nonce_within_allowed_window(nonce_offset_seconds, expected_result): result = is_nonce_within_allowed_window(synapse_nonce, allowed_window_ns) assert result == expected_result, f"Expected {expected_result} but got {result}" + + @pytest.mark.parametrize( + "forward_fn_return_annotation", + [ + None, + fastapi.Response, + bittensor.StreamingSynapse, + ], + ) + async def test_streaming_synapse( + self, + http_client, + axon, + streaming_synapse_cls, + no_verify_axon, + forward_fn_return_annotation, + ): + tokens = [f"data{i}\n" for i in range(10)] + + async def streamer(send): + for token in tokens: + await send( + { + "type": "http.response.body", + "body": token.encode(), + "more_body": True, + } + ) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + async def forward_fn(synapse: streaming_synapse_cls): + return synapse.create_streaming_response(token_streamer=streamer) + + if forward_fn_return_annotation is not None: + forward_fn.__annotations__["return"] = forward_fn_return_annotation + + axon.attach(forward_fn) + + response = http_client.post_synapse(streaming_synapse_cls()) + assert (response.status_code, response.text) == (200, "".join(tokens)) + self.assert_headers( + response, + { + "content-type": "text/event-stream", + "name": "CustomStreamingSynapse", + "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a", + }, + )