diff --git a/crates/defguard_setup/src/migration.rs b/crates/defguard_setup/src/migration.rs index 7fccec37c2..3911bb2fa9 100644 --- a/crates/defguard_setup/src/migration.rs +++ b/crates/defguard_setup/src/migration.rs @@ -10,7 +10,7 @@ use axum::{ serve, }; use axum_extra::extract::cookie::Key; -use defguard_common::{VERSION, db::models::Settings}; +use defguard_common::{VERSION, db::models::Settings, types::proxy::ProxyControlMessage}; use defguard_core::{ auth::failed_login::FailedLoginMap, handle_404, @@ -51,16 +51,25 @@ use crate::handlers::{ migration::finish_setup, }; +/// FIXME: This is a workaround which enables us to reuse the same API handlers +/// Helper struct which holds all the event receivers so that channels are not closed. +pub struct MigrationWebapp { + pub router: Router, + _event_rx: mpsc::UnboundedReceiver, + _wireguard_rx: broadcast::Receiver, + _proxy_control_rx: mpsc::Receiver, +} + pub fn build_migration_webapp( pool: PgPool, version: Version, setup_shutdown_tx: Sender<()>, -) -> Router { +) -> MigrationWebapp { let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); let (webhook_tx, webhook_rx) = mpsc::unbounded_channel::(); - let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); - let (wireguard_tx, _wireguard_rx) = broadcast::channel::(64); - let (proxy_control_tx, _proxy_control_rx) = mpsc::channel(32); + let (event_tx, event_rx) = mpsc::unbounded_channel::(); + let (wireguard_tx, wireguard_rx) = broadcast::channel::(64); + let (proxy_control_tx, proxy_control_rx) = mpsc::channel(32); let incompatible_components = Arc::new(RwLock::new(IncompatibleComponents::default())); spawn(async move { while event_rx.recv().await.is_some() {} }); let key = Key::from( @@ -73,15 +82,15 @@ pub fn build_migration_webapp( pool.clone(), webhook_tx, webhook_rx, - wireguard_tx, + wireguard_tx.clone(), key, failed_logins.clone(), event_tx, incompatible_components, - proxy_control_tx, + proxy_control_tx.clone(), ); - Router::new() + let router = Router::new() .route("/", get(index)) .route("/{*path}", get(index)) .route("/fonts/{*path}", get(web_asset)) @@ -137,6 +146,14 @@ pub fn build_migration_webapp( .layer(Extension(version)) .layer(Extension(failed_logins)) .layer(Extension(Arc::new(Mutex::new(Some(setup_shutdown_tx))))) + .layer(Extension(proxy_control_tx)); + + MigrationWebapp { + router, + _event_rx: event_rx, + _wireguard_rx: wireguard_rx, + _proxy_control_rx: proxy_control_rx, + } } #[instrument(skip_all)] @@ -146,11 +163,12 @@ pub async fn run_migration_web_server( http_port: u16, ) -> Result<(), anyhow::Error> { let (setup_shutdown_tx, setup_shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - let setup_webapp = build_migration_webapp( + let migration_webapp = build_migration_webapp( pool.clone(), defguard_version::Version::parse(VERSION)?, setup_shutdown_tx, ); + let router = migration_webapp.router; info!("Starting instance migration web server on port {http_port}"); let addr = SocketAddr::new( @@ -160,7 +178,7 @@ pub async fn run_migration_web_server( let listener = TcpListener::bind(&addr).await?; serve( listener, - setup_webapp.into_make_service_with_connect_info::(), + router.into_make_service_with_connect_info::(), ) .with_graceful_shutdown(async move { setup_shutdown_rx.await.ok();