diff --git a/Cargo.lock b/Cargo.lock index c87a0e3c3d..98016019e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1098,6 +1098,7 @@ dependencies = [ "defguard_event_logger", "defguard_event_router", "defguard_mail", + "defguard_proxy_manager", "defguard_session_manager", "defguard_version", "dotenvy", @@ -1280,6 +1281,29 @@ dependencies = [ "tonic-prost-build", ] +[[package]] +name = "defguard_proxy_manager" +version = "0.0.0" +dependencies = [ + "anyhow", + "axum", + "chrono", + "defguard_common", + "defguard_core", + "defguard_mail", + "defguard_proto", + "defguard_version", + "openidconnect", + "reqwest", + "semver", + "sqlx", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tonic", + "tracing", +] + [[package]] name = "defguard_session_manager" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 29c703a168..07b2e3a0c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ defguard_event_logger = { path = "./crates/defguard_event_logger", version = "0. defguard_event_router = { path = "./crates/defguard_event_router", version = "0.0.0" } defguard_mail = { path = "./crates/defguard_mail", version = "0.0.0" } defguard_proto = { path = "./crates/defguard_proto", version = "0.0.0" } +defguard_proxy_manager = { path = "./crates/defguard_proxy_manager", version = "0.0.0" } defguard_session_manager = { path = "./crates/defguard_session_manager", version = "0.0.0" } defguard_version = { path = "./crates/defguard_version", version = "0.0.0" } defguard_web_ui = { path = "./crates/defguard_web_ui", version = "0.0.0" } diff --git a/crates/defguard/Cargo.toml b/crates/defguard/Cargo.toml index 0dc170ff56..edc03cd2eb 100644 --- a/crates/defguard/Cargo.toml +++ b/crates/defguard/Cargo.toml @@ -14,6 +14,7 @@ defguard_core = { workspace = true } defguard_event_router = { workspace = true } defguard_event_logger = { workspace = true } defguard_mail = { workspace = true } +defguard_proxy_manager = { workspace = true } defguard_session_manager = { workspace = true } defguard_version = { workspace = true } diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 087ac62bdf..ea4023dabc 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -29,7 +29,7 @@ use defguard_core::{ grpc::{ WorkerState, gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, - run_grpc_bidi_stream, run_grpc_server, + run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, @@ -40,6 +40,7 @@ use defguard_core::{ use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; use defguard_event_router::{RouterReceiverSet, run_event_router}; use defguard_mail::{Mail, run_mail_handler}; +use defguard_proxy_manager::{ProxyOrchestrator, ProxyTxSet}; // use defguard_session_manager::run_session_manager; use secrecy::ExposeSecret; use tokio::sync::{broadcast, mpsc::unbounded_channel}; @@ -158,15 +159,13 @@ async fn main() -> Result<(), anyhow::Error> { } } + let proxy_tx = ProxyTxSet::new(wireguard_tx.clone(), mail_tx.clone(), bidi_event_tx.clone()); + let proxy_orchestrator = + ProxyOrchestrator::new(pool.clone(), proxy_tx, Arc::clone(&incompatible_components)); + // run services tokio::select! { - res = run_grpc_bidi_stream( - pool.clone(), - wireguard_tx.clone(), - mail_tx.clone(), - bidi_event_tx, - Arc::clone(&incompatible_components), - ), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:?}"), + res = proxy_orchestrator.run(&config.proxy_url) => error!("ProxyOrchestrator returned early: {res:?}"), res = run_grpc_server( Arc::clone(&worker_state), pool.clone(), diff --git a/crates/defguard_core/src/db/models/enrollment.rs b/crates/defguard_core/src/db/models/enrollment.rs index de33f97f51..129610fbea 100644 --- a/crates/defguard_core/src/db/models/enrollment.rs +++ b/crates/defguard_core/src/db/models/enrollment.rs @@ -4,14 +4,18 @@ use defguard_common::{ config::server_config, db::{ Id, - models::{Settings, user::User}, + models::{Settings, settings::defaults::WELCOME_EMAIL_SUBJECT, user::User}, }, random::gen_alphanumeric, }; -use defguard_mail::templates::{self, TemplateError, safe_tera}; -use sqlx::{Error as SqlxError, PgConnection, PgExecutor, PgPool, query, query_as}; +use defguard_mail::{ + Mail, + templates::{self, TemplateError, safe_tera}, +}; +use sqlx::{Error as SqlxError, PgConnection, PgExecutor, PgPool, Transaction, query, query_as}; use tera::Context; use thiserror::Error; +use tokio::sync::mpsc::UnboundedSender; use tonic::{Code, Status}; pub static ENROLLMENT_TOKEN_TYPE: &str = "ENROLLMENT"; @@ -380,6 +384,80 @@ impl Token { device_info, )?) } + + // Send configured welcome email to user after finishing enrollment + pub async fn send_welcome_email( + &self, + transaction: &mut Transaction<'_, sqlx::Postgres>, + mail_tx: &UnboundedSender, + user: &User, + settings: &Settings, + ip_address: &str, + device_info: Option<&str>, + ) -> Result<(), TokenError> { + debug!("Sending welcome mail to {}", user.username); + let mail = Mail { + to: user.email.clone(), + subject: settings + .enrollment_welcome_email_subject + .clone() + .unwrap_or_else(|| WELCOME_EMAIL_SUBJECT.to_string()), + content: self + .get_welcome_email_content(&mut *transaction, ip_address, device_info) + .await?, + attachments: Vec::new(), + result_tx: None, + }; + match mail_tx.send(mail) { + Ok(()) => { + info!("Sent enrollment welcome mail to {}", user.username); + Ok(()) + } + Err(err) => { + error!("Error sending welcome mail: {err}"); + Err(TokenError::NotificationError(err.to_string())) + } + } + } + + // Notify admin that a user has completed enrollment + pub fn send_admin_notification( + mail_tx: &UnboundedSender, + admin: &User, + user: &User, + ip_address: &str, + device_info: Option<&str>, + ) -> Result<(), TokenError> { + debug!( + "Sending enrollment success notification for user {} to {}", + user.username, admin.username + ); + let mail = Mail { + to: admin.email.clone(), + subject: "[defguard] User enrollment completed".into(), + content: templates::enrollment_admin_notification( + &user.clone().into(), + &admin.clone().into(), + ip_address, + device_info, + )?, + attachments: Vec::new(), + result_tx: None, + }; + match mail_tx.send(mail) { + Ok(()) => { + info!( + "Sent enrollment success notification for user {} to {}", + user.username, admin.username + ); + Ok(()) + } + Err(err) => { + error!("Error sending welcome mail: {err}"); + Err(TokenError::NotificationError(err.to_string())) + } + } + } } pub fn enrollment_welcome_message(settings: &Settings) -> Result { diff --git a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs index 7a3ef88604..575679ab3c 100644 --- a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs +++ b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs @@ -208,10 +208,7 @@ impl OpenIdProvider { } impl OpenIdProvider { - pub(crate) async fn find_by_name<'e, E>( - executor: E, - name: &str, - ) -> Result, SqlxError> + pub async fn find_by_name<'e, E>(executor: E, name: &str) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -230,7 +227,7 @@ impl OpenIdProvider { .await } - pub(crate) async fn get_current<'e, E>(executor: E) -> Result, SqlxError> + pub async fn get_current<'e, E>(executor: E) -> Result, SqlxError> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index 3c770b61bc..236db1b2c8 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -409,7 +409,7 @@ pub(crate) async fn test_directory_sync_connection( } /// Sync user groups with the directory if directory sync is enabled and configured, skip otherwise -pub(crate) async fn sync_user_groups_if_configured( +pub async fn sync_user_groups_if_configured( user: &User, pool: &PgPool, wg_tx: &Sender, diff --git a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs index b76c9170cf..fc975e5425 100644 --- a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs +++ b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs @@ -10,7 +10,7 @@ use crate::{ }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, grpc::{ - client_mfa::{ClientLoginSession, ClientMfaServer}, + proxy::client_mfa::{ClientLoginSession, ClientMfaServer}, utils::parse_client_ip_agent, }, }; diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index b17860caa2..855e52695d 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -31,7 +31,7 @@ static CSRF_COOKIE_NAME: &str = "csrf"; static NONCE_COOKIE_NAME: &str = "nonce"; // The select_account prompt is not supported by all providers, most notably not by JumpCloud. // Currently it's only enabled for Google, as it was tested to work there. -pub(crate) const SELECT_ACCOUNT_SUPPORTED_PROVIDERS: &[&str] = &["Google"]; +pub const SELECT_ACCOUNT_SUPPORTED_PROVIDERS: &[&str] = &["Google"]; use super::LicenseInfo; use crate::{ @@ -126,7 +126,7 @@ async fn get_provider_metadata(url: &str) -> Result) -> CsrfToken { +pub fn build_state(state_data: Option) -> CsrfToken { let csrf_token = CsrfToken::new_random(); if let Some(data) = state_data { let combined = format!("{}.{data}", csrf_token.secret()); @@ -155,7 +155,7 @@ pub(crate) fn extract_state_data(state: &str) -> Option { /// Build OpenID Connect client. /// `url`: redirect/callback URL -pub(crate) async fn make_oidc_client( +pub async fn make_oidc_client( url: Url, provider: &OpenIdProvider, ) -> Result< @@ -186,7 +186,7 @@ pub(crate) async fn make_oidc_client( } /// Get or create `User` from OpenID claims. -pub(crate) async fn user_from_claims( +pub async fn user_from_claims( pool: &PgPool, nonce: Nonce, code: AuthorizationCode, diff --git a/crates/defguard_core/src/enterprise/ldap/utils.rs b/crates/defguard_core/src/enterprise/ldap/utils.rs index 066e0dbbae..2ae32db260 100644 --- a/crates/defguard_core/src/enterprise/ldap/utils.rs +++ b/crates/defguard_core/src/enterprise/ldap/utils.rs @@ -57,7 +57,7 @@ pub(crate) async fn login_through_ldap( } /// Convenience wrapper around [`ldap_update_users_state`] to update a single user. -pub(crate) async fn ldap_update_user_state(user: &mut User, pool: &PgPool) { +pub async fn ldap_update_user_state(user: &mut User, pool: &PgPool) { let vec = vec![user]; Box::pin(ldap_update_users_state(vec, pool)).await; } @@ -77,7 +77,7 @@ pub(crate) async fn ldap_update_users_state(users: Vec<&mut User>, pool: &Pg /// This will set the `ldap_pass_randomized` field to `true` in the user. /// /// If the user already exists, the creation will be skipped. -pub(crate) async fn ldap_add_user(user: &mut User, password: Option<&str>, pool: &PgPool) { +pub async fn ldap_add_user(user: &mut User, password: Option<&str>, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Creating user {user} in LDAP"); if !ldap_sync_allowed_for_user(user, pool).await? { @@ -273,7 +273,7 @@ pub(crate) async fn ldap_remove_users_from_groups( .await; } -pub(crate) async fn ldap_change_password(user: &mut User, password: &str, pool: &PgPool) { +pub async fn ldap_change_password(user: &mut User, password: &str, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Changing password for user {user} in LDAP"); if !ldap_sync_allowed_for_user(user, pool).await? { diff --git a/crates/defguard_core/src/enterprise/mod.rs b/crates/defguard_core/src/enterprise/mod.rs index 4d16e6a429..7cc7699b1c 100644 --- a/crates/defguard_core/src/enterprise/mod.rs +++ b/crates/defguard_core/src/enterprise/mod.rs @@ -16,12 +16,12 @@ use limits::get_counts; use crate::enterprise::license::LicenseTier; /// Helper function to gate features which require a base license (Team or Business tier) -pub(crate) fn is_business_license_active() -> bool { +pub fn is_business_license_active() -> bool { is_license_tier_active(LicenseTier::Business) } /// Helper function to gate features which require an Enterprise tier license -pub(crate) fn is_enterprise_license_active() -> bool { +pub fn is_enterprise_license_active() -> bool { is_license_tier_active(LicenseTier::Enterprise) } diff --git a/crates/defguard_core/src/grpc/client_version.rs b/crates/defguard_core/src/grpc/client_version.rs index d8bc7be95e..af2c03cf0b 100644 --- a/crates/defguard_core/src/grpc/client_version.rs +++ b/crates/defguard_core/src/grpc/client_version.rs @@ -46,7 +46,7 @@ pub(crate) fn parse_client_version_platform( /// Represents a client feature that may have minimum version and OS family requirements. #[derive(Debug)] -pub(crate) enum ClientFeature { +pub enum ClientFeature { ServiceLocations, } @@ -63,7 +63,7 @@ impl ClientFeature { } } - pub(crate) fn is_supported_by_device(&self, info: Option<&DeviceInfo>) -> bool { + pub fn is_supported_by_device(&self, info: Option<&DeviceInfo>) -> bool { let (version, platform) = parse_client_version_platform(info); // No minimum version = matches all diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 7c1bc5fc9e..b1758a3aaa 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -1,84 +1,51 @@ use std::{ collections::hash_map::HashMap, - fs::read_to_string, net::{IpAddr, Ipv4Addr, SocketAddr}, sync::{Arc, Mutex, RwLock}, time::{Duration, Instant}, }; -use axum::http::Uri; +use reqwest::Url; +use serde::Serialize; +use sqlx::PgPool; +use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; +use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; +use tower::ServiceBuilder; + use defguard_common::{ VERSION, auth::claims::ClaimsType, db::{Id, models::Settings}, }; use defguard_mail::Mail; -use defguard_version::{ - ComponentInfo, DefguardComponent, Version, client::ClientVersionInterceptor, - get_tracing_variables, server::DefguardVersionLayer, -}; -use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; -use reqwest::Url; -use serde::Serialize; -use sqlx::PgPool; -use tokio::{ - sync::{ - broadcast::Sender, - mpsc::{self, UnboundedSender}, - }, - time::sleep, -}; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tonic::{ - Code, Streaming, - transport::{ - Certificate, ClientTlsConfig, Endpoint, Identity, Server, ServerTlsConfig, server::Router, - }, -}; -use tower::ServiceBuilder; +use defguard_version::{Version, server::DefguardVersionLayer}; use self::{ - auth::AuthServer, client_mfa::ClientMfaServer, enrollment::EnrollmentServer, - gateway::GatewayServer, interceptor::JwtInterceptor, password_reset::PasswordResetServer, - worker::WorkerServer, + auth::AuthServer, gateway::GatewayServer, interceptor::JwtInterceptor, worker::WorkerServer, }; pub use crate::version::MIN_GATEWAY_VERSION; use crate::{ auth::failed_login::FailedLoginMap, - db::{ - AppEvent, - models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, - }, - enrollment_management::clear_unused_enrollment_tokens, + db::AppEvent, enterprise::{ db::models::{ enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, openid_provider::OpenIdProvider, }, - directory_sync::sync_user_groups_if_configured, - grpc::polling::PollingServer, - handlers::openid_login::{ - SELECT_ACCOUNT_SUPPORTED_PROVIDERS, build_state, make_oidc_client, user_from_claims, - }, is_business_license_active, - ldap::utils::ldap_update_user_state, }, - events::{BidiStreamEvent, GrpcEvent}, + events::GrpcEvent, grpc::gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, server_config, - version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, + version::IncompatibleComponents, }; -static VERSION_ZERO: Version = Version::new(0, 0, 0); - mod auth; -pub(crate) mod client_mfa; pub mod client_version; -pub mod enrollment; pub mod gateway; mod interceptor; -pub mod password_reset; -pub(crate) mod utils; +pub mod proxy; +pub mod utils; pub mod worker; pub mod proto { @@ -92,10 +59,6 @@ pub mod proto { use defguard_proto::{ auth::auth_service_server::AuthServiceServer, gateway::gateway_service_server::GatewayServiceServer, - proxy::{ - AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, core_request, - core_response, proxy_client::ProxyClient, - }, worker::worker_service_server::WorkerServiceServer, }; @@ -107,554 +70,6 @@ pub static HOSTNAME_HEADER: &str = "hostname"; const TEN_SECS: Duration = Duration::from_secs(10); -struct ProxyMessageLoopContext<'a> { - pool: PgPool, - tx: UnboundedSender, - wireguard_tx: Sender, - resp_stream: &'a mut Streaming, - enrollment_server: &'a mut EnrollmentServer, - password_reset_server: &'a mut PasswordResetServer, - client_mfa_server: &'a mut ClientMfaServer, - polling_server: &'a mut PollingServer, - endpoint_uri: &'a Uri, -} - -#[instrument(skip_all)] -async fn handle_proxy_message_loop( - context: ProxyMessageLoopContext<'_>, -) -> Result<(), anyhow::Error> { - let pool = context.pool.clone(); - 'message: loop { - match context.resp_stream.message().await { - Ok(None) => { - info!("stream was closed by the sender"); - break 'message; - } - Ok(Some(received)) => { - debug!("Received message from proxy; ID={}", received.id); - let payload = match received.payload { - // rpc CodeMfaSetupStart return (CodeMfaSetupStartResponse) - Some(core_request::Payload::CodeMfaSetupStart(request)) => { - match context - .enrollment_server - .register_code_mfa_start(request) - .await - { - Ok(response) => { - Some(core_response::Payload::CodeMfaSetupStartResponse(response)) - } - Err(err) => { - error!("Register mfa start error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc CodeMfaSetupFinish return (CodeMfaSetupFinishResponse) - Some(core_request::Payload::CodeMfaSetupFinish(request)) => { - match context - .enrollment_server - .register_code_mfa_finish(request) - .await - { - Ok(response) => { - Some(core_response::Payload::CodeMfaSetupFinishResponse(response)) - } - Err(err) => { - error!("Register MFA finish error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientMfaTokenValidation return (ClientMfaTokenValidationResponse) - Some(core_request::Payload::ClientMfaTokenValidation(request)) => { - match context.client_mfa_server.validate_mfa_token(request).await { - Ok(response_payload) => Some( - core_response::Payload::ClientMfaTokenValidation(response_payload), - ), - Err(err) => { - error!("Client MFA validate token error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc RegisterMobileAuth (RegisterMobileAuthRequest) return (google.protobuf.Empty) - Some(core_request::Payload::RegisterMobileAuth(request)) => { - match context - .enrollment_server - .register_mobile_auth(request) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("Register mobile auth error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc StartEnrollment (EnrollmentStartRequest) returns (EnrollmentStartResponse) - Some(core_request::Payload::EnrollmentStart(request)) => { - match context - .enrollment_server - .start_enrollment(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::EnrollmentStart(response_payload)) - } - Err(err) => { - error!("start enrollment error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ActivateUser (ActivateUserRequest) returns (google.protobuf.Empty) - Some(core_request::Payload::ActivateUser(request)) => { - match context - .enrollment_server - .activate_user(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("activate user error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc CreateDevice (NewDevice) returns (DeviceConfigResponse) - Some(core_request::Payload::NewDevice(request)) => { - match context - .enrollment_server - .create_device(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::DeviceConfig(response_payload)) - } - Err(err) => { - error!("create device error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc GetNetworkInfo (ExistingDevice) returns (DeviceConfigResponse) - Some(core_request::Payload::ExistingDevice(request)) => { - match context - .enrollment_server - .get_network_info(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::DeviceConfig(response_payload)) - } - Err(err) => { - error!("get network info error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc RequestPasswordReset (PasswordResetInitializeRequest) returns (google.protobuf.Empty) - Some(core_request::Payload::PasswordResetInit(request)) => { - match context - .password_reset_server - .request_password_reset(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("password reset init error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc StartPasswordReset (PasswordResetStartRequest) returns (PasswordResetStartResponse) - Some(core_request::Payload::PasswordResetStart(request)) => { - match context - .password_reset_server - .start_password_reset(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::PasswordResetStart(response_payload)) - } - Err(err) => { - error!("password reset start error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ResetPassword (PasswordResetRequest) returns (google.protobuf.Empty) - Some(core_request::Payload::PasswordReset(request)) => { - match context - .password_reset_server - .reset_password(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("password reset error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientMfaStart (ClientMfaStartRequest) returns (ClientMfaStartResponse) - Some(core_request::Payload::ClientMfaStart(request)) => { - match context - .client_mfa_server - .start_client_mfa_login(request) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::ClientMfaStart(response_payload)) - } - Err(err) => { - error!("client MFA start error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientMfaFinish (ClientMfaFinishRequest) returns (ClientMfaFinishResponse) - Some(core_request::Payload::ClientMfaFinish(request)) => { - match context - .client_mfa_server - .finish_client_mfa_login(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::ClientMfaFinish(response_payload)) - } - Err(err) => { - match err.code() { - Code::FailedPrecondition => { - // User not yet done with OIDC authentication. Don't log it - // as an error. - debug!("Client MFA finish error: {err}"); - } - _ => { - // Log other errors as errors. - error!("Client MFA finish error: {err}"); - } - } - Some(core_response::Payload::CoreError(err.into())) - } - } - } - Some(core_request::Payload::ClientMfaOidcAuthenticate(request)) => { - match context - .client_mfa_server - .auth_mfa_session_with_oidc(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("client MFA OIDC authenticate error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc LocationInfo (LocationInfoRequest) returns (LocationInfoResponse) - Some(core_request::Payload::InstanceInfo(request)) => { - match context - .polling_server - .info(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::InstanceInfo(response_payload)) - } - Err(err) => { - if Code::FailedPrecondition == err.code() { - // Ignore the case when we are not enterprise but the client is - // trying to fetch the instance config, - // to avoid spamming the logs with misleading errors. - - debug!( - "A client tried to fetch the instance config, but we are \ - not enterprise." - ); - Some(core_response::Payload::CoreError(err.into())) - } else { - error!("Instance info error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - } - Some(core_request::Payload::AuthInfo(request)) => { - if !is_business_license_active() { - warn!("Enterprise license required"); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::FailedPrecondition as i32, - message: "no valid license".into(), - })) - } else if let Ok(redirect_url) = Url::parse(&request.redirect_url) { - if let Some(provider) = OpenIdProvider::get_current(&pool).await? { - match make_oidc_client(redirect_url, &provider).await { - Ok((_client_id, client)) => { - let mut authorize_url_builder = client - .authorize_url( - CoreAuthenticationFlow::AuthorizationCode, - || build_state(request.state), - Nonce::new_random, - ) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("profile".to_string())); - - if SELECT_ACCOUNT_SUPPORTED_PROVIDERS - .iter() - .all(|p| p.eq_ignore_ascii_case(&provider.name)) - { - authorize_url_builder = authorize_url_builder - .add_prompt( - openidconnect::core::CoreAuthPrompt::SelectAccount, - ); - } - let (url, csrf_token, nonce) = authorize_url_builder.url(); - - Some(core_response::Payload::AuthInfo(AuthInfoResponse { - url: url.into(), - csrf_token: csrf_token.secret().to_owned(), - nonce: nonce.secret().to_owned(), - button_display_name: provider.display_name, - })) - } - Err(err) => { - error!( - "Failed to setup external OIDC provider client: {err}" - ); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message: "failed to build OIDC client".into(), - })) - } - } - } else { - error!("Failed to get current OpenID provider"); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::NotFound as i32, - message: "failed to get current OpenID provider".into(), - })) - } - } else { - error!( - "Invalid redirect URL in authentication info request: {}", - request.redirect_url - ); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message: "invalid redirect URL".into(), - })) - } - } - Some(core_request::Payload::AuthCallback(request)) => { - match Url::parse(&request.callback_url) { - Ok(callback_url) => { - let code = AuthorizationCode::new(request.code); - match user_from_claims( - &pool, - Nonce::new(request.nonce), - code, - callback_url, - ) - .await - { - Ok(mut user) => { - clear_unused_enrollment_tokens(&user, &pool).await?; - if let Err(err) = sync_user_groups_if_configured( - &user, - &pool, - &context.wireguard_tx, - ) - .await - { - error!( - "Failed to sync user groups for user {} with the \ - directory while the user was logging in through an \ - external provider: {err}", - user.username, - ); - } else { - ldap_update_user_state(&mut user, &pool).await; - } - debug!("Cleared unused tokens for {}.", user.username); - debug!( - "Creating a new desktop activation token for user {} \ - as a result of proxy OpenID auth callback.", - user.username - ); - let config = server_config(); - let desktop_configuration = Token::new( - user.id, - Some(user.id), - Some(user.email), - config.enrollment_token_timeout.as_secs(), - Some(ENROLLMENT_TOKEN_TYPE.to_string()), - ); - debug!("Saving a new desktop configuration token..."); - desktop_configuration.save(&pool).await?; - debug!( - "Saved desktop configuration token. Responding to \ - proxy with the token." - ); - - Some(core_response::Payload::AuthCallback( - AuthCallbackResponse { - url: config.enrollment_url.clone().into(), - token: desktop_configuration.id, - }, - )) - } - Err(err) => { - let message = format!("OpenID auth error {err}"); - error!(message); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message, - })) - } - } - } - Err(err) => { - error!( - "Proxy requested an OpenID authentication info for a callback \ - URL ({}) that couldn't be parsed. Details: {err}", - request.callback_url - ); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message: "invalid callback URL".into(), - })) - } - } - } - // Reply without payload. - None => None, - }; - let req = CoreResponse { - id: received.id, - payload, - }; - context.tx.send(req).unwrap(); - } - Err(err) => { - error!("Disconnected from proxy at {}: {err}", context.endpoint_uri); - debug!("waiting 10s to re-establish the connection"); - sleep(TEN_SECS).await; - break 'message; - } - } - } - - Ok(()) -} - -/// Bi-directional gRPC stream for communication with Defguard Proxy. -#[instrument(skip_all)] -pub async fn run_grpc_bidi_stream( - pool: PgPool, - wireguard_tx: Sender, - mail_tx: UnboundedSender, - bidi_event_tx: UnboundedSender, - incompatible_components: Arc>, -) -> Result<(), anyhow::Error> { - let config = server_config(); - - // TODO: merge the two - let mut enrollment_server = EnrollmentServer::new( - pool.clone(), - wireguard_tx.clone(), - mail_tx.clone(), - bidi_event_tx.clone(), - ); - let mut password_reset_server = - PasswordResetServer::new(pool.clone(), mail_tx.clone(), bidi_event_tx.clone()); - let mut client_mfa_server = - ClientMfaServer::new(pool.clone(), mail_tx, wireguard_tx.clone(), bidi_event_tx); - let mut polling_server = PollingServer::new(pool.clone()); - - let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; - let endpoint = endpoint - .http2_keep_alive_interval(TEN_SECS) - .tcp_keepalive(Some(TEN_SECS)) - .keep_alive_while_idle(true); - let endpoint = if let Some(ca) = &config.proxy_grpc_ca { - let ca = read_to_string(ca)?; - let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca)); - endpoint.tls_config(tls)? - } else { - endpoint.tls_config(ClientTlsConfig::new().with_enabled_roots())? - }; - - loop { - debug!("Connecting to proxy at {}", endpoint.uri()); - let interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?); - let mut client = ProxyClient::with_interceptor(endpoint.connect_lazy(), interceptor); - let (tx, rx) = mpsc::unbounded_channel(); - let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { - Ok(response) => response, - Err(err) => { - match err.code() { - Code::FailedPrecondition => { - error!( - "Failed to connect to proxy @ {}, version check failed, retrying in \ - 10s: {err}", - endpoint.uri() - ); - // TODO push event - } - err => { - error!( - "Failed to connect to proxy @ {}, retrying in 10s: {err}", - endpoint.uri() - ); - } - } - sleep(TEN_SECS).await; - continue; - } - }; - let maybe_info = ComponentInfo::from_metadata(response.metadata()); - - // Check proxy version and continue if it's not supported. - let (version, info) = get_tracing_variables(&maybe_info); - let proxy_is_supported = is_proxy_version_supported(Some(&version)); - - let span = tracing::info_span!("proxy_bidi", component = %DefguardComponent::Proxy, - version = version.to_string(), info); - let _guard = span.enter(); - if !proxy_is_supported { - // Store incompatible proxy - let maybe_version = if version == VERSION_ZERO { - None - } else { - Some(version) - }; - let data = IncompatibleProxyData::new(maybe_version); - data.insert(&incompatible_components); - - // Sleep before trying to reconnect - sleep(TEN_SECS).await; - continue; - } - IncompatibleComponents::remove_proxy(&incompatible_components); - - info!("Connected to proxy at {}", endpoint.uri()); - let mut resp_stream = response.into_inner(); - handle_proxy_message_loop(ProxyMessageLoopContext { - pool: pool.clone(), - tx, - wireguard_tx: wireguard_tx.clone(), - resp_stream: &mut resp_stream, - enrollment_server: &mut enrollment_server, - password_reset_server: &mut password_reset_server, - client_mfa_server: &mut client_mfa_server, - polling_server: &mut polling_server, - endpoint_uri: endpoint.uri(), - }) - .await?; - } -} - /// Runs gRPC server with core services. #[instrument(skip_all)] pub async fn run_grpc_server( diff --git a/crates/defguard_core/src/grpc/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs similarity index 99% rename from crates/defguard_core/src/grpc/client_mfa.rs rename to crates/defguard_core/src/grpc/proxy/client_mfa.rs index 2069f25f89..03180fa4a9 100644 --- a/crates/defguard_core/src/grpc/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -58,7 +58,7 @@ pub(crate) struct ClientLoginSession { pub(crate) biometric_challenge: Option, } -pub(crate) struct ClientMfaServer { +pub struct ClientMfaServer { pub(crate) pool: PgPool, mail_tx: UnboundedSender, wireguard_tx: Sender, @@ -112,7 +112,7 @@ impl ClientMfaServer { /// Allows proxy to verify if token is valid and active #[instrument(skip_all)] - pub(crate) async fn validate_mfa_token( + pub async fn validate_mfa_token( &mut self, request: ClientMfaTokenValidationRequest, ) -> Result { diff --git a/crates/defguard_core/src/grpc/proxy/mod.rs b/crates/defguard_core/src/grpc/proxy/mod.rs new file mode 100644 index 0000000000..00048e4f2c --- /dev/null +++ b/crates/defguard_core/src/grpc/proxy/mod.rs @@ -0,0 +1 @@ +pub mod client_mfa; diff --git a/crates/defguard_core/src/grpc/utils.rs b/crates/defguard_core/src/grpc/utils.rs index 6955e02e22..67d851a0ac 100644 --- a/crates/defguard_core/src/grpc/utils.rs +++ b/crates/defguard_core/src/grpc/utils.rs @@ -28,10 +28,7 @@ use crate::{ }; // Create a new token for configuration polling. -pub(crate) async fn new_polling_token( - pool: &PgPool, - device: &Device, -) -> Result { +pub async fn new_polling_token(pool: &PgPool, device: &Device) -> Result { debug!( "Making a new polling token for device {}", device.wireguard_pubkey @@ -69,7 +66,7 @@ pub(crate) async fn new_polling_token( Ok(new_token.token) } -pub(crate) async fn build_device_config_response( +pub async fn build_device_config_response( pool: &PgPool, device: Device, token: Option, @@ -251,7 +248,7 @@ pub(crate) async fn build_device_config_response( } /// Parses `DeviceInfo` returning client IP address and user agent. -pub(crate) fn parse_client_ip_agent(info: &Option) -> Result<(IpAddr, String), String> { +pub fn parse_client_ip_agent(info: &Option) -> Result<(IpAddr, String), String> { let Some(info) = info else { error!("Missing DeviceInfo in proxy request"); return Err("missing device info".to_string()); diff --git a/crates/defguard_core/src/handlers/mod.rs b/crates/defguard_core/src/handlers/mod.rs index 60570d8e6d..070709e6a2 100644 --- a/crates/defguard_core/src/handlers/mod.rs +++ b/crates/defguard_core/src/handlers/mod.rs @@ -32,7 +32,7 @@ pub(crate) mod app_info; pub(crate) mod auth; pub(crate) mod forward_auth; pub(crate) mod group; -pub(crate) mod mail; +pub mod mail; pub mod network_devices; pub mod openid_clients; pub mod openid_flow; diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index a901bde27e..d2eda7df6c 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -96,7 +96,7 @@ pub fn check_username(username: &str) -> Result<(), WebError> { Ok(()) } -pub(crate) fn check_password_strength(password: &str) -> Result<(), WebError> { +pub fn check_password_strength(password: &str) -> Result<(), WebError> { if !(8..=128).contains(&password.len()) { return Err(WebError::Serialization("Incorrect password length".into())); } diff --git a/crates/defguard_core/src/headers.rs b/crates/defguard_core/src/headers.rs index ca9c2b78b2..c15f2cda40 100644 --- a/crates/defguard_core/src/headers.rs +++ b/crates/defguard_core/src/headers.rs @@ -26,7 +26,7 @@ pub(crate) static USER_AGENT_PARSER: LazyLock = LazyLock::new(| }); #[must_use] -pub(crate) fn get_device_info(user_agent: &str) -> String { +pub fn get_device_info(user_agent: &str) -> String { let escaped = tera::escape_html(user_agent); let client = USER_AGENT_PARSER.parse(&escaped); get_user_agent_device(&client) diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index c67f9c8e31..5005fe003d 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -939,7 +939,7 @@ pub async fn init_vpn_location( Ok(token) } -pub(crate) fn is_valid_phone_number(number: &str) -> bool { +pub fn is_valid_phone_number(number: &str) -> bool { PHONE_NUMBER_REGEX.is_match(number) } diff --git a/crates/defguard_core/src/version.rs b/crates/defguard_core/src/version.rs index 849c232337..976e437cc6 100644 --- a/crates/defguard_core/src/version.rs +++ b/crates/defguard_core/src/version.rs @@ -14,7 +14,7 @@ pub const MIN_GATEWAY_VERSION: Version = Version::new(1, 5, 0); static OUTDATED_COMPONENT_LIFETIME: TimeDelta = TimeDelta::hours(1); /// Checks if Defguard Proxy version meets minimum version requirements. -pub(crate) fn is_proxy_version_supported(version: Option<&Version>) -> bool { +pub fn is_proxy_version_supported(version: Option<&Version>) -> bool { let Some(version) = version else { error!( "Missing proxy component version information. This most likely means that proxy \ diff --git a/crates/defguard_proxy_manager/Cargo.toml b/crates/defguard_proxy_manager/Cargo.toml new file mode 100644 index 0000000000..63580a5b12 --- /dev/null +++ b/crates/defguard_proxy_manager/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "defguard_proxy_manager" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +# internal dependencies +defguard_common.workspace = true +defguard_core.workspace = true +defguard_mail.workspace = true +defguard_proto.workspace = true +defguard_version.workspace = true +openidconnect.workspace = true +reqwest.workspace = true +semver.workspace = true +tokio-stream.workspace = true + +anyhow.workspace = true +axum.workspace = true +chrono.workspace = true +sqlx.workspace = true +thiserror.workspace = true +tokio.workspace = true +tonic.workspace = true +tracing.workspace = true diff --git a/crates/defguard_core/src/grpc/enrollment.rs b/crates/defguard_proxy_manager/src/enrollment.rs similarity index 93% rename from crates/defguard_core/src/grpc/enrollment.rs rename to crates/defguard_proxy_manager/src/enrollment.rs index 0cfcb51fd6..e0358ce9cf 100644 --- a/crates/defguard_core/src/grpc/enrollment.rs +++ b/crates/defguard_proxy_manager/src/enrollment.rs @@ -1,36 +1,33 @@ use std::collections::HashSet; use defguard_common::{ + config::server_config, csv::AsCsv, db::{ Id, models::{ BiometricAuth, Device, DeviceConfig, DeviceType, MFAMethod, Settings, User, WireguardNetwork, device::DeviceInfo, polling_token::PollingToken, - settings::defaults::WELCOME_EMAIL_SUBJECT, wireguard::ServiceLocationMode, + wireguard::ServiceLocationMode, }, }, }; -use defguard_mail::{ - Mail, - templates::{self, TemplateLocation}, -}; +use defguard_mail::{Mail, templates::TemplateLocation}; use defguard_proto::proxy::{ ActivateUserRequest, AdminInfo, CodeMfaSetupFinishRequest, CodeMfaSetupFinishResponse, CodeMfaSetupStartRequest, CodeMfaSetupStartResponse, DeviceConfigResponse, EnrollmentStartRequest, EnrollmentStartResponse, ExistingDevice, InitialUserInfo, MfaMethod, NewDevice, RegisterMobileAuthRequest, }; -use sqlx::{PgPool, Transaction, query_scalar}; +use sqlx::{PgPool, query_scalar}; use tokio::sync::{ broadcast::Sender, mpsc::{UnboundedSender, error::SendError}, }; use tonic::Status; -use super::InstanceInfo; -use crate::{ - db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token, TokenError}, +use defguard_core::{ + db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, enterprise::{ db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, firewall::try_get_location_firewall_config, @@ -39,6 +36,7 @@ use crate::{ }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, grpc::{ + InstanceInfo, client_version::ClientFeature, gateway::events::GatewayEvent, utils::{build_device_config_response, new_polling_token, parse_client_ip_agent}, @@ -50,7 +48,7 @@ use crate::{ user::check_password_strength, }, headers::get_device_info, - is_valid_phone_number, server_config, + is_valid_phone_number, }; pub(super) struct EnrollmentServer { @@ -1060,81 +1058,6 @@ async fn initial_info_from_user( is_admin, }) } -impl Token { - // Send configured welcome email to user after finishing enrollment - async fn send_welcome_email( - &self, - transaction: &mut Transaction<'_, sqlx::Postgres>, - mail_tx: &UnboundedSender, - user: &User, - settings: &Settings, - ip_address: &str, - device_info: Option<&str>, - ) -> Result<(), TokenError> { - debug!("Sending welcome mail to {}", user.username); - let mail = Mail { - to: user.email.clone(), - subject: settings - .enrollment_welcome_email_subject - .clone() - .unwrap_or_else(|| WELCOME_EMAIL_SUBJECT.to_string()), - content: self - .get_welcome_email_content(&mut *transaction, ip_address, device_info) - .await?, - attachments: Vec::new(), - result_tx: None, - }; - match mail_tx.send(mail) { - Ok(()) => { - info!("Sent enrollment welcome mail to {}", user.username); - Ok(()) - } - Err(err) => { - error!("Error sending welcome mail: {err}"); - Err(TokenError::NotificationError(err.to_string())) - } - } - } - - // Notify admin that a user has completed enrollment - fn send_admin_notification( - mail_tx: &UnboundedSender, - admin: &User, - user: &User, - ip_address: &str, - device_info: Option<&str>, - ) -> Result<(), TokenError> { - debug!( - "Sending enrollment success notification for user {} to {}", - user.username, admin.username - ); - let mail = Mail { - to: admin.email.clone(), - subject: "[defguard] User enrollment completed".into(), - content: templates::enrollment_admin_notification( - &user.clone().into(), - &admin.clone().into(), - ip_address, - device_info, - )?, - attachments: Vec::new(), - result_tx: None, - }; - match mail_tx.send(mail) { - Ok(()) => { - info!( - "Sent enrollment success notification for user {} to {}", - user.username, admin.username - ); - Ok(()) - } - Err(err) => { - error!("Error sending welcome mail: {err}"); - Err(TokenError::NotificationError(err.to_string())) - } - } - } -} #[cfg(test)] mod test { @@ -1148,12 +1071,11 @@ mod test { setup_pool, }, }; + use defguard_core::db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}; use defguard_mail::Mail; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::mpsc::unbounded_channel; - use crate::db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}; - #[sqlx::test] async fn dg25_11_test_enrollment_welcome_email(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs new file mode 100644 index 0000000000..f950214de9 --- /dev/null +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -0,0 +1,845 @@ +use std::{ + collections::HashMap, + fs::read_to_string, + str::FromStr, + sync::{Arc, RwLock}, + time::Duration, +}; + +use axum::http::Uri; +use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; +use reqwest::Url; +use semver::Version; +use sqlx::PgPool; +use thiserror::Error; +use tokio::{ + sync::{ + broadcast::Sender, + mpsc::{self, UnboundedSender}, + }, + task::JoinSet, + time::sleep, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{ + Code, Streaming, + transport::{Certificate, ClientTlsConfig, Endpoint}, +}; + +use defguard_common::{VERSION, config::server_config}; +use defguard_core::{ + db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token, TokenError}, + enrollment_management::clear_unused_enrollment_tokens, + enterprise::{ + db::models::openid_provider::OpenIdProvider, + directory_sync::sync_user_groups_if_configured, + grpc::polling::PollingServer, + handlers::openid_login::{ + SELECT_ACCOUNT_SUPPORTED_PROVIDERS, build_state, make_oidc_client, user_from_claims, + }, + is_business_license_active, + ldap::utils::ldap_update_user_state, + }, + events::BidiStreamEvent, + grpc::{gateway::events::GatewayEvent, proxy::client_mfa::ClientMfaServer}, + version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, +}; +use defguard_mail::Mail; +use defguard_proto::proxy::{ + AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, core_request, + core_response, proxy_client::ProxyClient, +}; +use defguard_version::{ + ComponentInfo, DefguardComponent, client::ClientVersionInterceptor, get_tracing_variables, +}; + +use crate::{enrollment::EnrollmentServer, password_reset::PasswordResetServer}; + +mod enrollment; +pub(crate) mod password_reset; + +#[macro_use] +extern crate tracing; + +const TEN_SECS: Duration = Duration::from_secs(10); +static VERSION_ZERO: Version = Version::new(0, 0, 0); + +#[derive(Error, Debug)] +pub enum ProxyError { + #[error(transparent)] + InvalidUriError(#[from] axum::http::uri::InvalidUri), + #[error("Failed to read CA certificate: {0}")] + CaCertReadError(std::io::Error), + #[error(transparent)] + TonicError(#[from] tonic::transport::Error), + #[error(transparent)] + SemverError(#[from] semver::Error), + #[error(transparent)] + SqlxError(#[from] sqlx::Error), + #[error(transparent)] + TokenError(#[from] TokenError), +} + +/// Maintains routing state for proxy-specific responses by associating +/// correlation tokens with the proxy senders that should receive them. +#[derive(Default)] +struct ProxyRouter { + response_map: HashMap>>, +} + +impl ProxyRouter { + /// Records the proxy sender associated with a request that expects a routed response. + pub(crate) fn register_request( + &mut self, + request: &CoreRequest, + sender: &UnboundedSender, + ) { + match &request.payload { + // Mobile-assisted MFA completion responses must go to the proxy that owns the WebSocket + // so it can send the preshared key. + // Corresponds to the `core_response::Payload::ClientMfaFinish(response)` response. + // https://github.com/DefGuard/defguard/issues/1700 + Some(core_request::Payload::ClientMfaTokenValidation(request)) => { + self.response_map + .insert(request.token.clone(), vec![sender.clone()]); + } + Some(core_request::Payload::ClientMfaFinish(request)) => { + if let Some(senders) = self.response_map.get_mut(&request.token) { + senders.push(sender.clone()); + } + } + _ => {} + } + } + + /// Determines whether the given `CoreResponse` must be routed to a specific proxy instance. + pub(crate) fn route_response( + &mut self, + response: &CoreResponse, + ) -> Option>> { + #[allow(clippy::single_match)] + match &response.payload { + // Mobile-assisted MFA completion responses must go to the proxy that owns the WebSocket + // so it can send the preshared key. + // Corresponds to the `core_request::Payload::ClientMfaTokenValidation(request)` request. + // https://github.com/DefGuard/defguard/issues/1700 + Some(core_response::Payload::ClientMfaFinish(response)) => { + if let Some(ref token) = response.token { + return self.response_map.remove(token); + } + } + _ => {} + } + None + } +} + +/// Coordinates communication between the Core and multiple proxy instances. +/// +/// Responsibilities include: +/// - instantiating and supervising proxy connections, +/// - routing responses to the appropriate proxy based on correlation state, +/// - providing shared infrastructure (database access, outbound channels), +pub struct ProxyOrchestrator { + pool: PgPool, + tx: ProxyTxSet, + incompatible_components: Arc>, + router: Arc>, +} + +impl ProxyOrchestrator { + pub fn new( + pool: PgPool, + tx: ProxyTxSet, + incompatible_components: Arc>, + ) -> Self { + Self { + pool, + tx, + incompatible_components, + router: Default::default(), + } + } + + /// Spawns and supervises asynchronous tasks for all configured proxies. + /// + /// Each proxy runs in its own task and shares Core-side infrastructure + /// such as routing state and compatibility tracking. + pub async fn run(self, url: &Option) -> Result<(), ProxyError> { + // TODO retrieve proxies from db + let Some(url) = url else { + return Ok(()); + }; + let proxies = vec![Proxy::new( + self.pool.clone(), + Uri::from_str(url)?, + self.tx.clone(), + Arc::clone(&self.router), + )?]; + let mut tasks = JoinSet::>::new(); + for proxy in proxies { + tasks.spawn(proxy.run(self.tx.clone(), self.incompatible_components.clone())); + } + while let Some(result) = tasks.join_next().await { + match result { + Ok(Ok(())) => error!("Proxy task returned prematurely"), + Ok(Err(err)) => error!("Proxy task returned with error: {err}"), + Err(err) => error!("Proxy task execution failed: {err}"), + } + } + Ok(()) + } +} + +/// Shared set of outbound channels that proxy instances use to forward +/// events, notifications, and side effects to Core components. +#[derive(Clone)] +pub struct ProxyTxSet { + wireguard: Sender, + mail: UnboundedSender, + bidi_events: UnboundedSender, +} + +impl ProxyTxSet { + pub fn new( + wireguard: Sender, + mail: UnboundedSender, + bidi_events: UnboundedSender, + ) -> Self { + Self { + wireguard, + mail, + bidi_events, + } + } +} + +/// Represents a single Core - Proxy connection. +/// +/// A `Proxy` is responsible for establishing and maintaining a gRPC +/// bidirectional stream to one proxy instance, handling incoming requests +/// from that proxy, and forwarding responses back through the same stream. +/// Each `Proxy` runs independently and is supervised by the +/// `ProxyOrchestrator`. +struct Proxy { + pool: PgPool, + /// Proxy server gRPC URI + endpoint: Endpoint, + /// gRPC servers + services: ProxyServices, + /// Router shared between proxies and the orchestrator + router: Arc>, +} + +impl Proxy { + pub fn new( + pool: PgPool, + uri: Uri, + tx: ProxyTxSet, + router: Arc>, + ) -> Result { + let endpoint = Endpoint::from(uri); + + // Set endpoint keep-alive to avoid connectivity issues in proxied deployments. + let endpoint = endpoint + .http2_keep_alive_interval(TEN_SECS) + .tcp_keepalive(Some(TEN_SECS)) + .keep_alive_while_idle(true); + + // Setup certs. + let config = server_config(); + let endpoint = if let Some(ca) = &config.proxy_grpc_ca { + let ca = read_to_string(ca).map_err(|err| { + error!("Failed to read CA certificate: {err:?}"); + ProxyError::CaCertReadError(err) + })?; + let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca)); + endpoint.tls_config(tls)? + } else { + endpoint.tls_config(ClientTlsConfig::new().with_enabled_roots())? + }; + + // Instantiate gRPC servers. + let services = ProxyServices::new(pool.clone(), tx); + + Ok(Self { + pool, + endpoint, + router, + services, + }) + } + + /// Establishes and maintains a gRPC bidirectional stream to the proxy. + /// + /// The proxy connection is retried on failure, compatibility is checked + /// on each successful connection, and incoming messages are handled + /// until the stream is closed. + pub(crate) async fn run( + mut self, + tx_set: ProxyTxSet, + incompatible_components: Arc>, + ) -> Result<(), ProxyError> { + loop { + debug!("Connecting to proxy at {}", self.endpoint.uri()); + let interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?); + let mut client = + ProxyClient::with_interceptor(self.endpoint.connect_lazy(), interceptor); + let (tx, rx) = mpsc::unbounded_channel(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + match err.code() { + Code::FailedPrecondition => { + error!( + "Failed to connect to proxy @ {}, version check failed, retrying in \ + 10s: {err}", + self.endpoint.uri() + ); + // TODO push event + } + err => { + error!( + "Failed to connect to proxy @ {}, retrying in 10s: {err}", + self.endpoint.uri() + ); + } + } + sleep(TEN_SECS).await; + continue; + } + }; + let maybe_info = ComponentInfo::from_metadata(response.metadata()); + + // Check proxy version and continue if it's not supported. + let (version, info) = get_tracing_variables(&maybe_info); + let proxy_is_supported = is_proxy_version_supported(Some(&version)); + + let span = tracing::info_span!("proxy_bidi", component = %DefguardComponent::Proxy, + version = version.to_string(), info); + let _guard = span.enter(); + if !proxy_is_supported { + // Store incompatible proxy + let maybe_version = if version == VERSION_ZERO { + None + } else { + Some(version) + }; + let data = IncompatibleProxyData::new(maybe_version); + data.insert(&incompatible_components); + + // Sleep before trying to reconnect + sleep(TEN_SECS).await; + continue; + } + IncompatibleComponents::remove_proxy(&incompatible_components); + + info!("Connected to proxy at {}", self.endpoint.uri()); + let mut resp_stream = response.into_inner(); + self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) + .await?; + } + } + + /// Processes incoming requests from the proxy over an active gRPC stream. + /// + /// This loop receives `CoreRequest` messages from the proxy, dispatches + /// them to the appropriate Core-side handlers, and sends corresponding + /// `CoreResponse` messages back through the stream. Certain requests may + /// also register routing state for future responses. + async fn message_loop( + &mut self, + tx: UnboundedSender, + wireguard_tx: Sender, + resp_stream: &mut Streaming, + ) -> Result<(), ProxyError> { + let pool = self.pool.clone(); + 'message: loop { + match resp_stream.message().await { + Ok(None) => { + info!("stream was closed by the sender"); + break 'message; + } + Ok(Some(received)) => { + debug!("Received message from proxy; ID={}", received.id); + self.router + .write() + .unwrap() + .register_request(&received, &tx); + let payload = match received.payload { + // rpc CodeMfaSetupStart return (CodeMfaSetupStartResponse) + Some(core_request::Payload::CodeMfaSetupStart(request)) => { + match self + .services + .enrollment + .register_code_mfa_start(request) + .await + { + Ok(response) => Some( + core_response::Payload::CodeMfaSetupStartResponse(response), + ), + Err(err) => { + error!("Register mfa start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc CodeMfaSetupFinish return (CodeMfaSetupFinishResponse) + Some(core_request::Payload::CodeMfaSetupFinish(request)) => { + match self + .services + .enrollment + .register_code_mfa_finish(request) + .await + { + Ok(response) => Some( + core_response::Payload::CodeMfaSetupFinishResponse(response), + ), + Err(err) => { + error!("Register MFA finish error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaTokenValidation return (ClientMfaTokenValidationResponse) + Some(core_request::Payload::ClientMfaTokenValidation(request)) => { + match self.services.client_mfa.validate_mfa_token(request).await { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaTokenValidation( + response_payload, + )) + } + Err(err) => { + error!("Client MFA validate token error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc RegisterMobileAuth (RegisterMobileAuthRequest) return (google.protobuf.Empty) + Some(core_request::Payload::RegisterMobileAuth(request)) => { + match self.services.enrollment.register_mobile_auth(request).await { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("Register mobile auth error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc StartEnrollment (EnrollmentStartRequest) returns (EnrollmentStartResponse) + Some(core_request::Payload::EnrollmentStart(request)) => { + match self + .services + .enrollment + .start_enrollment(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::EnrollmentStart(response_payload)) + } + Err(err) => { + error!("start enrollment error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ActivateUser (ActivateUserRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::ActivateUser(request)) => { + match self + .services + .enrollment + .activate_user(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("activate user error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc CreateDevice (NewDevice) returns (DeviceConfigResponse) + Some(core_request::Payload::NewDevice(request)) => { + match self + .services + .enrollment + .create_device(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::DeviceConfig(response_payload)) + } + Err(err) => { + error!("create device error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc GetNetworkInfo (ExistingDevice) returns (DeviceConfigResponse) + Some(core_request::Payload::ExistingDevice(request)) => { + match self + .services + .enrollment + .get_network_info(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::DeviceConfig(response_payload)) + } + Err(err) => { + error!("get network info error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc RequestPasswordReset (PasswordResetInitializeRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::PasswordResetInit(request)) => { + match self + .services + .password_reset + .request_password_reset(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("password reset init error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc StartPasswordReset (PasswordResetStartRequest) returns (PasswordResetStartResponse) + Some(core_request::Payload::PasswordResetStart(request)) => { + match self + .services + .password_reset + .start_password_reset(request, received.device_info) + .await + { + Ok(response_payload) => Some( + core_response::Payload::PasswordResetStart(response_payload), + ), + Err(err) => { + error!("password reset start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ResetPassword (PasswordResetRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::PasswordReset(request)) => { + match self + .services + .password_reset + .reset_password(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("password reset error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaStart (ClientMfaStartRequest) returns (ClientMfaStartResponse) + Some(core_request::Payload::ClientMfaStart(request)) => { + match self + .services + .client_mfa + .start_client_mfa_login(request) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaStart(response_payload)) + } + Err(err) => { + error!("client MFA start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaFinish (ClientMfaFinishRequest) returns (ClientMfaFinishResponse) + Some(core_request::Payload::ClientMfaFinish(request)) => { + match self + .services + .client_mfa + .finish_client_mfa_login(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaFinish(response_payload)) + } + Err(err) => { + match err.code() { + Code::FailedPrecondition => { + // User not yet done with OIDC authentication. Don't log it + // as an error. + debug!("Client MFA finish error: {err}"); + } + _ => { + // Log other errors as errors. + error!("Client MFA finish error: {err}"); + } + } + Some(core_response::Payload::CoreError(err.into())) + } + } + } + Some(core_request::Payload::ClientMfaOidcAuthenticate(request)) => { + match self + .services + .client_mfa + .auth_mfa_session_with_oidc(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("client MFA OIDC authenticate error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc LocationInfo (LocationInfoRequest) returns (LocationInfoResponse) + Some(core_request::Payload::InstanceInfo(request)) => { + match self + .services + .polling + .info(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::InstanceInfo(response_payload)) + } + Err(err) => { + if Code::FailedPrecondition == err.code() { + // Ignore the case when we are not enterprise but the client is + // trying to fetch the instance config, + // to avoid spamming the logs with misleading errors. + + debug!( + "A client tried to fetch the instance config, but we are \ + not enterprise." + ); + Some(core_response::Payload::CoreError(err.into())) + } else { + error!("Instance info error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + } + Some(core_request::Payload::AuthInfo(request)) => { + if !is_business_license_active() { + warn!("Enterprise license required"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::FailedPrecondition as i32, + message: "no valid license".into(), + })) + } else if let Ok(redirect_url) = Url::parse(&request.redirect_url) { + if let Some(provider) = OpenIdProvider::get_current(&pool).await? { + match make_oidc_client(redirect_url, &provider).await { + Ok((_client_id, client)) => { + let mut authorize_url_builder = client + .authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + || build_state(request.state), + Nonce::new_random, + ) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("profile".to_string())); + + if SELECT_ACCOUNT_SUPPORTED_PROVIDERS + .iter() + .all(|p| p.eq_ignore_ascii_case(&provider.name)) + { + authorize_url_builder = authorize_url_builder + .add_prompt( + openidconnect::core::CoreAuthPrompt::SelectAccount, + ); + } + let (url, csrf_token, nonce) = + authorize_url_builder.url(); + + Some(core_response::Payload::AuthInfo( + AuthInfoResponse { + url: url.into(), + csrf_token: csrf_token.secret().to_owned(), + nonce: nonce.secret().to_owned(), + button_display_name: provider.display_name, + }, + )) + } + Err(err) => { + error!( + "Failed to setup external OIDC provider client: {err}" + ); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "failed to build OIDC client".into(), + })) + } + } + } else { + error!("Failed to get current OpenID provider"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::NotFound as i32, + message: "failed to get current OpenID provider".into(), + })) + } + } else { + error!( + "Invalid redirect URL in authentication info request: {}", + request.redirect_url + ); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "invalid redirect URL".into(), + })) + } + } + Some(core_request::Payload::AuthCallback(request)) => { + match Url::parse(&request.callback_url) { + Ok(callback_url) => { + let code = AuthorizationCode::new(request.code); + match user_from_claims( + &pool, + Nonce::new(request.nonce), + code, + callback_url, + ) + .await + { + Ok(mut user) => { + clear_unused_enrollment_tokens(&user, &pool).await?; + if let Err(err) = sync_user_groups_if_configured( + &user, + &pool, + &wireguard_tx, + ) + .await + { + error!( + "Failed to sync user groups for user {} with the \ + directory while the user was logging in through an \ + external provider: {err}", + user.username, + ); + } else { + ldap_update_user_state(&mut user, &pool).await; + } + debug!("Cleared unused tokens for {}.", user.username); + debug!( + "Creating a new desktop activation token for user {} \ + as a result of proxy OpenID auth callback.", + user.username + ); + let config = server_config(); + let desktop_configuration = Token::new( + user.id, + Some(user.id), + Some(user.email), + config.enrollment_token_timeout.as_secs(), + Some(ENROLLMENT_TOKEN_TYPE.to_string()), + ); + debug!("Saving a new desktop configuration token..."); + desktop_configuration.save(&pool).await?; + debug!( + "Saved desktop configuration token. Responding to \ + proxy with the token." + ); + + Some(core_response::Payload::AuthCallback( + AuthCallbackResponse { + url: config.enrollment_url.clone().into(), + token: desktop_configuration.id, + }, + )) + } + Err(err) => { + let message = format!("OpenID auth error {err}"); + error!(message); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message, + })) + } + } + } + Err(err) => { + error!( + "Proxy requested an OpenID authentication info for a callback \ + URL ({}) that couldn't be parsed. Details: {err}", + request.callback_url + ); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "invalid callback URL".into(), + })) + } + } + } + // Reply without payload. + None => None, + }; + + let req = CoreResponse { + id: received.id, + payload, + }; + if let Some(txs) = self.router.write().unwrap().route_response(&req) { + for tx in txs { + let _ = tx.send(req.clone()); + } + } else { + let _ = tx.send(req); + }; + } + Err(err) => { + error!("Disconnected from proxy at {}: {err}", self.endpoint.uri()); + debug!("waiting 10s to re-establish the connection"); + sleep(TEN_SECS).await; + break 'message; + } + } + } + + Ok(()) + } +} + +/// Groups Core-side service handlers used to process requests originating +/// from a proxy instance. +/// +/// Each `ProxyServices` instance is owned by a single `Proxy` and provides +/// the concrete handlers for enrollment, authentication, and polling-related +/// requests received over the gRPC bidirectional stream. +struct ProxyServices { + enrollment: EnrollmentServer, + password_reset: PasswordResetServer, + client_mfa: ClientMfaServer, + polling: PollingServer, +} + +impl ProxyServices { + pub fn new(pool: PgPool, tx: ProxyTxSet) -> Self { + let enrollment = EnrollmentServer::new( + pool.clone(), + tx.wireguard.clone(), + tx.mail.clone(), + tx.bidi_events.clone(), + ); + let password_reset = + PasswordResetServer::new(pool.clone(), tx.mail.clone(), tx.bidi_events.clone()); + let client_mfa = ClientMfaServer::new( + pool.clone(), + tx.mail.clone(), + tx.wireguard.clone(), + tx.bidi_events.clone(), + ); + let polling = PollingServer::new(pool.clone()); + + Self { + enrollment, + password_reset, + client_mfa, + polling, + } + } +} diff --git a/crates/defguard_core/src/grpc/password_reset.rs b/crates/defguard_proxy_manager/src/password_reset.rs similarity index 99% rename from crates/defguard_core/src/grpc/password_reset.rs rename to crates/defguard_proxy_manager/src/password_reset.rs index f3d44b6d99..e27bc4776e 100644 --- a/crates/defguard_core/src/grpc/password_reset.rs +++ b/crates/defguard_proxy_manager/src/password_reset.rs @@ -1,4 +1,4 @@ -use defguard_common::db::models::User; +use defguard_common::{config::server_config, db::models::User}; use defguard_mail::Mail; use defguard_proto::proxy::{ DeviceInfo, PasswordResetInitializeRequest, PasswordResetRequest, PasswordResetStartRequest, @@ -8,7 +8,7 @@ use sqlx::PgPool; use tokio::sync::mpsc::{UnboundedSender, error::SendError}; use tonic::Status; -use crate::{ +use defguard_core::{ db::models::enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, enterprise::ldap::utils::ldap_change_password, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, PasswordResetEvent}, @@ -18,7 +18,6 @@ use crate::{ user::check_password_strength, }, headers::get_device_info, - server_config, }; pub(super) struct PasswordResetServer { diff --git a/deny.toml b/deny.toml index df0ff43787..b6b55bf1bf 100644 --- a/deny.toml +++ b/deny.toml @@ -133,6 +133,9 @@ exceptions = [ { allow = [ "AGPL-3.0-only", "AGPL-3.0-or-later", ], crate = "defguard_event_logger" }, + { allow = [ + "AGPL-3.0-only", "AGPL-3.0-or-later", + ], crate = "defguard_proxy_manager" }, { allow = [ "AGPL-3.0-only", "AGPL-3.0-or-later", ], crate = "defguard_session_manager" },