From c7b86a11e856c626fc3448e91f0d87c7813ceaa7 Mon Sep 17 00:00:00 2001 From: Stephen Belanger Date: Mon, 29 Sep 2025 22:25:20 +0800 Subject: [PATCH] Make a few more safety improvements --- src/asgi/mod.rs | 45 +++++++++++++++------------------------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/src/asgi/mod.rs b/src/asgi/mod.rs index 3298fe7..05b8a69 100644 --- a/src/asgi/mod.rs +++ b/src/asgi/mod.rs @@ -3,7 +3,7 @@ use std::{ ffi::CString, fs::{read_dir, read_to_string}, path::{Path, PathBuf}, - sync::{Arc, OnceLock, RwLock, Weak}, + sync::{Arc, Mutex, OnceLock, Weak}, }; #[cfg(target_os = "linux")] @@ -26,19 +26,17 @@ type HttpResponseResult = Result; static FALLBACK_RUNTIME: OnceLock = OnceLock::new(); fn fallback_handle() -> tokio::runtime::Handle { - if let Ok(handle) = tokio::runtime::Handle::try_current() { - handle - } else { + tokio::runtime::Handle::try_current().unwrap_or_else(|_| { // No runtime exists, create a fallback one let rt = FALLBACK_RUNTIME.get_or_init(|| { tokio::runtime::Runtime::new().expect("Failed to create fallback tokio runtime") }); rt.handle().clone() - } + }) } /// Global Python event loop handle storage -static PYTHON_EVENT_LOOP: OnceLock>> = OnceLock::new(); +static PYTHON_EVENT_LOOP: OnceLock>> = OnceLock::new(); mod http; mod http_method; @@ -90,22 +88,16 @@ unsafe impl Sync for EventLoopHandle {} /// Ensure a Python event loop exists and return a handle to it fn ensure_python_event_loop() -> Result, HandlerError> { - let weak_handle = PYTHON_EVENT_LOOP.get_or_init(|| RwLock::new(Weak::new())); + let mut guard = PYTHON_EVENT_LOOP + .get_or_init(|| Mutex::new(Weak::new())) + .lock()?; // Try to upgrade the weak reference - if let Some(handle) = weak_handle.read()?.upgrade() { - return Ok(handle); - } - - // Need write lock to create new handle - let mut guard = weak_handle.write()?; - - // Double-check in case another thread created it if let Some(handle) = guard.upgrade() { return Ok(handle); } - // Create new event loop handle + // Create new handle let new_handle = Arc::new(create_event_loop_handle()?); *guard = Arc::downgrade(&new_handle); @@ -159,16 +151,10 @@ impl Asgi { docroot: Option, app_target: Option, ) -> Result { - // Determine document root - let docroot = PathBuf::from(if let Some(docroot) = docroot { - docroot - } else { - current_dir() - .map(|path| path.to_string_lossy().to_string()) - .map_err(HandlerError::CurrentDirectoryError)? - }); - let target = app_target.unwrap_or_default(); + let docroot = docroot + .map(|d| Ok(PathBuf::from(d))) + .unwrap_or_else(|| current_dir().map_err(HandlerError::CurrentDirectoryError))?; // Get or create shared Python event loop let event_loop_handle = ensure_python_event_loop()?; @@ -181,8 +167,10 @@ impl Asgi { .canonicalize() .map_err(HandlerError::EntrypointNotFoundError)?; - let code = read_to_string(entrypoint).map_err(HandlerError::EntrypointNotFoundError)?; - let code = CString::new(code).map_err(HandlerError::StringCovertError)?; + let code = read_to_string(entrypoint) + .map_err(HandlerError::EntrypointNotFoundError) + .and_then(|s| CString::new(s).map_err(HandlerError::StringCovertError))?; + let file_name = CString::new(format!("{}.py", target.file)).map_err(HandlerError::StringCovertError)?; let module_name = @@ -374,9 +362,6 @@ fn setup_python_paths(py: Python, docroot: &Path) -> PyResult<()> { /// Start a Python thread that runs the event loop forever fn start_python_event_loop_thread(event_loop: PyObject) { - // Initialize Python for this thread - pyo3::prepare_freethreaded_python(); - Python::with_gil(|py| { // Set the event loop for this thread and run it let asyncio = py.import("asyncio")?;