Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions crates/defguard/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use defguard_core::{
license::{run_periodic_license_check, set_cached_license, License},
limits::update_counts,
},
events::{ApiEvent, BidiStreamEvent, GrpcEvent},
events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent},
grpc::{run_grpc_bidi_stream, run_grpc_server, GatewayMap, WorkerState},
init_dev_env, init_vpn_location,
mail::{run_mail_handler, Mail},
Expand All @@ -27,7 +27,7 @@ use defguard_core::{
SERVER_CONFIG, VERSION,
};
use defguard_event_logger::{message::EventLoggerMessage, run_event_logger};
use defguard_event_router::run_event_router;
use defguard_event_router::{run_event_router, RouterReceiverSet};
use secrecy::ExposeSecret;
use tokio::sync::{broadcast, mpsc::unbounded_channel};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
Expand Down Expand Up @@ -88,6 +88,7 @@ async fn main() -> Result<(), anyhow::Error> {
// create event channels for services
let (api_event_tx, api_event_rx) = unbounded_channel::<ApiEvent>();
let (bidi_event_tx, bidi_event_rx) = unbounded_channel::<BidiStreamEvent>();
let (internal_event_tx, internal_event_rx) = unbounded_channel::<InternalEvent>();
let (grpc_event_tx, grpc_event_rx) = unbounded_channel::<GrpcEvent>();

// Audit stream setup
Expand Down Expand Up @@ -147,11 +148,11 @@ async fn main() -> Result<(), anyhow::Error> {
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_state), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone(), grpc_event_tx) => error!("gRPC server returned early: {res:?}"),
res = run_web_server(worker_state, gateway_state, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx.clone(), pool.clone(), failed_logins, api_event_tx) => error!("Web server returned early: {res:?}"),
res = run_mail_handler(mail_rx) => error!("Mail handler returned early: {res:?}"),
res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx.clone()) => error!("Periodic peer disconnect task returned early: {res:?}"),
res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx.clone(), internal_event_tx.clone()) => error!("Periodic peer disconnect task returned early: {res:?}"),
res = run_periodic_stats_purge(pool.clone(), config.stats_purge_frequency.into(), config.stats_purge_threshold.into()), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:?}"),
res = run_periodic_license_check(&pool) => error!("Periodic license check task returned early: {res:?}"),
res = run_utility_thread(&pool, wireguard_tx.clone()) => error!("Utility thread returned early: {res:?}"),
res = run_event_router( api_event_rx, grpc_event_rx, bidi_event_rx, event_logger_tx, wireguard_tx, mail_tx, audit_stream_reload_notify.clone()) => error!("Event router returned early: {res:?}"),
res = run_event_router(RouterReceiverSet::new(api_event_rx, grpc_event_rx, bidi_event_rx, internal_event_rx), event_logger_tx, wireguard_tx, mail_tx, audit_stream_reload_notify.clone()) => error!("Event router returned early: {res:?}"),
res = run_event_logger(pool.clone(), event_logger_rx, audit_messages_tx.clone()) => error!("Audit event logger returned early: {res:?}"),
res = run_audit_stream_manager(pool.clone(), audit_stream_reload_notify.clone(), audit_messages_rx) => error!("Audit stream manager returned early: {res:?}"),
}
Expand Down
7 changes: 7 additions & 0 deletions crates/defguard_core/src/db/models/audit_log/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,10 @@ pub struct VpnClientMetadata {
pub location: WireguardNetwork<Id>,
pub device: Device<Id>,
}

#[derive(Serialize)]
pub struct VpnClientMfaMetadata {
pub location: WireguardNetwork<Id>,
pub device: Device<Id>,
pub method: MFAMethod,
}
3 changes: 3 additions & 0 deletions crates/defguard_core/src/db/models/audit_log/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ pub enum EventType {
// VPN client events
VpnClientConnected,
VpnClientDisconnected,
VpnClientConnectedMfa,
VpnClientDisconnectedMfa,
VpnClientMfaFailed,
}

