diff --git a/httpstan/app.py b/httpstan/app.py index 65bdff4ed..c704c5b8c 100644 --- a/httpstan/app.py +++ b/httpstan/app.py @@ -7,6 +7,7 @@ import aiohttp.web +import httpstan.pools import httpstan.routes try: @@ -41,5 +42,7 @@ def make_app() -> aiohttp.web.Application: httpstan.routes.setup_routes(app) # startup and shutdown tasks app["operations"] = {} + httpstan.pools.setup_pools(app) app.on_cleanup.append(_warn_unfinished_operations) + app.on_cleanup.append(httpstan.pools.shutdown_pools) return app diff --git a/httpstan/pools.py b/httpstan/pools.py new file mode 100644 index 000000000..b0102d46e --- /dev/null +++ b/httpstan/pools.py @@ -0,0 +1,47 @@ +import concurrent.futures +import multiprocessing as mp +import signal + +import aiohttp.web + + +def init_call_worker() -> None: + signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore KeyboardInterrupt + + +def setup_pools(app: aiohttp.web.Application) -> None: + """Create any Process or Thread Pools needed by the application + + This won't create the pools immediately, in case a feature that uses them + isn't used, but instead lazily. That's why the pools are represented by a + function instead of the pool exectur object itself. + + """ + fit_executor = None + + def create_fit_executor(shutdown=False): + nonlocal fit_executor + + if shutdown: + if fit_executor is None: + return + + fit_executor.shutdown() + return + + if fit_executor is not None: + return fit_executor + + # Use `get_context` to get a package-specific multiprocessing context. + # See "Contexts and start methods" in the `multiprocessing` docs for details. + fit_executor = concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("fork"), initializer=init_call_worker + ) + + return fit_executor + + app["create_fit_executor"] = create_fit_executor + + +async def shutdown_pools(app: aiohttp.web.Application) -> None: + app["create_fit_executor"](shutdown=True) diff --git a/httpstan/services_stub.py b/httpstan/services_stub.py index 2fc622797..8cd892f3d 100644 --- a/httpstan/services_stub.py +++ b/httpstan/services_stub.py @@ -8,14 +8,11 @@ """ import asyncio import collections -import concurrent.futures import functools import io import logging -import multiprocessing as mp import os import select -import signal import socket import tempfile import typing @@ -23,17 +20,9 @@ import httpstan.cache import httpstan.models -import httpstan.services.arguments as arguments from httpstan.config import HTTPSTAN_DEBUG +from httpstan.services import arguments - -# Use `get_context` to get a package-specific multiprocessing context. -# See "Contexts and start methods" in the `multiprocessing` docs for details. -def init_worker() -> None: - signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore KeyboardInterrupt - - -executor = concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("fork"), initializer=init_worker) logger = logging.getLogger("httpstan") @@ -59,6 +48,7 @@ async def call( function_name: str, model_name: str, fit_name: str, + executor, logger_callback: typing.Optional[typing.Callable] = None, **kwargs: dict, ) -> None: diff --git a/httpstan/views.py b/httpstan/views.py index 669783235..feecdffd2 100644 --- a/httpstan/views.py +++ b/httpstan/views.py @@ -3,13 +3,13 @@ Handlers are separated from the endpoint names. Endpoints are defined in `httpstan.routes`. """ -import asyncio import functools import gzip import http import logging import re import traceback +from types import CoroutineType from typing import Optional, Sequence, cast import aiohttp.web @@ -364,7 +364,7 @@ async def handle_create_fit(request: aiohttp.web.Request) -> aiohttp.web.Respons request.app["operations"][operation_name] = operation_dict return aiohttp.web.json_response(operation_dict, status=201) - def _services_call_done(operation: dict, future: asyncio.Future) -> None: + async def _services_call_done(operation: dict, coroutine: CoroutineType) -> None: """Called when services call (i.e., an operation) is done. This needs to handle both successful and exception-raising calls. @@ -374,11 +374,12 @@ def _services_call_done(operation: dict, future: asyncio.Future) -> None: future: Finished future """ - # either the call succeeded or it raised an exception. - operation["done"] = True - exc = future.exception() - if exc: + try: + await coroutine + logger.info("Operation `%s` finished.", operation["name"]) + operation["result"] = schemas.Fit().load(operation["metadata"]["fit"]) + except Exception as exc: # e.g., "hmc_nuts_diag_e_adapt_wrapper() got an unexpected keyword argument, ..." # e.g., dimension errors in variable declarations # e.g., initialization failed @@ -394,9 +395,9 @@ def _services_call_done(operation: dict, future: asyncio.Future) -> None: httpstan.cache.delete_fit(operation["metadata"]["fit"]["name"]) except KeyError: pass - else: - logger.info(f"Operation `{operation['name']}` finished.") - operation["result"] = schemas.Fit().load(operation["metadata"]["fit"]) + finally: + # either the call succeeded or it raised an exception. + operation["done"] = True operation_name = f'operations/{name.split("/")[-1]}' operation_dict = schemas.Operation().load( @@ -414,12 +415,16 @@ def logger_callback(operation: dict, message: bytes) -> None: operation["metadata"]["progress"] = iteration_info_re.findall(message).pop().decode() logger_callback_partial = functools.partial(logger_callback, operation_dict) - task = asyncio.create_task( - services_stub.call( - function, model_name, operation_dict["metadata"]["fit"]["name"], logger_callback_partial, **args - ) + + call = services_stub.call( + function, + model_name, + operation_dict["metadata"]["fit"]["name"], + request.app["create_fit_executor"](), + logger_callback_partial, + **args, ) - task.add_done_callback(functools.partial(_services_call_done, operation_dict)) + await _services_call_done(operation_dict, call) request.app["operations"][operation_name] = operation_dict return aiohttp.web.json_response(operation_dict, status=201)