Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
test:
runs-on:
- codebuild-defguard-gateway-runner-${{ github.run_id }}-${{ github.run_attempt }}
container: rust:1
container: public.ecr.aws/docker/library/rust:1

steps:
- name: Debug
Expand Down
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "1.5.0"
edition = "2021"

[dependencies]
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "f61ce40927a4d21095ea53a691219d5ae46e3e4e" }
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "a5709e7117103458ad8417d4437a8a369ca5bbce" }
axum = { version = "0.8", features = ["macros"] }
base64 = "0.22"
clap = { version = "4.5", features = ["derive", "env"] }
Expand Down Expand Up @@ -32,6 +32,7 @@ tonic = { version = "0.14", default-features = false, features = [
] }
tracing = "0.1"
tonic-prost = "0.14"
tower = "0.5.2"

[target.'cfg(target_os = "linux")'.dependencies]
nftnl = { git = "https://github.com/DefGuard/nftnl-rs.git", rev = "1a1147271f43b9d7182a114bb056a5224c35d38f" }
Expand Down
78 changes: 31 additions & 47 deletions src/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use defguard_version::{
client::version_interceptor, parse_metadata, ComponentInfo, DefguardComponent, Version,
client::ClientVersionInterceptor, get_tracing_variables, parse_metadata, ComponentInfo,
DefguardComponent, Version,
};
use defguard_wireguard_rs::{net::IpAddrMask, WireguardInterfaceApi};
use gethostname::gethostname;
Expand All @@ -23,10 +24,11 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{
codegen::InterceptedService,
metadata::{Ascii, MetadataValue},
service::Interceptor,
service::{Interceptor, InterceptorLayer},
transport::{Certificate, Channel, ClientTlsConfig, Endpoint},
Request, Status, Streaming,
};
use tower::ServiceBuilder;
use tracing::{instrument, Instrument};

use crate::{
Expand All @@ -41,6 +43,7 @@ use crate::{
gateway_service_client::GatewayServiceClient, stats_update::Payload, update, Configuration,
ConfigurationRequest, Peer, StatsUpdate, Update,
},
version::ensure_core_version_supported,
VERSION,
};

Expand Down Expand Up @@ -72,50 +75,36 @@ impl From<Configuration> for InterfaceConfiguration {
}
}

type InterceptorFn = Box<dyn Fn(Request<()>) -> Result<Request<()>, Status> + Send + Sync>;

