diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 629c753d..3710f64c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: test: runs-on: - codebuild-defguard-gateway-runner-${{ github.run_id }}-${{ github.run_attempt }} - container: rust:1 + container: public.ecr.aws/docker/library/rust:1 steps: - name: Debug diff --git a/Cargo.lock b/Cargo.lock index d0286c1b..9336521d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -401,6 +401,7 @@ dependencies = [ "tonic", "tonic-prost", "tonic-prost-build", + "tower", "tracing", "vergen-git2", "x25519-dalek", @@ -409,7 +410,7 @@ dependencies = [ [[package]] name = "defguard_version" version = "0.0.0" -source = "git+https://github.com/DefGuard/defguard.git?rev=f61ce40927a4d21095ea53a691219d5ae46e3e4e#f61ce40927a4d21095ea53a691219d5ae46e3e4e" +source = "git+https://github.com/DefGuard/defguard.git?rev=a5709e7117103458ad8417d4437a8a369ca5bbce#a5709e7117103458ad8417d4437a8a369ca5bbce" dependencies = [ "http", "os_info", diff --git a/Cargo.toml b/Cargo.toml index 96915c03..b7d63934 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "1.5.0" edition = "2021" [dependencies] -defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "f61ce40927a4d21095ea53a691219d5ae46e3e4e" } +defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "a5709e7117103458ad8417d4437a8a369ca5bbce" } axum = { version = "0.8", features = ["macros"] } base64 = "0.22" clap = { version = "4.5", features = ["derive", "env"] } @@ -32,6 +32,7 @@ tonic = { version = "0.14", default-features = false, features = [ ] } tracing = "0.1" tonic-prost = "0.14" +tower = "0.5.2" [target.'cfg(target_os = "linux")'.dependencies] nftnl = { git = "https://github.com/DefGuard/nftnl-rs.git", rev = "1a1147271f43b9d7182a114bb056a5224c35d38f" } diff --git a/src/gateway.rs b/src/gateway.rs index e5155881..4f3d78e8 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,5 +1,6 @@ use defguard_version::{ - client::version_interceptor, parse_metadata, ComponentInfo, DefguardComponent, Version, + client::ClientVersionInterceptor, get_tracing_variables, parse_metadata, ComponentInfo, + DefguardComponent, Version, }; use defguard_wireguard_rs::{net::IpAddrMask, WireguardInterfaceApi}; use gethostname::gethostname; @@ -23,10 +24,11 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{ codegen::InterceptedService, metadata::{Ascii, MetadataValue}, - service::Interceptor, + service::{Interceptor, InterceptorLayer}, transport::{Certificate, Channel, ClientTlsConfig, Endpoint}, Request, Status, Streaming, }; +use tower::ServiceBuilder; use tracing::{instrument, Instrument}; use crate::{ @@ -41,6 +43,7 @@ use crate::{ gateway_service_client::GatewayServiceClient, stats_update::Payload, update, Configuration, ConfigurationRequest, Peer, StatsUpdate, Update, }, + version::ensure_core_version_supported, VERSION, }; @@ -72,29 +75,23 @@ impl From for InterfaceConfiguration { } } -type InterceptorFn = Box) -> Result, Status> + Send + Sync>; - /// Intercepts all grpc requests adding authentication and version metadata -struct RequestInterceptor { +struct AuthInterceptor { hostname: MetadataValue, token: MetadataValue, - version: defguard_version::Version, - version_interceptor_fn: InterceptorFn, } -impl Clone for RequestInterceptor { +impl Clone for AuthInterceptor { fn clone(&self) -> Self { Self { hostname: self.hostname.clone(), token: self.token.clone(), - version: self.version.clone(), - version_interceptor_fn: Box::new(version_interceptor(self.version.clone())), } } } -impl RequestInterceptor { - fn new(token: &str, version: Version) -> Result { +impl AuthInterceptor { + fn new(token: &str) -> Result { let token = MetadataValue::try_from(token)?; let hostname = MetadataValue::try_from( gethostname() @@ -102,20 +99,12 @@ impl RequestInterceptor { .expect("Unable to get current hostname during gRPC connection setup."), )?; - Ok(Self { - hostname, - token, - version: version.clone(), - version_interceptor_fn: Box::new(version_interceptor(version)), - }) + Ok(Self { hostname, token }) } } -impl Interceptor for RequestInterceptor { - fn call(&mut self, request: Request<()>) -> Result, Status> { - // Apply version interceptor - adds version headers - let mut request = (self.version_interceptor_fn)(request)?; - +impl Interceptor for AuthInterceptor { + fn call(&mut self, mut request: Request<()>) -> Result, Status> { // Add auth headers let metadata = request.metadata_mut(); metadata.insert("authorization", self.token.clone()); @@ -126,6 +115,9 @@ impl Interceptor for RequestInterceptor { } type PubKey = String; +type GatewayClientType = GatewayServiceClient< + InterceptedService, ClientVersionInterceptor>, +>; pub struct Gateway { config: Config, @@ -137,7 +129,7 @@ pub struct Gateway { #[cfg_attr(not(target_os = "linux"), allow(unused))] firewall_config: Option, pub connected: Arc, - client: GatewayServiceClient>, + client: GatewayClientType, core_info: Option, stats_thread: Option>, } @@ -282,7 +274,7 @@ impl Gateway { #[instrument(skip_all)] async fn handle_stats_thread( - mut client: GatewayServiceClient>, + mut client: GatewayClientType, rx: UnboundedReceiverStream, ) { let status = client.stats(rx).await; @@ -463,19 +455,6 @@ impl Gateway { Ok(()) } - fn get_tracing_variables(&self) -> (String, String) { - let version = self - .core_info - .as_ref() - .map_or(String::from("?"), |info| info.version.to_string()); - let info = self - .core_info - .as_ref() - .map_or(String::from("?"), |info| info.system.to_string()); - - (version, info) - } - /// Continuously tries to connect to gRPC endpoint. Once the connection is established /// configures the interface, starts the stats thread, connects and returns the updates stream. async fn connect(&mut self) -> Streaming { @@ -499,7 +478,7 @@ impl Gateway { match (response, stream) { (Ok(response), Ok(stream)) => { self.core_info = parse_metadata(response.metadata()); - let (version, info) = self.get_tracing_variables(); + let (version, info) = get_tracing_variables(&self.core_info); let span = tracing::info_span!( "core_configuration", component = %DefguardComponent::Core, @@ -508,6 +487,10 @@ impl Gateway { ); let _guard = span.enter(); + // check core version and exit if it's not supported + let version = self.core_info.as_ref().map(|info| &info.version); + ensure_core_version_supported(version); + if let Err(err) = self.configure(response.into_inner()) { error!("Interface configuration failed: {err}"); continue; @@ -532,10 +515,7 @@ impl Gateway { } } - fn setup_client( - config: &Config, - ) -> Result>, GatewayError> - { + fn setup_client(config: &Config) -> Result { debug!("Preparing gRPC client configuration"); let tls = ClientTlsConfig::new(); // Use CA if provided, otherwise load certificates from system. @@ -554,9 +534,13 @@ impl Gateway { .keep_alive_while_idle(true) .tls_config(tls)?; let channel = endpoint.connect_lazy(); - let version = Version::parse(VERSION)?; - let request_interceptor = RequestInterceptor::new(&config.token, version)?; - let client = GatewayServiceClient::with_interceptor(channel, request_interceptor); + let version_interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?); + let auth_interceptor = AuthInterceptor::new(&config.token)?; + let channel = ServiceBuilder::new() + .layer(InterceptorLayer::new(version_interceptor)) + .layer(InterceptorLayer::new(auth_interceptor)) + .service(channel); + let client = GatewayServiceClient::new(channel); debug!("gRPC client configuration done"); Ok(client) @@ -700,7 +684,7 @@ impl Gateway { debug!("Executing specified POST_UP command: {post_up}"); execute_command(post_up)?; } - let (version, info) = self.get_tracing_variables(); + let (version, info) = get_tracing_variables(&self.core_info); let span = tracing::info_span!( "core_grpc", component = %DefguardComponent::Core, diff --git a/src/lib.rs b/src/lib.rs index 251696cf..66c5ef14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod config; pub mod error; pub mod gateway; pub mod server; +mod version; pub mod proto { pub mod gateway { @@ -26,7 +27,7 @@ use syslog::{BasicLogger, Facility, Formatter3164}; pub mod enterprise; -pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-", env!("VERGEN_GIT_SHA")); +pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_GIT_SHA")); /// Masks object's field with "***" string. /// Used to log sensitive/secret objects. diff --git a/src/version.rs b/src/version.rs new file mode 100644 index 00000000..dcd9e985 --- /dev/null +++ b/src/version.rs @@ -0,0 +1,18 @@ +use defguard_version::Version; + +const MIN_CORE_VERSION: Version = Version::new(1, 5, 0); + +/// Ensures the core version meets minimum version requirements. +/// Terminates the process if it doesn't. +pub(crate) fn ensure_core_version_supported(core_version: Option<&Version>) { + let Some(core_version) = core_version else { + error!("Missing core component version information. This most likely means that core component uses unsupported version. Exiting."); + std::process::exit(1); + }; + if core_version < &MIN_CORE_VERSION { + error!("Core version {core_version} is not supported. Minimal supported core version is {MIN_CORE_VERSION}. Exiting."); + std::process::exit(1); + } + + info!("Core version {core_version} is supported"); +}