diff --git a/.sqlx/query-0f00cd80f489fe50957305cd0e2cfb992f473271cfa5f4e2cb8d9c92c4fea3f8.json b/.sqlx/query-0f00cd80f489fe50957305cd0e2cfb992f473271cfa5f4e2cb8d9c92c4fea3f8.json new file mode 100644 index 0000000000..c386fd3d79 --- /dev/null +++ b/.sqlx/query-0f00cd80f489fe50957305cd0e2cfb992f473271cfa5f4e2cb8d9c92c4fea3f8.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE \"pollingtoken\" SET \"token\" = $2,\"device_id\" = $3,\"created_at\" = $4 WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Text", + "Int8", + "Timestamp" + ] + }, + "nullable": [] + }, + "hash": "0f00cd80f489fe50957305cd0e2cfb992f473271cfa5f4e2cb8d9c92c4fea3f8" +} diff --git a/.sqlx/query-5267ddfcbe18a9db34a0bc9f51baa5ef5a214dbb487981d9864f6469a04805c2.json b/.sqlx/query-5267ddfcbe18a9db34a0bc9f51baa5ef5a214dbb487981d9864f6469a04805c2.json new file mode 100644 index 0000000000..ebc7341f4e --- /dev/null +++ b/.sqlx/query-5267ddfcbe18a9db34a0bc9f51baa5ef5a214dbb487981d9864f6469a04805c2.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id \"id?\", \"token\",\"device_id\",\"created_at\" FROM \"pollingtoken\" WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id?", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "token", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "device_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "created_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "5267ddfcbe18a9db34a0bc9f51baa5ef5a214dbb487981d9864f6469a04805c2" +} diff --git a/.sqlx/query-65af2457c30994d33ef2d6d265e04523dc1c69fa0052089d341c41862fdc4425.json b/.sqlx/query-65af2457c30994d33ef2d6d265e04523dc1c69fa0052089d341c41862fdc4425.json new file mode 100644 index 0000000000..0111f9c656 --- /dev/null +++ b/.sqlx/query-65af2457c30994d33ef2d6d265e04523dc1c69fa0052089d341c41862fdc4425.json @@ -0,0 +1,38 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id \"id?\", \"token\",\"device_id\",\"created_at\" FROM \"pollingtoken\"", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id?", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "token", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "device_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "created_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "65af2457c30994d33ef2d6d265e04523dc1c69fa0052089d341c41862fdc4425" +} diff --git a/.sqlx/query-750c3343a64d8f02d6952ab115318f79cc32d8013a8255a87c745f73ef08a2df.json b/.sqlx/query-750c3343a64d8f02d6952ab115318f79cc32d8013a8255a87c745f73ef08a2df.json new file mode 100644 index 0000000000..ee81bee7e4 --- /dev/null +++ b/.sqlx/query-750c3343a64d8f02d6952ab115318f79cc32d8013a8255a87c745f73ef08a2df.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, token, device_id, created_at\n FROM pollingtoken WHERE token = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "token", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "device_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "created_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "750c3343a64d8f02d6952ab115318f79cc32d8013a8255a87c745f73ef08a2df" +} diff --git a/.sqlx/query-838909479f3e9a1538980e1299be80b362b9ca168bdc2a4f10db6157b93b53e1.json b/.sqlx/query-838909479f3e9a1538980e1299be80b362b9ca168bdc2a4f10db6157b93b53e1.json new file mode 100644 index 0000000000..f539daaee7 --- /dev/null +++ b/.sqlx/query-838909479f3e9a1538980e1299be80b362b9ca168bdc2a4f10db6157b93b53e1.json @@ -0,0 +1,24 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO \"pollingtoken\" (\"token\",\"device_id\",\"created_at\") VALUES ($1,$2,$3) RETURNING id", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Text", + "Int8", + "Timestamp" + ] + }, + "nullable": [ + false + ] + }, + "hash": "838909479f3e9a1538980e1299be80b362b9ca168bdc2a4f10db6157b93b53e1" +} diff --git a/.sqlx/query-891d8d88e0eaba3d67a23dec7de6292046c0be2a179424b57c16767dd1d8b212.json b/.sqlx/query-891d8d88e0eaba3d67a23dec7de6292046c0be2a179424b57c16767dd1d8b212.json new file mode 100644 index 0000000000..0f55514a6e --- /dev/null +++ b/.sqlx/query-891d8d88e0eaba3d67a23dec7de6292046c0be2a179424b57c16767dd1d8b212.json @@ -0,0 +1,125 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT u.id \"id?\", u.username, u.password_hash, u.last_name, u.first_name, u.email, u.phone, u.mfa_enabled, u.totp_enabled, u.email_mfa_enabled, u.totp_secret, u.email_mfa_secret, u.mfa_method \"mfa_method: _\", u.recovery_codes, u.is_active, u.openid_login FROM \"user\" as u JOIN \"device\" as d ON u.id = d.user_id WHERE d.id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id?", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "username", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "password_hash", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "last_name", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "first_name", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "email", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "phone", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "mfa_enabled", + "type_info": "Bool" + }, + { + "ordinal": 8, + "name": "totp_enabled", + "type_info": "Bool" + }, + { + "ordinal": 9, + "name": "email_mfa_enabled", + "type_info": "Bool" + }, + { + "ordinal": 10, + "name": "totp_secret", + "type_info": "Bytea" + }, + { + "ordinal": 11, + "name": "email_mfa_secret", + "type_info": "Bytea" + }, + { + "ordinal": 12, + "name": "mfa_method: _", + "type_info": { + "Custom": { + "name": "mfa_method", + "kind": { + "Enum": [ + "none", + "one_time_password", + "webauthn", + "web3", + "email" + ] + } + } + } + }, + { + "ordinal": 13, + "name": "recovery_codes", + "type_info": "TextArray" + }, + { + "ordinal": 14, + "name": "is_active", + "type_info": "Bool" + }, + { + "ordinal": 15, + "name": "openid_login", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + false, + true, + false, + false, + false, + true, + true, + false, + false, + false, + false + ] + }, + "hash": "891d8d88e0eaba3d67a23dec7de6292046c0be2a179424b57c16767dd1d8b212" +} diff --git a/.sqlx/query-fac46fc03161bac460c30d9b61af52e6a872f5e974f586ed4d6a8c1ceaed8223.json b/.sqlx/query-fac46fc03161bac460c30d9b61af52e6a872f5e974f586ed4d6a8c1ceaed8223.json new file mode 100644 index 0000000000..7af831b096 --- /dev/null +++ b/.sqlx/query-fac46fc03161bac460c30d9b61af52e6a872f5e974f586ed4d6a8c1ceaed8223.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM \"pollingtoken\" WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "fac46fc03161bac460c30d9b61af52e6a872f5e974f586ed4d6a8c1ceaed8223" +} diff --git a/migrations/20240906090729_polling_token.down.sql b/migrations/20240906090729_polling_token.down.sql new file mode 100644 index 0000000000..bb7a045ee9 --- /dev/null +++ b/migrations/20240906090729_polling_token.down.sql @@ -0,0 +1 @@ +DROP TABLE pollingtoken; diff --git a/migrations/20240906090729_polling_token.up.sql b/migrations/20240906090729_polling_token.up.sql new file mode 100644 index 0000000000..933a426f90 --- /dev/null +++ b/migrations/20240906090729_polling_token.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE pollingtoken ( + id bigserial PRIMARY KEY, + token TEXT NOT NULL, + device_id bigint NOT NULL, + created_at timestamp without time zone NOT NULL DEFAULT now(), + FOREIGN KEY(device_id) REFERENCES "device"(id) ON DELETE CASCADE +); diff --git a/proto b/proto index 1b8f7c5cfc..d069a0e530 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 1b8f7c5cfccf7e4f3a7a8422215c9db1d6fcc68e +Subproject commit d069a0e5304281cfc8b09e949a8e7a9feb5fc115 diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index e52eb7f1ac..6f1897778d 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -12,6 +12,7 @@ pub mod oauth2authorizedapp; pub mod oauth2client; #[cfg(feature = "openid")] pub mod oauth2token; +pub mod polling_token; pub mod session; pub mod settings; pub mod user; diff --git a/src/db/models/polling_token.rs b/src/db/models/polling_token.rs new file mode 100644 index 0000000000..3235f6b05c --- /dev/null +++ b/src/db/models/polling_token.rs @@ -0,0 +1,38 @@ +use chrono::{NaiveDateTime, Utc}; +use model_derive::Model; +use sqlx::{query_as, Error as SqlxError}; + +use crate::random::gen_alphanumeric; + +use super::DbPool; + +// Token used for polling requests. +#[derive(Clone, Debug, Model)] +pub struct PollingToken { + pub id: Option, + pub token: String, + pub device_id: i64, + pub created_at: NaiveDateTime, +} + +impl PollingToken { + pub fn new(device_id: i64) -> Self { + Self { + id: None, + device_id, + token: gen_alphanumeric(32), + created_at: Utc::now().naive_utc(), + } + } + + pub async fn find(pool: &DbPool, token: &str) -> Result, SqlxError> { + query_as!( + Self, + "SELECT id, token, device_id, created_at + FROM pollingtoken WHERE token = $1", + token + ) + .fetch_optional(pool) + .await + } +} diff --git a/src/db/models/user.rs b/src/db/models/user.rs index 4b873e3c96..6032543af8 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -871,6 +871,27 @@ impl User { } Ok(()) } + + pub async fn find_by_device_id<'e, E>( + executor: E, + device_id: i64, + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT u.id \"id?\", u.username, u.password_hash, u.last_name, u.first_name, u.email, \ + u.phone, u.mfa_enabled, u.totp_enabled, u.email_mfa_enabled, \ + u.totp_secret, u.email_mfa_secret, u.mfa_method \"mfa_method: _\", u.recovery_codes, u.is_active, u.openid_login \ + FROM \"user\" as u \ + JOIN \"device\" as d ON u.id = d.user_id \ + WHERE d.id = $1", + device_id + ) + .fetch_optional(executor) + .await + } } #[cfg(test)] diff --git a/src/enterprise/grpc/mod.rs b/src/enterprise/grpc/mod.rs new file mode 100644 index 0000000000..505916a0a5 --- /dev/null +++ b/src/enterprise/grpc/mod.rs @@ -0,0 +1 @@ +pub mod polling; diff --git a/src/enterprise/grpc/polling.rs b/src/enterprise/grpc/polling.rs new file mode 100644 index 0000000000..f351a89be2 --- /dev/null +++ b/src/enterprise/grpc/polling.rs @@ -0,0 +1,88 @@ +use crate::{ + db::{models::polling_token::PollingToken, DbPool, Device, User}, + enterprise::license::{get_cached_license, validate_license}, + grpc::utils::build_device_config_response, +}; +use tonic::Status; + +use crate::grpc::proto::{InstanceInfoRequest, InstanceInfoResponse}; + +pub struct PollingServer { + pool: DbPool, +} + +impl PollingServer { + #[must_use] + pub fn new(pool: DbPool) -> Self { + Self { pool } + } + + /// Checks validity of polling session + async fn validate_session(&self, token: &str) -> Result { + debug!("Validating polling token. Token: {token}"); + + // Polling service is enterprise-only, check the lincense + if validate_license(get_cached_license().as_ref()).is_err() { + debug!("No valid license, denying instance polling info"); + return Err(Status::permission_denied("no valid license")); + } + + // Validate the token + let Some(token) = PollingToken::find(&self.pool, token).await.map_err(|err| { + error!("Failed to retrieve token: {err}"); + Status::internal("failed to retrieve token") + })? + else { + error!("Invalid token {token:?}"); + return Err(Status::permission_denied("invalid token")); + }; + + // Polling tokens are valid indefinitely + info!("Token validation successful {token:?}."); + Ok(token) + } + + /// Prepares instance info for polling requests. Enterprise only. + pub async fn info(&self, request: InstanceInfoRequest) -> Result { + trace!("Polling info start"); + let token = self.validate_session(&request.token).await?; + let Some(device) = Device::find_by_id(&self.pool, token.device_id) + .await + .map_err(|err| { + error!("Failed to retrieve device id {}: {err}", token.device_id); + Status::internal("failed to retrieve device") + })? + else { + error!("Device id {} not found", token.device_id); + return Err(Status::internal("device not found")); + }; + debug!("Polling info for device: {}", device.wireguard_pubkey); + + // Ensure user is active + let device_id = device.id.expect("missing device id"); + let Some(user) = User::find_by_device_id(&self.pool, device_id) + .await + .map_err(|err| { + error!("Failed to retrieve user for device id {device_id}: {err}"); + Status::internal("failed to retrieve user") + })? + else { + error!("User for device id {device_id} not found"); + return Err(Status::internal("user not found")); + }; + if !user.is_active { + warn!( + "Denying polling info for inactive user {}({:?})", + user.username, user.id + ); + return Err(Status::permission_denied("user inactive")); + } + + // Build & return polling info + let device_config = + build_device_config_response(&self.pool, &device.wireguard_pubkey).await?; + Ok(InstanceInfoResponse { + device_config: Some(device_config), + }) + } +} diff --git a/src/enterprise/license.rs b/src/enterprise/license.rs index 4500bdfb8f..1166675f26 100644 --- a/src/enterprise/license.rs +++ b/src/enterprise/license.rs @@ -488,7 +488,7 @@ async fn renew_license(db_pool: &DbPool) -> Result { /// /// This function checks the following two things: /// 1. Does the cached license exist -/// 2. Does the cached license is past its maximum expiry date +/// 2. Is the cached license past its maximum expiry date pub fn validate_license(license: Option<&License>) -> Result<(), LicenseError> { debug!("Validating if the license is present and not expired..."); match license { diff --git a/src/enterprise/mod.rs b/src/enterprise/mod.rs index 6af234a386..dadca5b4ce 100644 --- a/src/enterprise/mod.rs +++ b/src/enterprise/mod.rs @@ -1,3 +1,4 @@ pub mod db; +pub mod grpc; pub mod handlers; pub mod license; diff --git a/src/grpc/enrollment.rs b/src/grpc/enrollment.rs index 60ce1d536d..7568fa629b 100644 --- a/src/grpc/enrollment.rs +++ b/src/grpc/enrollment.rs @@ -1,7 +1,7 @@ use std::sync::Arc; +use super::InstanceInfo; use ipnetwork::IpNetwork; -use reqwest::Url; use sqlx::Transaction; use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; use tonic::Status; @@ -15,13 +15,14 @@ use super::proto::{ use crate::{ db::{ models::{ - device::{DeviceConfig, DeviceInfo, WireguardNetworkDevice}, + device::{DeviceConfig, DeviceInfo}, enrollment::{Token, TokenError, ENROLLMENT_TOKEN_TYPE}, - wireguard::WireguardNetwork, + polling_token::PollingToken, }, DbPool, Device, GatewayEvent, Settings, User, }, enterprise::db::models::enterprise_settings::EnterpriseSettings, + grpc::utils::build_device_config_response, handlers::{mail::send_new_device_added_email, user::check_password_strength}, headers::get_device_info, ldap::utils::ldap_add_user, @@ -38,40 +39,6 @@ pub(super) struct EnrollmentServer { ldap_feature_active: bool, } -#[derive(Debug)] -struct InstanceInfo { - id: uuid::Uuid, - name: String, - url: Url, - proxy_url: Url, - username: String, -} - -impl InstanceInfo { - pub fn new>(settings: Settings, username: S) -> Self { - let config = server_config(); - InstanceInfo { - id: settings.uuid, - name: settings.instance_name, - url: config.url.clone(), - proxy_url: config.enrollment_url.clone(), - username: username.into(), - } - } -} - -impl From for super::proto::InstanceInfo { - fn from(instance: InstanceInfo) -> Self { - Self { - name: instance.name, - id: instance.id.to_string(), - url: instance.url.to_string(), - proxy_url: instance.proxy_url.to_string(), - username: instance.username, - } - } -} - impl EnrollmentServer { #[must_use] pub fn new( @@ -91,22 +58,31 @@ impl EnrollmentServer { } } - // check if token provided with request corresponds to a valid session - async fn validate_session(&self, token: Option<&str>) -> Result { - info!("Start validating session. Token {token:?}"); + /// Checks if token provided with request corresponds to a valid enrollment session + async fn validate_session(&self, token: &Option) -> Result { + info!("Validating enrollment session. Token: {token:?}"); let Some(token) = token else { error!("Missing authorization header in request"); return Err(Status::unauthenticated("Missing authorization header")); }; - debug!("Validating session token: {token}"); - let enrollment = Token::find_by_id(&self.pool, token).await?; - debug!("Verify is token valid {enrollment:?}."); + debug!("Found matching token, verifying validity: {enrollment:?}."); + if !enrollment + .token_type + .as_ref() + .is_some_and(|token_type| token_type == ENROLLMENT_TOKEN_TYPE) + { + error!( + "Invalid token type used in enrollment process: {:?}", + enrollment.token_type + ); + return Err(Status::permission_denied("invalid token")); + } if enrollment.is_session_valid(server_config().enrollment_session_timeout.as_secs()) { - info!("Session validated"); + info!("Enrollment session validated: {enrollment:?}"); Ok(enrollment) } else { - error!("Session expired"); + error!("Enrollment session expired: {enrollment:?}"); Err(Status::unauthenticated("Session expired")) } } @@ -233,7 +209,7 @@ impl EnrollmentServer { req_device_info: Option, ) -> Result<(), Status> { debug!("Activating user account: {request:?}"); - let enrollment = self.validate_session(request.token.as_deref()).await?; + let enrollment = self.validate_session(&request.token).await?; let ip_address; let device_info; @@ -344,7 +320,6 @@ impl EnrollmentServer { })?; info!("User {} activated", user.username); - Ok(()) } @@ -354,7 +329,7 @@ impl EnrollmentServer { req_device_info: Option, ) -> Result { debug!("Adding new user device: {request:?}"); - let enrollment = self.validate_session(request.token.as_deref()).await?; + let enrollment = self.validate_session(&request.token).await?; // fetch related users let user = enrollment.fetch_user(&self.pool).await?; @@ -447,6 +422,17 @@ impl EnrollmentServer { })?; debug!("Settings {settings:?}"); + // create polling token for further client communication + debug!("Creating polling token for further client communication"); + let mut token = PollingToken::new(device.id.ok_or_else(|| { + error!("No device id"); + Status::internal("unexpected error") + })?); + token.save(&mut *transaction).await.map_err(|err| { + error!("Failed to save PollingToken: {err}"); + Status::internal("failed to save polling token") + })?; + transaction.commit().await.map_err(|_| { error!("Failed to commit transaction"); Status::internal("unexpected error") @@ -481,94 +467,21 @@ impl EnrollmentServer { device: Some(device.into()), configs: configs.into_iter().map(Into::into).collect(), instance: Some(InstanceInfo::new(settings, &user.username).into()), + token: Some(token.token), }; debug!("Created a create device response {response:?}."); Ok(response) } - /// Get all information needed - /// to update instance information for desktop client + /// Get all information needed to update instance information for desktop client pub async fn get_network_info( &self, request: ExistingDevice, ) -> Result { debug!("Getting network info for device: {:?}", request.pubkey); - let enrollment = self.validate_session(request.token.as_deref()).await?; - - // get enrollment user - let user = enrollment.fetch_user(&self.pool).await?; - - Device::validate_pubkey(&request.pubkey).map_err(|_| { - error!("Invalid pubkey {}", request.pubkey); - Status::invalid_argument("invalid pubkey") - })?; - // Find existing device by public key - let device = Device::find_by_pubkey(&self.pool, &request.pubkey) - .await - .map_err(|_| { - error!("Failed to get device by its pubkey: {}", request.pubkey); - Status::internal("unexpected error") - })?; - - let settings = Settings::get_settings(&self.pool).await.map_err(|_| { - error!("Failed to get settings"); - Status::internal("unexpected error") - })?; - - let networks = WireguardNetwork::all(&self.pool).await.map_err(|err| { - error!("Failed to fetch all networks: {err}"); - Status::internal(format!("unexpected error: {err}")) - })?; - - let mut configs: Vec = Vec::new(); - if let Some(device) = device { - for network in networks { - let (Some(device_id), Some(network_id)) = (device.id, network.id) else { - continue; - }; - let wireguard_network_device = - WireguardNetworkDevice::find(&self.pool, device_id, network_id) - .await - .map_err(|err| { - error!("Failed to fetch wireguard network device for device {} and network {}: {err}", device_id, network_id); - Status::internal(format!("unexpected error: {err}")) - })?; - if let Some(wireguard_network_device) = wireguard_network_device { - let allowed_ips = network - .allowed_ips - .iter() - .map(IpNetwork::to_string) - .collect::>() - .join(","); - let config = ProtoDeviceConfig { - config: device.create_config(&network, &wireguard_network_device), - network_id, - network_name: network.name, - assigned_ip: wireguard_network_device.wireguard_ip.to_string(), - endpoint: format!("{}:{}", network.endpoint, network.port), - pubkey: network.pubkey, - allowed_ips, - dns: network.dns, - mfa_enabled: network.mfa_enabled, - keepalive_interval: network.keepalive_interval, - }; - configs.push(config); - } - } - - info!("Device {} configs fetched", device.name); - - let response = DeviceConfigResponse { - device: Some(device.into()), - configs, - instance: Some(InstanceInfo::new(settings, &user.username).into()), - }; - - Ok(response) - } else { - Err(Status::internal("device not found error")) - } + let _token = self.validate_session(&request.token).await?; + build_device_config_response(&self.pool, &request.pubkey).await } } diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 5f97f606b3..a7be80e6c5 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -9,7 +9,9 @@ use std::{ sync::{Arc, Mutex}, }; +use crate::enterprise::grpc::polling::PollingServer; use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; +use reqwest::Url; use serde::Serialize; use thiserror::Error; use tokio::{ @@ -42,8 +44,11 @@ use self::{ worker::{worker_service_server::WorkerServiceServer, WorkerServer}, }; use crate::{ - auth::failed_login::FailedLoginMap, db::AppEvent, - handlers::mail::send_gateway_disconnected_email, mail::Mail, server_config, + auth::failed_login::FailedLoginMap, + db::{AppEvent, Settings}, + handlers::mail::send_gateway_disconnected_email, + mail::Mail, + server_config, }; #[cfg(feature = "worker")] use crate::{ @@ -59,6 +64,7 @@ pub(crate) mod gateway; #[cfg(any(feature = "wireguard", feature = "worker"))] mod interceptor; pub mod password_reset; +pub(crate) mod utils; #[cfg(feature = "worker")] pub mod worker; @@ -355,7 +361,8 @@ pub async fn run_grpc_bidi_stream( user_agent_parser, ); let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx.clone()); - let mut client_mfa_server = ClientMfaServer::new(pool, mail_tx, wireguard_tx); + let mut client_mfa_server = ClientMfaServer::new(pool.clone(), mail_tx, wireguard_tx); + let polling_server = PollingServer::new(pool); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; let endpoint = endpoint @@ -505,6 +512,18 @@ pub async fn run_grpc_bidi_stream( } } } + // rpc LocationInfo (LocationInfoRequest) returns (LocationInfoResponse) + Some(core_request::Payload::InstanceInfo(request)) => { + match polling_server.info(request).await { + Ok(response_payload) => { + Some(core_response::Payload::InstanceInfo(response_payload)) + } + Err(err) => { + error!("Instance info error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } // Reply without payload. None => None, }; @@ -619,3 +638,37 @@ pub struct WorkerDetail { ip: IpAddr, connected: bool, } + +#[derive(Debug)] +pub struct InstanceInfo { + id: uuid::Uuid, + name: String, + url: Url, + proxy_url: Url, + username: String, +} + +impl InstanceInfo { + pub fn new>(settings: Settings, username: S) -> Self { + let config = server_config(); + InstanceInfo { + id: settings.uuid, + name: settings.instance_name, + url: config.url.clone(), + proxy_url: config.enrollment_url.clone(), + username: username.into(), + } + } +} + +impl From for crate::grpc::proto::InstanceInfo { + fn from(instance: InstanceInfo) -> Self { + Self { + name: instance.name, + id: instance.id.to_string(), + url: instance.url.to_string(), + proxy_url: instance.proxy_url.to_string(), + username: instance.username, + } + } +} diff --git a/src/grpc/password_reset.rs b/src/grpc/password_reset.rs index 98e021bdda..bea8fc54be 100644 --- a/src/grpc/password_reset.rs +++ b/src/grpc/password_reset.rs @@ -37,26 +37,33 @@ impl PasswordResetServer { } } - // check if token provided with request corresponds to a valid enrollment session - async fn validate_session(&self, token: Option<&str>) -> Result { - debug!("Validating enrollment session"); + /// Checks if token provided with request corresponds to a valid password reset session + async fn validate_session(&self, token: &Option) -> Result { + info!("Validating password reset session. Token: {token:?}"); let Some(token) = token else { error!("Missing authorization header in request"); return Err(Status::unauthenticated("Missing authorization header")); }; - - debug!("Validating enrollment session token: {token}"); let enrollment = Token::find_by_id(&self.pool, token).await?; + debug!("Found matching token, verifying validity: {enrollment:?}."); + if !enrollment + .token_type + .as_ref() + .is_some_and(|token_type| token_type == PASSWORD_RESET_TOKEN_TYPE) + { + error!( + "Invalid token type used in password reset process: {:?}", + enrollment.token_type + ); + return Err(Status::permission_denied("invalid token")); + } if enrollment.is_session_valid(server_config().enrollment_session_timeout.as_secs()) { - info!( - "Enrollment session validated for user {}.", - enrollment.user_id - ); + info!("Password reset session validated: {enrollment:?}.",); Ok(enrollment) } else { - error!("Enrollment session expired"); - Err(Status::unauthenticated("Enrollment session expired")) + error!("Password reset session expired: {enrollment:?}"); + Err(Status::unauthenticated("Session expired")) } } @@ -207,7 +214,7 @@ impl PasswordResetServer { req_device_info: Option, ) -> Result<(), Status> { debug!("Starting password reset: {request:?}"); - let enrollment = self.validate_session(request.token.as_deref()).await?; + let enrollment = self.validate_session(&request.token).await?; let ip_address; let user_agent; diff --git a/src/grpc/utils.rs b/src/grpc/utils.rs new file mode 100644 index 0000000000..f19a1bb555 --- /dev/null +++ b/src/grpc/utils.rs @@ -0,0 +1,92 @@ +use super::InstanceInfo; +use ipnetwork::IpNetwork; +use tonic::Status; + +use super::proto::{DeviceConfig as ProtoDeviceConfig, DeviceConfigResponse}; +use crate::db::{ + models::{device::WireguardNetworkDevice, wireguard::WireguardNetwork}, + DbPool, Device, Settings, User, +}; + +pub(crate) async fn build_device_config_response( + pool: &DbPool, + pubkey: &str, +) -> Result { + Device::validate_pubkey(pubkey).map_err(|_| { + error!("Invalid pubkey {pubkey}"); + Status::invalid_argument("invalid pubkey") + })?; + // Find existing device by public key + let device = Device::find_by_pubkey(pool, pubkey).await.map_err(|_| { + error!("Failed to get device by its pubkey: {pubkey}"); + Status::internal("unexpected error") + })?; + let settings = Settings::get_settings(pool).await.map_err(|_| { + error!("Failed to get settings"); + Status::internal("unexpected error") + })?; + + let networks = WireguardNetwork::all(pool).await.map_err(|err| { + error!("Failed to fetch all networks: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + + let mut configs: Vec = Vec::new(); + let Some(device) = device else { + return Err(Status::internal("device not found error")); + }; + let user = User::find_by_id(pool, device.user_id) + .await + .map_err(|_| { + error!("Failed to get user: {}", device.user_id); + Status::internal("unexpected error") + })? + .ok_or_else(|| { + error!("User not found: {}", device.user_id); + Status::internal("unexpected error") + })?; + for network in networks { + let (Some(device_id), Some(network_id)) = (device.id, network.id) else { + continue; + }; + let wireguard_network_device = WireguardNetworkDevice::find(pool, device_id, network_id) + .await + .map_err(|err| { + error!( + "Failed to fetch wireguard network device for device {} and network {}: {err}", + device_id, network_id + ); + Status::internal(format!("unexpected error: {err}")) + })?; + if let Some(wireguard_network_device) = wireguard_network_device { + let allowed_ips = network + .allowed_ips + .iter() + .map(IpNetwork::to_string) + .collect::>() + .join(","); + let config = ProtoDeviceConfig { + config: device.create_config(&network, &wireguard_network_device), + network_id, + network_name: network.name, + assigned_ip: wireguard_network_device.wireguard_ip.to_string(), + endpoint: format!("{}:{}", network.endpoint, network.port), + pubkey: network.pubkey, + allowed_ips, + dns: network.dns, + mfa_enabled: network.mfa_enabled, + keepalive_interval: network.keepalive_interval, + }; + configs.push(config); + } + } + + info!("Device {} configs fetched", device.name); + + Ok(DeviceConfigResponse { + device: Some(device.into()), + configs, + instance: Some(InstanceInfo::new(settings, &user.username).into()), + token: None, + }) +}