Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 217 additions & 248 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ regex = { version = "1.12" }
reqwest = { version = "0.12", features = ["json", "http2", "gzip", "deflate"] }
rmp = { version = "0.8" }
schemars = { version = "1.1" }
sea-orm = { version = "1.1", features = ["sqlx-mysql", "sqlx-postgres", "runtime-tokio", "runtime-tokio-native-tls"] }
sea-orm = { version = "1.1", features = ["debug-print", "sqlx-mysql", "sqlx-postgres", "runtime-tokio", "runtime-tokio-native-tls"] }
sea-orm-migration = { version = "1.1", features = ["sqlx-mysql", "sqlx-postgres", "runtime-tokio"] }
secrecy = { version = "0.10", features = ["serde"] }
serde = { version = "1.0" }
serde_bytes = { version = "0.11" }
serde_json = { version = "1.0" }
Expand All @@ -76,6 +77,7 @@ hyper = { version = "1.8", features = ["http1"] }
hyper-util = { version = "0.1", features = ["tokio", "http1"] }
keycloak = { version = "26.4" }
mockall = { version = "0.14" }
rand = "0.9"
reqwest = { version = "0.12", features = ["json", "multipart"] }
sea-orm = { version = "1.1", features = ["mock", "sqlx-sqlite" ]}
serde_urlencoded = { version = "0.7" }
Expand Down
19 changes: 12 additions & 7 deletions src/bin/keystone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
//!
//! This is the entry point of the `keystone` binary.

use axum::extract::DefaultBodyLimit;
use axum::http::{self, HeaderName, Request, header};
use clap::{Parser, ValueEnum};
use color_eyre::eyre::{Report, Result};
use eyre::WrapErr;
use sea_orm::ConnectOptions;
use sea_orm::Database;
use sea_orm::{ConnectOptions, Database};
use secrecy::ExposeSecret;
use std::io;
use std::net::{Ipv4Addr, SocketAddr};
use std::path::PathBuf;
Expand All @@ -36,6 +37,7 @@ use tower_http::{
};
use tracing::{Level, debug, error, info, info_span, trace};
use tracing_subscriber::{
Layer,
filter::{LevelFilter, Targets},
prelude::*,
};
Expand All @@ -53,6 +55,9 @@ use openstack_keystone::plugin_manager::PluginManager;
use openstack_keystone::policy::PolicyFactory;
use openstack_keystone::provider::Provider;

// Default body limit 256kB
const DEFAULT_BODY_LIMIT: usize = 1024 * 256;