#[derive(Model, FromRow, Serialize)]
Expand Down
14 changes: 13 additions & 1 deletion crates/defguard_core/src/db/models/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ use crate::{
db::{models::group::Permission, GatewayEvent, Id, NoId, Session, Settings, WireguardNetwork},
enterprise::limits::update_counts,
error::WebError,
grpc::gateway::{send_multiple_wireguard_events, send_wireguard_event},
grpc::{
gateway::{send_multiple_wireguard_events, send_wireguard_event},
proto::proxy::MfaMethod,
},
random::{gen_alphanumeric, gen_totp_secret},
server_config,
};
Expand All @@ -50,6 +53,15 @@ pub enum MFAMethod {
Email,
}

impl From<MfaMethod> for MFAMethod {
fn from(method: MfaMethod) -> Self {
match method {
MfaMethod::Totp => Self::OneTimePassword,
MfaMethod::Email => Self::Email,
}
}
}

impl fmt::Display for MFAMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
Expand Down
58 changes: 51 additions & 7 deletions crates/defguard_core/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::net::IpAddr;

use crate::db::{Device, Id, MFAMethod, WireguardNetwork};
use chrono::{NaiveDateTime, Utc};
use ipnetwork::IpNetwork;

/// Shared context that needs to be added to every API event
///
Expand Down Expand Up @@ -177,27 +176,35 @@ pub struct BidiRequestContext {
pub timestamp: NaiveDateTime,
pub user_id: Id,
pub username: String,
pub ip: IpNetwork,
pub device: String,
pub ip: IpAddr,
pub device: Device<Id>,
pub location: WireguardNetwork<Id>,
}

