Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion tests/unit_tests/test_axon.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import asyncio
import contextlib
import re
import threading
import time
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

import aiohttp
import fastapi
import netaddr
import pydantic
import pytest
import uvicorn
from fastapi.testclient import TestClient
from starlette.requests import Request

from bittensor.core.axon import AxonMiddleware, Axon
from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer
from bittensor.core.errors import RunException
from bittensor.core.settings import version_as_int
from bittensor.core.stream import StreamingSynapse
Expand Down Expand Up @@ -765,3 +770,50 @@ async def forward_fn(synapse: streaming_synapse_cls):
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
},
)


@pytest.mark.asyncio
async def test_threaded_fastapi():
server_started = threading.Event()
server_stopped = threading.Event()

@contextlib.asynccontextmanager
async def lifespan(app):
server_started.set()
yield
server_stopped.set()

app = fastapi.FastAPI(
lifespan=lifespan,
)
app.get("/")(lambda: "Hello World")

server = FastAPIThreadedServer(
uvicorn.Config(app, loop="none"),
)
server.start()

server_started.wait(3.0)

async def wait_for_server():
while not (server.started or server_stopped.is_set()):
await asyncio.sleep(1.0)

await asyncio.wait_for(wait_for_server(), 7.0)

assert server.is_running is True

async with aiohttp.ClientSession(
base_url="http://127.0.0.1:8000",
) as session:
async with session.get("/") as response:
assert await response.text() == '"Hello World"'

server.stop()

assert server.should_exit is True

server_stopped.wait()

with pytest.raises(aiohttp.ClientConnectorError):
await session.get("/")
Loading