/// `OpenStack` Keystone.
///
/// Keystone is an `OpenStack` service that provides API client authentication, service discovery,
Expand Down Expand Up @@ -144,11 +149,10 @@ async fn main() -> Result<(), Report> {
let cloned_token = token.clone();

let cfg = Config::new(args.config)?;
let db_url = cfg.database.get_connection();
let mut opt = ConnectOptions::new(db_url.clone());
if args.verbose < 2 {
opt.sqlx_logging(false);
}
let opt: ConnectOptions = ConnectOptions::new(cfg.database.get_connection().expose_secret())
// Prevent dumping the password in plaintext.
.sqlx_logging(false)
.to_owned();

debug!("Establishing the database connection...");
let conn = Database::connect(opt)
Expand Down Expand Up @@ -192,6 +196,7 @@ async fn main() -> Result<(), Report> {
))
//.layer(PropagateRequestIdLayer::new(x_request_id))
.sensitive_request_headers(sensitive_headers.clone())
.layer(DefaultBodyLimit::max(DEFAULT_BODY_LIMIT))
.layer(
TraceLayer::new_for_http()
//.make_span_with(DefaultMakeSpan::new().include_headers(true))
Expand Down
11 changes: 5 additions & 6 deletions src/bin/keystone_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
use clap::{Parser, Subcommand};
use color_eyre::Report;
use eyre::WrapErr;
use secrecy::ExposeSecret;
use std::io;
use std::path::PathBuf;
use tracing::info;
Expand Down Expand Up @@ -87,12 +88,10 @@ async fn main() -> Result<(), Report> {
// build the tracing registry
tracing_subscriber::registry().with(log_layer).init();
let cfg = Config::new(cli.config)?;
let db_url = cfg.database.get_connection();
let mut opt = ConnectOptions::new(db_url.clone());

if cli.verbose < 2 {
opt.sqlx_logging(false);
}
let opt: ConnectOptions = ConnectOptions::new(cfg.database.get_connection().expose_secret())
// Prevent dumping the password in plaintext.
.sqlx_logging(false)
.to_owned();

info!("Establishing the database connection...");
let conn = Database::connect(opt)
Expand Down
16 changes: 10 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use config::{File, FileFormat};
use eyre::{Report, WrapErr};
use regex::Regex;
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
use std::path::PathBuf;
Expand Down Expand Up @@ -111,14 +112,16 @@ pub struct FernetTokenSection {

#[derive(Debug, Default, Deserialize, Clone)]
pub struct DatabaseSection {
pub connection: String,
/// Database URL.
pub connection: SecretString,
}

impl DatabaseSection {
pub fn get_connection(&self) -> String {
if self.connection.contains("+") {
pub fn get_connection(&self) -> SecretString {
let val = self.connection.expose_secret();
if val.contains("+") {
return Regex::new(r"(?<type>\w+)\+(\w+)://")
.map(|re| re.replace(&self.connection, "${type}://").to_string())
.map(|re| SecretString::from(re.replace(val, "${type}://").to_string()))
.unwrap_or(self.connection.clone());
}
self.connection.clone()
Expand Down Expand Up @@ -321,16 +324,17 @@ impl TryFrom<config::ConfigBuilder<config::builder::DefaultState>> for Config {
#[cfg(test)]
mod tests {
use super::*;
use secrecy::ExposeSecret;

#[test]
fn test_db_connection() {
let sot = DatabaseSection {
connection: "mysql://u:p@h".into(),
};
assert_eq!("mysql://u:p@h", sot.get_connection());
assert_eq!("mysql://u:p@h", sot.get_connection().expose_secret());
let sot = DatabaseSection {
connection: "mysql+driver://u:p@h".into(),
};
assert_eq!("mysql://u:p@h", sot.get_connection());
assert_eq!("mysql://u:p@h", sot.get_connection().expose_secret());
}
}
6 changes: 4 additions & 2 deletions src/identity/backends/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ impl IdentityBackend for SqlBackend {
{
let user_opts = user_option::get(&state.db, local_user.user_id.clone()).await?;

if password_hashing::verify_password(&self.config, auth.password, expected_hash)? {
if password_hashing::verify_password(&self.config, auth.password, expected_hash)
.await?
{
if let Some(user) = user::get(&state.db, &local_user.user_id).await? {
// TODO: Check password is expired
// TODO: reset failed login attempt
Expand Down Expand Up @@ -574,7 +576,7 @@ async fn create_user(
let password_entry = password::create(
db,
local_user.id,
password_hashing::hash_password(conf, password)?,
password_hashing::hash_password(conf, password).await?,
None,
)
.await?;
Expand Down
9 changes: 9 additions & 0 deletions src/identity/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,20 @@ impl From<IdentityDatabaseError> for IdentityProviderError {
}
}

/// Password hashing related errors.
#[derive(Error, Debug)]
pub enum IdentityProviderPasswordHashError {
/// Bcrypt error.
#[error(transparent)]
BCrypt {
#[from]
source: bcrypt::BcryptError,
},

/// Async task join error.
#[error(transparent)]
Join {
#[from]
source: tokio::task::JoinError,
},
}
57 changes: 48 additions & 9 deletions src/identity/password_hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// SPDX-License-Identifier: Apache-2.0

use std::cmp::max;
use tokio::task;
use tracing::warn;

use crate::config::{Config, PasswordHashingAlgo};
Expand All @@ -26,7 +27,8 @@ fn verify_length_and_trunc_password(password: &[u8], max_length: usize) -> &[u8]
password
}

pub fn hash_password<S: AsRef<[u8]>>(
/// Calculate password hash with the configuration defaults.
pub async fn hash_password<S: AsRef<[u8]>>(
conf: &Config,
password: S,
) -> Result<String, IdentityProviderPasswordHashError> {
Expand All @@ -35,14 +37,18 @@ pub fn hash_password<S: AsRef<[u8]>>(
let password_bytes = verify_length_and_trunc_password(
password.as_ref(),
max(conf.identity.max_password_length, 72),
);
)
.to_owned();
let rounds = conf.identity.password_hash_rounds.unwrap_or(12);
Ok(bcrypt::hash(password_bytes, rounds as u32)?)
let hash =
task::spawn_blocking(move || bcrypt::hash(password_bytes, rounds as u32)).await??;
Ok(hash)
}
}
}

pub fn verify_password<P: AsRef<[u8]>, H: AsRef<str>>(
/// Verify the password matches the hashed value.
pub async fn verify_password<P: AsRef<[u8]>, H: AsRef<str>>(
conf: &Config,
password: P,
hash: H,
Expand All @@ -52,15 +58,23 @@ pub fn verify_password<P: AsRef<[u8]>, H: AsRef<str>>(
let password_bytes = verify_length_and_trunc_password(
password.as_ref(),
max(conf.identity.max_password_length, 72),
);
Ok(bcrypt::verify(password_bytes, hash.as_ref())?)
)
.to_owned();
let password_hash = hash.as_ref().to_string();
// Do not block the main thread with a definitely long running call.
let verify =
task::spawn_blocking(move || bcrypt::verify(password_bytes, &password_hash))
.await??;
Ok(verify)
//Ok(bcrypt::verify(password_bytes, hash.as_ref())?)
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use rand::distr::{Alphanumeric, SampleString};

#[test]
fn test_verify_length_and_trunc_password() {
Expand All @@ -79,14 +93,39 @@ mod tests {
);
}

#[test]
fn test_hash_bcrypt() {
#[tokio::test]
async fn test_hash_bcrypt() {
let builder = config::Config::builder()
.set_override("auth.methods", "")
.unwrap()
.set_override("database.connection", "dummy")
.unwrap();
let conf: Config = Config::try_from(builder).expect("can build a valid config");
assert!(hash_password(&conf, "abcdefg").await.is_ok());
}

#[tokio::test]
async fn test_roundtrip_bcrypt() {
let builder = config::Config::builder()
.set_override("auth.methods", "")
.unwrap()
.set_override("database.connection", "dummy")
.unwrap();
let conf: Config = Config::try_from(builder).expect("can build a valid config");
let hashed = hash_password(&conf, "abcdefg").await.unwrap();
assert!(verify_password(&conf, "abcdefg", hashed).await.unwrap());
}

#[tokio::test]
async fn test_roundtrip_bcrypt_longer_than_72() {
let builder = config::Config::builder()
.set_override("auth.methods", "")
.unwrap()
.set_override("database.connection", "dummy")
.unwrap();
let conf: Config = Config::try_from(builder).expect("can build a valid config");
assert!(hash_password(&conf, "abcdefg").is_ok());
let pass = Alphanumeric.sample_string(&mut rand::rng(), 80);
let hashed = hash_password(&conf, pass.clone()).await.unwrap();
assert!(verify_password(&conf, pass, hashed).await.unwrap());
}
}
Loading