impl BidiRequestContext {
pub fn new(user_id: Id, username: String, ip: IpNetwork, device: String) -> Self {
pub fn new(
user_id: Id,
username: String,
ip: IpAddr,
device: Device<Id>,
location: WireguardNetwork<Id>,
) -> Self {
let timestamp = Utc::now().naive_utc();
Self {
timestamp,
user_id,
username,
ip,
device,
location,
}
}
}

/// Events emmited from gRPC bi-directional communication stream
#[derive(Debug)]
pub struct BidiStreamEvent {
pub request_context: BidiRequestContext,
pub context: BidiRequestContext,
pub event: BidiStreamEventType,
}

Expand All @@ -208,7 +215,7 @@ pub struct BidiStreamEvent {
pub enum BidiStreamEventType {
Enrollment(EnrollmentEvent),
PasswordReset(PasswordResetEvent),
DesktopCLientMfa(DesktopClientMfaEvent),
DesktopClientMfa(DesktopClientMfaEvent),
ConfigPolling(ConfigPollingEvent),
}

Expand All @@ -221,7 +228,44 @@ pub enum EnrollmentEvent {
pub enum PasswordResetEvent {}

#[derive(Debug)]
pub enum DesktopClientMfaEvent {}
pub enum DesktopClientMfaEvent {
Connected { method: MFAMethod },
Failed { method: MFAMethod },
}

#[derive(Debug)]
pub enum ConfigPollingEvent {}

/// Shared context for every internally-triggered event.
///
/// Similarly to `ApiRequestContexts` at the moment it's mostly meant to populate the audit log.
#[derive(Debug)]
pub struct InternalEventContext {
pub timestamp: NaiveDateTime,
pub user_id: Id,
pub username: String,
pub ip: IpAddr,
pub device: Device<Id>,
}

impl InternalEventContext {
pub fn new(user_id: Id, username: String, ip: IpAddr, device: Device<Id>) -> Self {
let timestamp = Utc::now().naive_utc();
Self {
timestamp,
user_id,
username,
ip,
device,
}
}
}

/// Events emmited by background threads, not triggered directly by users
#[derive(Debug)]
pub enum InternalEvent {
DesktopClientMfaDisconnected {
context: InternalEventContext,
location: WireguardNetwork<Id>,
},
}
62 changes: 59 additions & 3 deletions crates/defguard_core/src/grpc/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::collections::HashMap;
use std::{collections::HashMap, net::Ipv4Addr};

use chrono::Utc;
use sqlx::PgPool;
use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender};
use tonic::Status;
use thiserror::Error;
use tokio::sync::{
broadcast::Sender,
mpsc::{error::SendError, UnboundedSender},
};
use tonic::{Code, Status};

use super::proto::proxy::{
ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse,
Expand All @@ -15,12 +19,26 @@ use crate::{
models::device::{DeviceInfo, DeviceNetworkInfo, WireguardNetworkDevice},
Device, GatewayEvent, Id, User, UserInfo, WireguardNetwork,
},
events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent},
handlers::mail::send_email_mfa_code_email,
mail::Mail,
};

const CLIENT_SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes

#[derive(Debug, Error)]
#[allow(clippy::large_enum_variant)]
pub enum ClientMfaServerError {
#[error("gRPC event channel error: {0}")]
BidiEventChannelError(#[from] SendError<BidiStreamEvent>),
}

impl From<ClientMfaServerError> for Status {
fn from(value: ClientMfaServerError) -> Self {
Self::new(Code::Internal, value.to_string())
}
}

struct ClientLoginSession {
method: MfaMethod,
location: WireguardNetwork<Id>,
Expand All @@ -33,6 +51,7 @@ pub(super) struct ClientMfaServer {
mail_tx: UnboundedSender<Mail>,
wireguard_tx: Sender<GatewayEvent>,
sessions: HashMap<String, ClientLoginSession>,
bidi_event_tx: UnboundedSender<BidiStreamEvent>,
}

impl ClientMfaServer {
Expand All @@ -41,11 +60,13 @@ impl ClientMfaServer {
pool: PgPool,
mail_tx: UnboundedSender<Mail>,
wireguard_tx: Sender<GatewayEvent>,
bidi_event_tx: UnboundedSender<BidiStreamEvent>,
) -> Self {
Self {
pool,
mail_tx,
wireguard_tx,
bidi_event_tx,
sessions: HashMap::new(),
}
}
Expand Down Expand Up @@ -73,6 +94,10 @@ impl ClientMfaServer {
Ok(claims.client_id)
}

fn emit_event(&self, event: BidiStreamEvent) -> Result<(), ClientMfaServerError> {
Ok(self.bidi_event_tx.send(event)?)
}

#[instrument(skip_all)]
pub async fn start_client_mfa_login(
&mut self,
Expand Down Expand Up @@ -213,17 +238,42 @@ impl ClientMfaServer {
user,
} = session;

// Prepare event context
let context = BidiRequestContext::new(
user.id,
user.username.clone(),
std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
device.clone(),
location.clone(),
);

// validate code
match method {
MfaMethod::Totp => {
if !user.verify_totp_code(&request.code.to_string()) {
error!("Provided TOTP code is not valid");
self.emit_event(BidiStreamEvent {
context,
event: BidiStreamEventType::DesktopClientMfa(
DesktopClientMfaEvent::Failed {
method: (*method).into(),
},
),
})?;
return Err(Status::unauthenticated("unauthorized"));
}
}
MfaMethod::Email => {
if !user.verify_email_mfa_code(&request.code.to_string()) {
error!("Provided email code is not valid");
self.emit_event(BidiStreamEvent {
context,
event: BidiStreamEventType::DesktopClientMfa(
DesktopClientMfaEvent::Failed {
method: (*method).into(),
},
),
})?;
return Err(Status::unauthenticated("unauthorized"));
}
}
Expand Down Expand Up @@ -281,6 +331,12 @@ impl ClientMfaServer {
"Desktop client login finished for {} at location {}",
user.username, location.name
);
self.emit_event(BidiStreamEvent {
context,
event: BidiStreamEventType::DesktopClientMfa(DesktopClientMfaEvent::Connected {
method: (*method).into(),
}),
})?;

// remove login session from map
self.sessions.remove(&pubkey);
Expand Down
4 changes: 2 additions & 2 deletions crates/defguard_core/src/grpc/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ impl EnrollmentServer {
#[allow(dead_code)]
fn emit_event(
&self,
request_context: BidiRequestContext,
context: BidiRequestContext,
event: EnrollmentEvent,
) -> Result<(), SendError<BidiStreamEvent>> {
let event = BidiStreamEvent {
request_context,
context,
event: BidiStreamEventType::Enrollment(event),
};

Expand Down
Loading