/// Intercepts all grpc requests adding authentication and version metadata
struct RequestInterceptor {
struct AuthInterceptor {
hostname: MetadataValue<Ascii>,
token: MetadataValue<Ascii>,
version: defguard_version::Version,
version_interceptor_fn: InterceptorFn,
}

impl Clone for RequestInterceptor {
impl Clone for AuthInterceptor {
fn clone(&self) -> Self {
Self {
hostname: self.hostname.clone(),
token: self.token.clone(),
version: self.version.clone(),
version_interceptor_fn: Box::new(version_interceptor(self.version.clone())),
}
}
}

impl RequestInterceptor {
fn new(token: &str, version: Version) -> Result<Self, GatewayError> {
impl AuthInterceptor {
fn new(token: &str) -> Result<Self, GatewayError> {
let token = MetadataValue::try_from(token)?;
let hostname = MetadataValue::try_from(
gethostname()
.to_str()
.expect("Unable to get current hostname during gRPC connection setup."),
)?;

Ok(Self {
hostname,
token,
version: version.clone(),
version_interceptor_fn: Box::new(version_interceptor(version)),
})
Ok(Self { hostname, token })
}
}

impl Interceptor for RequestInterceptor {
fn call(&mut self, request: Request<()>) -> Result<Request<()>, Status> {
// Apply version interceptor - adds version headers
let mut request = (self.version_interceptor_fn)(request)?;

impl Interceptor for AuthInterceptor {
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
// Add auth headers
let metadata = request.metadata_mut();
metadata.insert("authorization", self.token.clone());
Expand All @@ -126,6 +115,9 @@ impl Interceptor for RequestInterceptor {
}

type PubKey = String;
type GatewayClientType = GatewayServiceClient<
InterceptedService<InterceptedService<Channel, AuthInterceptor>, ClientVersionInterceptor>,
>;

pub struct Gateway {
config: Config,
Expand All @@ -137,7 +129,7 @@ pub struct Gateway {
#[cfg_attr(not(target_os = "linux"), allow(unused))]
firewall_config: Option<FirewallConfig>,
pub connected: Arc<AtomicBool>,
client: GatewayServiceClient<InterceptedService<Channel, RequestInterceptor>>,
client: GatewayClientType,
core_info: Option<ComponentInfo>,
stats_thread: Option<JoinHandle<()>>,
}
Expand Down Expand Up @@ -282,7 +274,7 @@ impl Gateway {

#[instrument(skip_all)]
async fn handle_stats_thread(
mut client: GatewayServiceClient<InterceptedService<Channel, RequestInterceptor>>,
mut client: GatewayClientType,
rx: UnboundedReceiverStream<StatsUpdate>,
) {
let status = client.stats(rx).await;
Expand Down Expand Up @@ -463,19 +455,6 @@ impl Gateway {
Ok(())
}

fn get_tracing_variables(&self) -> (String, String) {
let version = self
.core_info
.as_ref()
.map_or(String::from("?"), |info| info.version.to_string());
let info = self
.core_info
.as_ref()
.map_or(String::from("?"), |info| info.system.to_string());

(version, info)
}

/// Continuously tries to connect to gRPC endpoint. Once the connection is established
/// configures the interface, starts the stats thread, connects and returns the updates stream.
async fn connect(&mut self) -> Streaming<Update> {
Expand All @@ -499,7 +478,7 @@ impl Gateway {
match (response, stream) {
(Ok(response), Ok(stream)) => {
self.core_info = parse_metadata(response.metadata());
let (version, info) = self.get_tracing_variables();
let (version, info) = get_tracing_variables(&self.core_info);
let span = tracing::info_span!(
"core_configuration",
component = %DefguardComponent::Core,
Expand All @@ -508,6 +487,10 @@ impl Gateway {
);
let _guard = span.enter();

// check core version and exit if it's not supported
let version = self.core_info.as_ref().map(|info| &info.version);
ensure_core_version_supported(version);

if let Err(err) = self.configure(response.into_inner()) {
error!("Interface configuration failed: {err}");
continue;
Expand All @@ -532,10 +515,7 @@ impl Gateway {
}
}

fn setup_client(
config: &Config,
) -> Result<GatewayServiceClient<InterceptedService<Channel, RequestInterceptor>>, GatewayError>
{
fn setup_client(config: &Config) -> Result<GatewayClientType, GatewayError> {
debug!("Preparing gRPC client configuration");
let tls = ClientTlsConfig::new();
// Use CA if provided, otherwise load certificates from system.
Expand All @@ -554,9 +534,13 @@ impl Gateway {
.keep_alive_while_idle(true)
.tls_config(tls)?;
let channel = endpoint.connect_lazy();
let version = Version::parse(VERSION)?;
let request_interceptor = RequestInterceptor::new(&config.token, version)?;
let client = GatewayServiceClient::with_interceptor(channel, request_interceptor);
let version_interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?);
let auth_interceptor = AuthInterceptor::new(&config.token)?;
let channel = ServiceBuilder::new()
.layer(InterceptorLayer::new(version_interceptor))
.layer(InterceptorLayer::new(auth_interceptor))
.service(channel);
let client = GatewayServiceClient::new(channel);

debug!("gRPC client configuration done");
Ok(client)
Expand Down Expand Up @@ -700,7 +684,7 @@ impl Gateway {
debug!("Executing specified POST_UP command: {post_up}");
execute_command(post_up)?;
}
let (version, info) = self.get_tracing_variables();
let (version, info) = get_tracing_variables(&self.core_info);
let span = tracing::info_span!(
"core_grpc",
component = %DefguardComponent::Core,
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod config;
pub mod error;
pub mod gateway;
pub mod server;
mod version;

pub mod proto {
pub mod gateway {
Expand All @@ -26,7 +27,7 @@ use syslog::{BasicLogger, Facility, Formatter3164};

pub mod enterprise;

pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-", env!("VERGEN_GIT_SHA"));
pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_GIT_SHA"));

/// Masks object's field with "***" string.
/// Used to log sensitive/secret objects.
Expand Down
18 changes: 18 additions & 0 deletions src/version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use defguard_version::Version;

const MIN_CORE_VERSION: Version = Version::new(1, 5, 0);

/// Ensures the core version meets minimum version requirements.
/// Terminates the process if it doesn't.
pub(crate) fn ensure_core_version_supported(core_version: Option<&Version>) {
let Some(core_version) = core_version else {
error!("Missing core component version information. This most likely means that core component uses unsupported version. Exiting.");
std::process::exit(1);
};
if core_version < &MIN_CORE_VERSION {
error!("Core version {core_version} is not supported. Minimal supported core version is {MIN_CORE_VERSION}. Exiting.");
std::process::exit(1);
}

info!("Core version {core_version} is supported");
}