diff --git a/Cargo.lock b/Cargo.lock index c6d8594146..3eac623a5b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -15,7 +15,7 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" dependencies = [ - "winapi", + "winapi 0.3.8", ] [[package]] @@ -24,17 +24,6 @@ version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85bb70cc08ec97ca5450e6eba421deeea5f172c0fc61f78b5357b2a8e8be195f" -[[package]] -name = "async-trait" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c4f3195085c36ea8d24d32b2f828d23296a9370a28aa39d111f6f16bef9f3b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "atty" version = "0.2.14" @@ -43,7 +32,7 @@ checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ "hermit-abi", "libc", - "winapi", + "winapi 0.3.8", ] [[package]] @@ -75,19 +64,34 @@ name = "casper-node" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", + "displaydoc", "either", "enum-iterator", "futures", + "hex_fmt", + "maplit", + "openssl", + "rmp-serde", "serde", + "serde-big-array", "smallvec", "structopt", + "thiserror", "tokio", + "tokio-openssl", + "tokio-serde", + "tokio-util", "toml", "tracing", "tracing-subscriber", ] +[[package]] +name = "cc" +version = "1.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "404b1fe4f65288577753b17e3b36a04596ee784493ec249bf81c7f2d2acd751c" + [[package]] name = "cfg-if" version = "0.1.10" @@ -120,6 +124,28 @@ dependencies = [ "vec_map", ] +[[package]] +name = "derivative" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb582b60359da160a9477ee80f15c8d784c477e69c217ef2cdd4169c24ea380f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "displaydoc" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6269d127174b18c665e683e23c2c55d3735fadbec4181c7c70b0450b764bfa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.5.3" @@ -152,6 +178,37 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "fuchsia-zircon" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" +dependencies = [ + "bitflags", + "fuchsia-zircon-sys", +] + +[[package]] +name = "fuchsia-zircon-sys" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" + [[package]] name = "futures" version = "0.3.5" @@ -265,12 +322,37 @@ dependencies = [ "libc", ] +[[package]] +name = "hex_fmt" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b07f60793ff0a4d9cef0f18e63b5357e06209987153a64648c972c1e5aff336f" + +[[package]] +name = "iovec" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e" +dependencies = [ + "libc", +] + [[package]] name = "itoa" version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8b7a7c0c47db5545ed3fef7468ee7bb5b74691498139e4b3f6a20685dc6dd8e" +[[package]] +name = "kernel32-sys" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d" +dependencies = [ + "winapi 0.2.8", + "winapi-build", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -292,6 +374,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matchers" version = "0.0.1" @@ -307,6 +395,48 @@ version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" +[[package]] +name = "mio" +version = "0.6.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fce347092656428bc8eaf6201042cb551b8d67855af7374542a92a0fbfcac430" +dependencies = [ + "cfg-if", + "fuchsia-zircon", + "fuchsia-zircon-sys", + "iovec", + "kernel32-sys", + "libc", + "log", + "miow", + "net2", + "slab", + "winapi 0.2.8", +] + +[[package]] +name = "miow" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919" +dependencies = [ + "kernel32-sys", + "net2", + "winapi 0.2.8", + "ws2_32-sys", +] + +[[package]] +name = "net2" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ba7c918ac76704fb42afcbbb43891e72731f3dcca3bef2a19786297baf14af7" +dependencies = [ + "cfg-if", + "libc", + "winapi 0.3.8", +] + [[package]] name = "num-integer" version = "0.1.42" @@ -342,6 +472,33 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b631f7e854af39a1739f401cf34a8a013dfe09eac4fa4dba91e9768bd28168d" +[[package]] +name = "openssl" +version = "0.10.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee6d85f4cb4c4f59a6a85d5b68a233d280c82e29e822913b9c8b129fbf20bdd" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "lazy_static", + "libc", + "openssl-sys", +] + +[[package]] +name = "openssl-sys" +version = "0.9.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f02309a7f127000ed50594f0b50ecc69e7c654e16d41b4e8156d1b3df8e0b52e" +dependencies = [ + "autocfg", + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "pin-project" version = "0.4.16" @@ -374,6 +531,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05da548ad6865900e60eaba7f589cc0783590a92e940c26953ff81ddbab2d677" + [[package]] name = "proc-macro-error" version = "1.0.2" @@ -458,6 +621,27 @@ version = "0.6.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" +[[package]] +name = "rmp" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f10b46df14cf1ee1ac7baa4d2fbc2c52c0622a4b82fa8740e37bc452ac0184f" +dependencies = [ + "byteorder", + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c1ee98f14fe8b8e9c5ea13d25da7b2a1796169202c57a09d7288de90d56222b" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "ryu" version = "1.0.4" @@ -473,6 +657,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-big-array" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52309f7932ab258e58bcf73cc89037e307ffef3bcfb7ce7a246580c26f81dc55" +dependencies = [ + "serde", + "serde_derive", +] + [[package]] name = "serde_derive" version = "1.0.110" @@ -577,6 +771,26 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "thiserror" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5976891d6950b4f68477850b5b9e5aa64d955961466f9e174363f573e54e8ca7" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab81dbd1cd69cd2ce22ecfbdd3bdb73334ba25350649408cc6c085f46d89573d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.0.1" @@ -593,7 +807,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" dependencies = [ "libc", - "winapi", + "winapi 0.3.8", ] [[package]] @@ -604,6 +818,10 @@ checksum = "d099fa27b9702bed751524694adbe393e18b36b204da91eb1cbbbbb4a5ee2d58" dependencies = [ "bytes", "fnv", + "futures-core", + "iovec", + "lazy_static", + "mio", "num_cpus", "pin-project-lite", "tokio-macros", @@ -620,6 +838,44 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-openssl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c4b08c5f4208e699ede3df2520aca2e82401b2de33f45e96696a074480be594" +dependencies = [ + "openssl", + "tokio", +] + +[[package]] +name = "tokio-serde" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebdd897b01021779294eb09bb3b52b6e11b0747f9f7e333a84bef532b656de99" +dependencies = [ + "bytes", + "derivative", + "futures", + "pin-project", + "rmp-serde", + "serde", +] + +[[package]] +name = "tokio-util" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be8242891f2b6cbef26a2d7e8605133c2c554cd35b3e4948ea892d6d68436499" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "log", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.5.6" @@ -719,6 +975,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" +[[package]] +name = "vcpkg" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fc439f2794e98976c88a2a2dafce96b930fe8010b0a256b3c2199a773933168" + [[package]] name = "vec_map" version = "0.8.2" @@ -731,6 +993,12 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "078775d0255232fb988e6fccf26ddc9d1ac274299aaedcedce21c6f72cc533ce" +[[package]] +name = "winapi" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a" + [[package]] name = "winapi" version = "0.3.8" @@ -741,6 +1009,12 @@ dependencies = [ "winapi-x86_64-pc-windows-gnu", ] +[[package]] +name = "winapi-build" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc" + [[package]] name = "winapi-i686-pc-windows-gnu" version = "0.4.0" @@ -752,3 +1026,13 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "ws2_32-sys" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59cefebd0c892fa2dd6de581e937301d8552cb44489cdff035c6187cb63fa5e" +dependencies = [ + "winapi 0.2.8", + "winapi-build", +] diff --git a/Cargo.toml b/Cargo.toml index 6c0e7ca4da..18c428eb72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,23 +1,31 @@ [package] name = "casper-node" version = "0.1.0" -authors = ["Marc Brinkmann ", - "Fraser Hutchison "] +authors = ["Marc Brinkmann ", "Fraser Hutchison "] edition = "2018" -description = "The CasperLabs blockchain node server" +description = "The CasperLabs blockchain node" publish = false # Prevent accidental `cargo publish` for now. license-file = "LICENSE" [dependencies] anyhow = "1.0.28" -async-trait = "0.1.31" +displaydoc = "0.1.6" either = "1.5.3" enum-iterator = "0.6.0" futures = "0.3.5" +hex_fmt = "0.3.0" +maplit = "1.0.2" +openssl = "0.10.29" +rmp-serde = "0.14.3" serde = { version = "1.0.110", features = ["derive"] } +serde-big-array = "0.3.0" smallvec = "1.4.0" structopt = "0.3.14" -tokio = { version = "0.2.20", features = ["macros", "rt-threaded", "sync"] } +thiserror = "1.0.18" +tokio = { version = "0.2.20", features = ["macros", "rt-threaded", "sync", "tcp"] } +tokio-openssl = "0.4.0" +tokio-serde = { version = "0.6.1", features = ["messagepack"] } +tokio-util = { version = "0.3.1", features = ["codec"] } toml = "0.5.6" tracing = "0.1.14" tracing-subscriber = "0.2.5" diff --git a/README.md b/README.md index 733de2ba2a..0ead1ebf80 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,21 @@ -# CasperLabs node +# casper-node -The is the core application for the CasperLabs blockchain. - -## Building - -To compile this application, simply run `cargo build` on a recent stable Rust (`>= 1.43.1`) version. +This is the core application for the CasperLabs blockchain. ## Running a validator node -Launching a validator node with the default configuration is done by simply launching the application: +To run a validator node with the default configuration: ``` -casper-node validator +cargo run --release -- validator ``` It is very likely that the configuration requires editing though, so typically one will want to generate a configuration file first, edit it and then launch: ``` -casper-node generate-config > mynode.toml +cargo run --release -- generate-config > mynode.toml # ... edit mynode.toml -casper-node validator -c mynode.toml +cargo run --release -- validator --config=mynode.toml ``` ## Development @@ -30,4 +26,4 @@ A good starting point is to build the documentation and read it in your browser: cargo doc --no-deps --open ``` -When generating a configuration file, it is usually helpful to set the log-level to `DEBUG` during development. \ No newline at end of file +When generating a configuration file, it is usually helpful to set the log-level to `DEBUG` during development. diff --git a/images/CasperLabs_Logo_Favicon_RGB_50px.png b/images/CasperLabs_Logo_Favicon_RGB_50px.png new file mode 100644 index 0000000000..593254f7bd Binary files /dev/null and b/images/CasperLabs_Logo_Favicon_RGB_50px.png differ diff --git a/images/CasperLabs_Logo_Symbol_RGB.png b/images/CasperLabs_Logo_Symbol_RGB.png new file mode 100644 index 0000000000..fcc47b453c Binary files /dev/null and b/images/CasperLabs_Logo_Symbol_RGB.png differ diff --git a/src/cli.rs b/src/cli.rs index aea99a76a7..855fcdd9c4 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,18 +1,32 @@ //! Command-line option parsing. //! -//! Most configuration is done through the configuration, which is the only required command-line -//! argument. However some configuration values can be overwritten for convenience's sake. -use std::{io, io::Write, path}; +//! Most configuration is done via config files (see [`config`](../config/index.html) for details). + +use std::{io, io::Write, path::PathBuf}; + +use anyhow::bail; use structopt::StructOpt; +use tracing::Level; -use crate::{config, reactor}; +use crate::{ + config, + reactor::{self, validator::Reactor}, + tls, +}; // Note: The docstring on `Cli` is the help shown when calling the binary with `--help`. #[derive(Debug, StructOpt)] /// CasperLabs blockchain node. pub enum Cli { + /// Generate a self-signed node certificate. + GenerateCert { + /// Output path base of the certificate. The certificate will be stored as + /// `output.crt.pem`, while the key will be stored as `output.key.pem`. + output: PathBuf, + }, /// Generate a configuration file from defaults and dump it to stdout. GenerateConfig {}, + /// Run the validator node. /// /// Loads the configuration values from the given configuration file or uses defaults if not @@ -20,30 +34,54 @@ pub enum Cli { Validator { #[structopt(short, long, env)] /// Path to configuration file. - config: Option, + config: Option, + + /// Override log-level, forcing debug output. + #[structopt(short, long)] + debug: bool, }, } impl Cli { - /// Execute selected CLI command. + /// Executes selected CLI command. pub async fn run(self) -> anyhow::Result<()> { match self { + Cli::GenerateCert { output } => { + if output.file_name().is_none() { + bail!("not a valid output path"); + } + + let mut cert_path = output.clone(); + cert_path.set_extension("crt.pem"); + + let mut key_path = output; + key_path.set_extension("key.pem"); + + let (cert, key) = tls::generate_node_cert()?; + + tls::save_cert(&cert, cert_path)?; + tls::save_private_key(&key, key_path)?; + + Ok(()) + } Cli::GenerateConfig {} => { let cfg_str = config::to_string(&Default::default())?; io::stdout().write_all(cfg_str.as_bytes())?; Ok(()) } - Cli::Validator { config } => { + Cli::Validator { config, debug } => { // We load the specified config, if any, otherwise use defaults. - let cfg = config + let mut cfg = config .map(config::load_from_file) .transpose()? .unwrap_or_default(); - + if debug { + cfg.log.level = Level::DEBUG; + } cfg.log.setup_logging()?; - reactor::launch::(cfg).await + reactor::launch::(cfg).await } } } diff --git a/src/components.rs b/src/components.rs new file mode 100644 index 0000000000..37eb4946fa --- /dev/null +++ b/src/components.rs @@ -0,0 +1,5 @@ +//! Components +//! +//! Docs to be written, sorry. + +pub mod small_network; diff --git a/src/components/small_network.rs b/src/components/small_network.rs new file mode 100644 index 0000000000..0fc15729b1 --- /dev/null +++ b/src/components/small_network.rs @@ -0,0 +1,761 @@ +//! Fully connected overlay network +//! +//! The *small network* is an overlay network where each node participating is connected to every +//! other node on the network. The *small* portion of the name stems from the fact that this +//! approach is not scalable, as it requires at least $O(n)$ network connections and broadcast will +//! result in $O(n^2)$ messages. +//! +//! # Node IDs +//! +//! Each node has a self-generated node ID based on its self-signed TLS certificate. Whenever a +//! connection is made to another node, it verifies the "server"'s certificate to check that it +//! connected to the correct node and sends its own certificate during the TLS handshake, +//! establishing identity. +//! +//! # Messages and payloads +//! +//! The network itself is best-effort, during regular operation, no messages should be lost. A node +//! will attempt to reconnect when it loses a connection, however messages and broadcasts may be +//! lost during that time. +//! +//! # Connection +//! +//! Every node has an ID and a listening address. The objective of each node is to constantly +//! maintain an outgoing connection to each other node (and thus have an incoming connection from +//! these nodes as well). +//! +//! Any incoming connection is strictly read from, while any outgoing connection is strictly used +//! for sending messages. +//! +//! Nodes track the signed (timestamp, listening address, certificate) tuples called "endpoints" +//! internally and whenever they connecting to a new node, they share this state with the other +//! node, as well as notifying them about any updates they receive. +//! +//! # Joining the network +//! +//! When a node connects to any other network node, it sends its current list of endpoints down the +//! new outgoing connection. This will cause the receiving node to initiate a connection attempt to +//! all nodes in the list and simultaneously tell all of its connected nodes about the new node, +//! repeating the process. + +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, + fmt::{self, Debug, Display, Formatter}, + hash::Hash, + io, + net::{SocketAddr, TcpListener}, + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + +use anyhow::{anyhow, bail, Context}; +use futures::{ + stream::{SplitSink, SplitStream}, + FutureExt, SinkExt, StreamExt, +}; +use maplit::hashmap; +use openssl::{pkey, x509}; +use pkey::{PKey, Private}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio::{ + net::TcpStream, + sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, +}; +use tokio_openssl::SslStream; +use tokio_serde::{formats::SymmetricalMessagePack, SymmetricallyFramed}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use tracing::{debug, error, info, warn}; +use x509::X509; + +use crate::{ + config, + effect::{Effect, EffectExt, EffectResultExt}, + reactor::{EventQueueHandle, QueueKind, Reactor}, + tls::{self, KeyFingerprint, Signed, TlsCert}, + utils::{DisplayIter, Multiple}, +}; + +/// A node ID. +/// +/// The key fingerprint found on TLS certificates. +pub type NodeId = KeyFingerprint; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub enum Message

{ + /// A pruned set of all endpoint announcements the server has received. + Snapshot(HashSet>), + /// Broadcast a new endpoint known to the sender. + BroadcastEndpoint(Signed), + /// A payload message. + Payload(P), +} + +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct Endpoint { + /// UNIX timestamp in nanoseconds resolution. + /// + /// Will overflow earliest November 2262. + timestamp_ns: u64, + /// Socket address the node is listening on. + addr: SocketAddr, + /// Certificate. + cert: TlsCert, +} + +#[derive(Debug)] +pub enum Event

{ + /// Connection to the root node succeeded. + RootConnected { cert: TlsCert, transport: Transport }, + /// Connection to the root node failed. + RootFailed { error: anyhow::Error }, + /// A new TCP connection has been established from an incoming connection. + IncomingNew { stream: TcpStream, addr: SocketAddr }, + /// The TLS handshake completed on the incoming connection. + IncomingHandshakeCompleted { + result: anyhow::Result<(NodeId, Transport)>, + addr: SocketAddr, + }, + /// Received network message. + IncomingMessage { node_id: NodeId, msg: Message

}, + /// Incoming connection closed. + IncomingClosed { + result: io::Result<()>, + addr: SocketAddr, + }, + + /// A new outgoing connection was successfully established. + OutgoingEstablished { + node_id: NodeId, + transport: Transport, + }, + /// An outgoing connection failed to connect or was terminated. + OutgoingFailed { + node_id: NodeId, + attempt_count: u32, + error: Option, + }, +} + +pub struct SmallNetwork +where + R: Reactor, +{ + /// Configuration. + cfg: config::SmallNetwork, + /// Server certificate. + cert: Arc, + /// Server private key. + private_key: Arc>, + /// Handle to event queue. + eq: EventQueueHandle>, + /// A list of known endpoints by node ID. + endpoints: HashMap, + /// Stored signed endpoints that can be sent to other nodes. + signed_endpoints: HashMap>, + /// Outgoing network connections' messages. + outgoing: HashMap>>, +} + +impl SmallNetwork +where + R: Reactor + 'static, + P: Serialize + DeserializeOwned + Clone + Debug + Send + 'static, +{ + #[allow(clippy::type_complexity)] + pub fn new( + eq: EventQueueHandle>, + cfg: config::SmallNetwork, + ) -> anyhow::Result<(SmallNetwork, Multiple>>)> + where + R: Reactor + 'static, + { + // First, we load or generate the TLS keys. + let (cert, private_key) = match (&cfg.cert, &cfg.private_key) { + // We're given a cert_file and a private_key file. Just load them, additional checking + // will be performed once we create the acceptor and connector. + (Some(cert_file), Some(private_key_file)) => ( + tls::load_cert(cert_file).context("could not load TLS certificate")?, + tls::load_private_key(private_key_file) + .context("could not load TLS private key")?, + ), + + // Neither was passed, so we auto-generate a pair. + (None, None) => tls::generate_node_cert()?, + + // If we get only one of the two, return an error. + _ => bail!("need either both or none of cert, private_key in network config"), + }; + + // We can now create a listener. + let listener = create_listener(&cfg)?; + let addr = listener.local_addr()?; + + // Create the model. Initially we know our own endpoint address. + let our_endpoint = Endpoint { + timestamp_ns: SystemTime::now().duration_since(UNIX_EPOCH)?.as_nanos() as u64, + addr, + cert: tls::validate_cert(cert.clone())?, + }; + let our_fp = our_endpoint.cert.public_key_fingerprint(); + + // Run the server task. + info!(%our_endpoint, "starting server background task"); + let mut effects = server_task(eq, tokio::net::TcpListener::from_std(listener)?) + .boxed() + .ignore(); + let model = SmallNetwork { + cfg, + signed_endpoints: hashmap! { our_fp => Signed::new(&our_endpoint, &private_key)? }, + endpoints: hashmap! { our_fp => our_endpoint }, + cert: Arc::new(cert), + private_key: Arc::new(private_key), + eq, + outgoing: HashMap::new(), + }; + + // Connect to the root node (even if we are the root node, just loopback). + effects.extend(model.connect_to_root()); + + Ok((model, effects)) + } + + /// Attempts to connect to the root node. + fn connect_to_root(&self) -> Multiple>> { + connect_trusted( + self.cfg.root_addr, + self.cert.clone(), + self.private_key.clone(), + ) + .result( + move |(cert, transport)| Event::RootConnected { cert, transport }, + move |error| Event::RootFailed { error }, + ) + } + + #[allow(clippy::cognitive_complexity)] + pub fn handle_event(&mut self, ev: Event

) -> Multiple>> { + match ev { + Event::RootConnected { cert, transport } => { + // Create a pseudo-endpoint for the root node with the lowest priority (time 0) + let root_node_id = cert.public_key_fingerprint(); + let ep = Endpoint { + timestamp_ns: 0, + addr: self.cfg.root_addr, + cert, + }; + if self.endpoints.insert(root_node_id, ep).is_some() { + // This connection is the very first we will ever make, there should never be + // a root node registered, as we will never re-attempt this connection if it + // succeeded once. + error!("Encountered a second root node connection.") + } + + // We're now almost setup exactly as if the root node was any other node, proceed + // as normal. + self.setup_outgoing(root_node_id, transport) + } + Event::RootFailed { error } => { + warn!(%error, "connection to root failed"); + self.connect_to_root() + + // TODO: delay next attempt + } + Event::IncomingNew { stream, addr } => { + debug!(%addr, "Incoming connection, starting TLS handshake"); + + setup_tls(stream, self.cert.clone(), self.private_key.clone()) + .boxed() + .event(move |result| Event::IncomingHandshakeCompleted { result, addr }) + } + Event::IncomingHandshakeCompleted { result, addr } => { + match result { + Ok((fp, transport)) => { + // The sink is never used, as we only read data from incoming connections. + let (_sink, stream) = framed::

(transport).split(); + + message_reader(self.eq, stream, fp) + .event(move |result| Event::IncomingClosed { result, addr }) + } + Err(err) => { + warn!(%addr, %err, "TLS handshake failed"); + Multiple::new() + } + } + } + Event::IncomingMessage { node_id, msg } => self.handle_message(node_id, msg), + Event::IncomingClosed { result, addr } => { + match result { + Ok(()) => info!(%addr, "connection closed"), + Err(err) => warn!(%addr, %err, "connection dropped"), + } + Multiple::new() + } + Event::OutgoingEstablished { node_id, transport } => { + self.setup_outgoing(node_id, transport) + } + Event::OutgoingFailed { + node_id, + attempt_count, + error, + } => { + if let Some(err) = error { + warn!(%node_id, %err, "outgoing connection failed"); + } else { + warn!(%node_id, "outgoing connection closed"); + } + + if let Some(max) = self.cfg.max_outgoing_retries { + if attempt_count >= max { + // We're giving up connecting to the node. We will remove it completely + // (this only carries the danger of the stale addresses being sent to us by + // other nodes again). + self.endpoints.remove(&node_id); + self.signed_endpoints.remove(&node_id); + self.outgoing.remove(&node_id); + + warn!(%attempt_count, %node_id, "giving up on outgoing connection"); + } + + return Multiple::new(); + } + // TODO: Delay reconnection. + + if let Some(endpoint) = self.endpoints.get(&node_id) { + connect_outgoing( + endpoint.clone(), + self.cert.clone(), + self.private_key.clone(), + ) + .result( + move |transport| Event::OutgoingEstablished { node_id, transport }, + move |error| Event::OutgoingFailed { + node_id, + attempt_count: attempt_count + 1, + error: Some(error), + }, + ) + } else { + error!("endpoint disappeared"); + Multiple::new() + } + } + } + } + + /// Queues a message to be sent to all nodes. + fn broadcast_message(&self, msg: Message

) { + for node_id in self.outgoing.keys() { + self.send_message(*node_id, msg.clone()); + } + } + + /// Queues a message to be sent to a specific node. + fn send_message(&self, dest: NodeId, msg: Message

) { + // Try to send the message. + if let Some(sender) = self.outgoing.get(&dest) { + if let Err(msg) = sender.send(msg) { + // We lost the connection, but that fact has not reached us yet. + warn!(%dest, ?msg, "dropped outgoing message, lost connection"); + } + } else { + // We are not connected, so the reconnection is likely already in progress. + warn!(%dest, ?msg, "dropped outgoing message, no connection"); + } + } + + /// Updates the internal endpoint store from a given endpoint. + /// + /// Returns the node ID of the endpoint if it was new. + #[inline] + fn update_endpoint(&mut self, endpoint: &Endpoint) -> Option { + let fp = endpoint.cert.public_key_fingerprint(); + + if let Some(prev) = self.endpoints.get(&fp) { + if prev >= endpoint { + // Still up to date or stale, do nothing. + return None; + } + } + + self.endpoints.insert(fp, endpoint.clone()); + Some(fp) + } + + /// Updates internal endpoint store and if new, output a `BroadcastEndpoint` effect. + #[inline] + fn update_and_broadcast_if_new( + &mut self, + signed: Signed, + ) -> Multiple>> { + match signed.validate_self_signed(|endpoint| Ok(endpoint.cert.public_key())) { + Ok(endpoint) => { + // Endpoint is valid, check if it was new. + if let Some(node_id) = self.update_endpoint(&endpoint) { + debug!("new endpoint {}", endpoint); + // We learned of a new endpoint. We store it and note whether it is the first + // endpoint for the node. + self.signed_endpoints.insert(node_id, signed.clone()); + self.endpoints.insert(node_id, endpoint.clone()); + + let effect = if self.outgoing.remove(&node_id).is_none() { + info!(%node_id, ?endpoint, "new outgoing channel"); + // Initiate the connection process once we learn of a new node ID. + connect_outgoing(endpoint, self.cert.clone(), self.private_key.clone()) + .result( + move |transport| Event::OutgoingEstablished { node_id, transport }, + move |error| Event::OutgoingFailed { + node_id, + attempt_count: 0, + error: Some(error), + }, + ) + } else { + // There was a previous endpoint, whose sender has now been dropped. This + // will cause the sender task to exit and trigger a reconnect. + + info!(%endpoint, "endpoint changed"); + Multiple::new() + }; + + self.broadcast_message(Message::BroadcastEndpoint(signed)); + + effect + } else { + debug!("known endpoint: {}", endpoint); + Multiple::new() + } + } + Err(err) => { + warn!(%err, ?signed, "received invalid endpoint"); + Multiple::new() + } + } + } + + /// Sets up an established outgoing connection. + fn setup_outgoing( + &mut self, + node_id: NodeId, + transport: Transport, + ) -> Multiple>> { + // This connection is send-only, we only use the sink. + let (sink, _stream) = framed::

(transport).split(); + + let (sender, receiver) = mpsc::unbounded_channel(); + if self.outgoing.insert(node_id, sender).is_some() { + // We assume that for a reconnect to have happened, the outgoing entry must have + // been either non-existent yet or cleaned up by the handler of the connection + // closing event. If this is not the case, an assumed invariant has been violated. + error!(%node_id, "did not expect leftover channel in outgoing map"); + } + + // We can now send a snapshot. + let snapshot = Message::Snapshot(self.signed_endpoints.values().cloned().collect()); + self.send_message(node_id, snapshot); + + message_sender(receiver, sink).event(move |result| Event::OutgoingFailed { + node_id, + attempt_count: 0, // reset to 0, since we have had a successful connection + error: result.err().map(Into::into), + }) + } + + /// Handles a received message. + // Internal function to keep indentation and nesting sane. + fn handle_message(&mut self, node_id: NodeId, msg: Message

) -> Multiple>> { + match msg { + Message::Snapshot(snapshot) => snapshot + .into_iter() + .map(|signed| self.update_and_broadcast_if_new(signed)) + .flatten() + .collect(), + Message::BroadcastEndpoint(signed) => self.update_and_broadcast_if_new(signed), + Message::Payload(payload) => { + // We received a message payload. + warn!( + %node_id, + ?payload, + "received message payload, but no implementation for what comes next" + ); + Multiple::new() + } + } + } +} + +/// Determines bind address for now. +/// +/// Will attempt to bind on the root address first if the `bind_interface` is the same as the +/// interface of `root_addr`. Otherwise uses an unused port on `bind_interface`. +fn create_listener(cfg: &config::SmallNetwork) -> io::Result { + if cfg.root_addr.ip() == cfg.bind_interface { + // Try to become the root node, if the root nodes interface is available. + match TcpListener::bind(cfg.root_addr) { + Ok(listener) => { + info!("we are the root node!"); + return Ok(listener); + } + Err(err) => { + warn!( + %err, + "could not bind to {}, will become a non-root node", cfg.root_addr + ); + } + }; + } + + // We did not become the root node, bind on random port. + Ok(TcpListener::bind((cfg.bind_interface, 0u16))?) +} + +/// Core accept loop for the networking server. +/// +/// Never terminates. +async fn server_task( + eq: EventQueueHandle>, + mut listener: tokio::net::TcpListener, +) { + loop { + // We handle accept errors here, since they can be caused by a temporary resource shortage + // or the remote side closing the connection while it is waiting in the queue. + match listener.accept().await { + Ok((stream, addr)) => { + // Move the incoming connection to the event queue for handling. + let ev = Event::IncomingNew { stream, addr }; + eq.schedule(ev, QueueKind::NetworkIncoming).await; + } + Err(err) => warn!(%err, "dropping incoming connection during accept"), + } + } +} + +/// Server-side TLS handshake. +/// +/// This function groups the TLS handshake into a convenient function, enabling the `?` operator. +async fn setup_tls( + stream: TcpStream, + cert: Arc, + private_key: Arc>, +) -> anyhow::Result<(NodeId, Transport)> { + let tls_stream = tokio_openssl::accept( + &tls::create_tls_acceptor(&cert.as_ref(), &private_key.as_ref())?, + stream, + ) + .await?; + + // We can now verify the certificate. + let peer_cert = tls_stream + .ssl() + .peer_certificate() + .ok_or_else(|| anyhow!("no peer certificate presented"))?; + + Ok(( + tls::validate_cert(peer_cert)?.public_key_fingerprint(), + tls_stream, + )) +} + +/// Network message reader. +/// +/// Schedules all received messages until the stream is closed or an error occurs. +async fn message_reader( + eq: EventQueueHandle>, + mut stream: SplitStream>, + node_id: NodeId, +) -> io::Result<()> +where + R: Reactor, + P: DeserializeOwned + Send, +{ + while let Some(msg_result) = stream.next().await { + match msg_result { + Ok(msg) => { + // We've received a message, push it to the reactor. + eq.schedule( + Event::IncomingMessage { node_id, msg }, + QueueKind::NetworkIncoming, + ) + .await; + } + Err(err) => { + warn!(%err, "receiving message failed, closing connection"); + return Err(err); + } + } + } + Ok(()) +} + +/// Network message sender. +/// +/// Reads from a channel and sends all messages, until the stream is closed or an error occurs. +async fn message_sender

( + mut queue: UnboundedReceiver>, + mut sink: SplitSink, Message

>, +) -> io::Result<()> +where + P: Serialize + Send, +{ + while let Some(payload) = queue.recv().await { + // We simply error-out if the sink fails, it means that our connection broke. + sink.send(payload).await?; + } + + Ok(()) +} + +/// Transport type alias for base encrypted connections. +type Transport = SslStream; + +/// A framed transport for `Message`s. +type FramedTransport

= SymmetricallyFramed< + Framed, + Message

, + SymmetricalMessagePack>, +>; + +/// Constructs a new framed transport on a stream. +fn framed

(stream: Transport) -> FramedTransport

{ + let length_delimited = Framed::new(stream, LengthDelimitedCodec::new()); + SymmetricallyFramed::new( + length_delimited, + SymmetricalMessagePack::>::default(), + ) +} + +/// Initiates a TLS connection to an endpoint. +async fn connect_outgoing( + endpoint: Endpoint, + cert: Arc, + private_key: Arc>, +) -> anyhow::Result { + let (server_cert, transport) = connect_trusted(endpoint.addr, cert, private_key).await?; + + let remote_id = server_cert.public_key_fingerprint(); + + if remote_id != endpoint.cert.public_key_fingerprint() { + bail!("remote node has wrong ID"); + } + + Ok(transport) +} + +/// Initiates a TLS connection to a remote address, regardless of what ID the remote node reports. +async fn connect_trusted( + addr: SocketAddr, + cert: Arc, + private_key: Arc>, +) -> anyhow::Result<(TlsCert, Transport)> { + let mut config = tls::create_tls_connector(&cert, &private_key) + .context("could not create TLS connector")? + .configure()?; + config.set_verify_hostname(false); + + let stream = tokio::net::TcpStream::connect(addr) + .await + .context("TCP connection failed")?; + + let tls_stream = tokio_openssl::connect(config, "this-will-not-be-checked.example.com", stream) + .await + .context("tls handshake failed")?; + + let server_cert = tls_stream + .ssl() + .peer_certificate() + .ok_or_else(|| anyhow!("no server certificate presented"))?; + Ok((tls::validate_cert(server_cert)?, tls_stream)) +} + +// Impose a total ordering on endpoints. Compare timestamps first, if the same, order by actual +// address. If both of these are the same, use the TLS certificate's fingerprint as a tie-breaker. +impl Ord for Endpoint { + fn cmp(&self, other: &Self) -> Ordering { + Ord::cmp(&self.timestamp_ns, &other.timestamp_ns) + .then_with(|| { + Ord::cmp( + &(self.addr.ip(), self.addr.port()), + &(other.addr.ip(), other.addr.port()), + ) + }) + .then_with(|| Ord::cmp(&self.cert.fingerprint(), &other.cert.fingerprint())) + } +} +impl PartialOrd for Endpoint { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Display for Message

{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Message::Snapshot(snapshot) => { + write!(f, "snapshot: {:10}", DisplayIter::new(snapshot.iter())) + } + Message::BroadcastEndpoint(endpoint) => write!(f, "broadcast endpoint: {}", endpoint), + Message::Payload(payload) => write!(f, "payload: {}", payload), + } + } +} + +impl Display for Event

{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Event::RootConnected { cert, .. } => { + write!(f, "root connected @ {}", cert.public_key_fingerprint()) + } + Event::RootFailed { error } => write!(f, "root failed: {}", error), + Event::IncomingNew { addr, .. } => write!(f, "incoming connection from {}", addr), + Event::IncomingHandshakeCompleted { result, addr } => { + write!(f, "handshake from {}, is_err {}", addr, result.is_err()) + } + Event::IncomingMessage { node_id, msg } => write!(f, "msg from {}: {}", node_id, msg), + Event::IncomingClosed { addr, .. } => write!(f, "closed connection from {}", addr), + Event::OutgoingEstablished { node_id, .. } => { + write!(f, "established outgoing to {}", node_id) + } + Event::OutgoingFailed { + node_id, + attempt_count, + error, + } => write!( + f, + "failed outgoing {} [{}]: (is_err {})", + node_id, + attempt_count, + error.is_some() + ), + } + } +} + +impl Display for Endpoint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "{}@{} [{}]", + self.cert.public_key_fingerprint(), + self.addr, + self.timestamp_ns + ) + } +} + +impl Debug for SmallNetwork +where + R: Reactor, + P: Debug, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("SmallNetwork") + .field("cert", &"") + .field("private_key", &"") + .field("eq", &"") + .field("endpoints", &self.endpoints) + .field("signed_endpoints", &self.signed_endpoints) + .field("outgoing", &self.outgoing) + .finish() + } +} diff --git a/src/config.rs b/src/config.rs index 45d452b71c..aabb81c848 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,28 +1,42 @@ -//! Configuration file management +//! Configuration file management. //! -//! Configuration for the node is loaded from TOML files, but all configuration values have -//! sensible defaults. +//! Configuration for the node is loaded from TOML files, but all configuration values have sensible +//! defaults. //! -//! The `cli` offers an option to generate a configuration from defaults for editing. +//! The [`Cli`](../cli/enum.Cli.html#variant.GenerateConfig) offers an option to generate a +//! configuration from defaults for editing. I.e. running the following will dump a default +//! configuration file to stdout: +//! ``` +//! cargo run --release -- generate-config +//! ``` //! //! # Adding a configuration section //! //! When adding a section to the configuration, ensure that //! -//! * it has an entry in the root configuration `Config`, +//! * it has an entry in the root configuration [`Config`](struct.Config.html), //! * `Default` is implemented (derived or manually) with sensible defaults, and //! * it is completely documented. +use std::{ + fs, io, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::{Path, PathBuf}, +}; + use anyhow::Context; use serde::{Deserialize, Serialize}; -use std::{fs, io, path}; -use tracing::debug; +use tracing::{debug, Level}; /// Root configuration. -#[derive(Debug, Default, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct Config { /// Log configuration. pub log: Log, + /// Network configuration for the validator-only network. + pub validator_net: SmallNetwork, + /// Network configuration for the public network. + pub public_net: SmallNetwork, } /// Log configuration. @@ -30,19 +44,64 @@ pub struct Config { pub struct Log { /// Log level. #[serde(with = "log_level")] - pub level: tracing::Level, + pub level: Level, } -impl Default for Log { +#[derive(Debug, Deserialize, Serialize)] +/// Small network configuration +pub struct SmallNetwork { + /// Interface to bind to. If it is the same as the in `root_addr`, attempt + /// become the root node for this particular small network. + pub bind_interface: IpAddr, + + /// Port to bind to when not the root node. Use 0 for a random port. + pub bind_port: u16, + + /// Address to connect to join the network. + pub root_addr: SocketAddr, + + /// Path to certificate file. + pub cert: Option, + + /// Path to private key for certificate. + pub private_key: Option, + + /// Maximum number of retries when trying to connect to an outgoing node. Unlimited if `None`. + pub max_outgoing_retries: Option, +} + +impl SmallNetwork { + /// Creates a default instance for `SmallNetwork` with a constant port. + fn default_on_port(port: u16) -> Self { + SmallNetwork { + bind_interface: Ipv4Addr::new(127, 0, 0, 1).into(), + bind_port: 0, + root_addr: (Ipv4Addr::new(127, 0, 0, 1), port).into(), + cert: None, + private_key: None, + max_outgoing_retries: None, + } + } +} + +impl Default for Config { fn default() -> Self { - Log { - level: tracing::Level::INFO, + Config { + log: Default::default(), + validator_net: SmallNetwork::default_on_port(34553), + public_net: SmallNetwork::default_on_port(1485), } } } +impl Default for Log { + fn default() -> Self { + Log { level: Level::INFO } + } +} + impl Log { - /// Initialize logging system based on settings in configuration. + /// Initializes logging system based on settings in configuration. /// /// Will setup logging as described in this configuration for the whole application. This /// function should only be called once during the lifetime of the application. @@ -61,7 +120,7 @@ impl Log { } /// Loads a TOML-formatted configuration from a given file. -pub fn load_from_file>(config_path: P) -> anyhow::Result { +pub fn load_from_file>(config_path: P) -> anyhow::Result { let path_ref = config_path.as_ref(); Ok(toml::from_str( &fs::read_to_string(path_ref) @@ -70,29 +129,31 @@ pub fn load_from_file>(config_path: P) -> anyhow::Result anyhow::Result { toml::to_string_pretty(cfg).with_context(|| "Failed to serialize default configuration") } /// Serialization/deserialization mod log_level { - use serde::{self, Deserialize}; use std::str::FromStr; + + use serde::{self, de::Error, Deserialize, Deserializer, Serializer}; use tracing::Level; pub fn serialize(value: &Level, serializer: S) -> Result where - S: serde::Serializer, + S: Serializer, { serializer.serialize_str(value.to_string().as_str()) } pub fn deserialize<'de, D>(deserializer: D) -> Result where - D: serde::Deserializer<'de>, + D: Deserializer<'de>, { let s = String::deserialize(deserializer)?; - Level::from_str(s.as_str()).map_err(serde::de::Error::custom) + + Level::from_str(s.as_str()).map_err(Error::custom) } } diff --git a/src/effect.rs b/src/effect.rs index 230ab9432d..5098e78b40 100644 --- a/src/effect.rs +++ b/src/effect.rs @@ -1,32 +1,33 @@ //! Effects subsystem. //! -//! Effects describe things that the creator of the effect intends to happen, -//! producing a value upon completion. They are, in fact, futures. +//! Effects describe things that the creator of the effect intends to happen, producing a value upon +//! completion. They are, in fact, futures. //! //! A boxed, pinned future returning an event is called an effect and typed as an `Effect`, //! where `Ev` is the event's type. //! //! ## Using effects //! -//! To create an effect, an events factory is used that implements one or more of the factory -//! traits of this module. For example, given an events factory `eff`, we can create a +//! To create an effect, an events factory is used that implements one or more of the factory traits +//! of this module. For example, given an events factory `events_factory`, we can create a //! `set_timeout` future and turn it into an effect: //! //! ``` -//! # use std::time; +//! use std::time::Duration; //! use crate::effect::EffectExt; //! //! enum Event { -//! ThreeSecondsElapsed(time::Duration) +//! ThreeSecondsElapsed(Duration) //! } //! -//! eff.set_timeout(time::Duration::from_secs(3)) -//! .event(Event::ThreeSecondsElapsed) +//! events_factory +//! .set_timeout(Duration::from_secs(3)) +//! .event(Event::ThreeSecondsElapsed); //! ``` //! //! This example will produce an effect that, after three seconds, creates an //! `Event::ThreeSecondsElapsed`. Note that effects do nothing on their own, they need to be passed -//! to the `Reactor` (see `reactor` module) to be executed. +//! to a [`reactor`](../reactor/index.html) to be executed. //! //! ## Chaining futures and effects //! @@ -37,24 +38,32 @@ //! It is possible to create an effect from multiple effects being run in parallel using `.also`: //! //! ``` -//! # use std::time; +//! use std::time::Duration; //! use crate::effect::{EffectExt, EffectAlso}; //! //! enum Event { -//! ThreeSecondsElapsed(time::Duration), -//! FiveSecondsElapsed(time::Duration), +//! ThreeSecondsElapsed(Duration), +//! FiveSecondsElapsed(Duration), //! } //! //! // This effect produces a single event after five seconds: -//! eff.set_timeout(time::Duration::from_secs(3)) -//! .then(|_| eff.set_timeout(time::Duration::from_secs(2)) -//! .event(Event::FiveSecondsElapsed); +//! events_factory +//! .set_timeout(Duration::from_secs(3)) +//! .then(|_| { +//! events_factory +//! .set_timeout(Duration::from_secs(2)) +//! .event(Event::FiveSecondsElapsed) +//! }); //! //! // Here, two effects are run in parallel, resulting in two events: -//! eff.set_timeout(time::Duration::from_secs(3)) -//! .event(Event::ThreeSecondsElapsed) -//! .also(eff.set_timeout(time::Duration::from_secs(5)) -//! .event(Event::FiveSecondsElapsed)); +//! events_factory +//! .set_timeout(Duration::from_secs(3)) +//! .event(Event::ThreeSecondsElapsed) +//! .also( +//! events_factory +//! .set_timeout(Duration::from_secs(5)) +//! .event(Event::FiveSecondsElapsed), +//! ); //! ``` //! //! ## Arbitrary effects @@ -63,12 +72,12 @@ //! the effects explicitly listed in this module through traits to create them. Post-processing on //! effects to turn them into events should also be kept brief. -use crate::util::Multiple; -use futures::future::BoxFuture; -use futures::FutureExt; +use std::{future::Future, time::Duration}; + +use futures::{future::BoxFuture, FutureExt}; use smallvec::smallvec; -use std::future::Future; -use std::time; + +use crate::utils::Multiple; /// Effect type. /// @@ -79,9 +88,9 @@ pub type Effect = BoxFuture<'static, Multiple>; /// /// Used to convert futures into actual effects. pub trait EffectExt: Future + Send { - /// Finalize a future into an effect that returns an event. + /// Finalizes a future into an effect that returns an event. /// - /// The passed in function `f` is used to translate the resulting value from an effect into + /// The function `f` is used to translate the returned value from an effect into an event. fn event(self, f: F) -> Multiple> where F: FnOnce(Self::Output) -> U + 'static + Send, @@ -92,6 +101,21 @@ pub trait EffectExt: Future + Send { fn ignore(self) -> Multiple>; } +pub trait EffectResultExt { + type Value; + type Error; + + /// Finalizes a future returning a `Result` into two different effects. + /// + /// The function `f` is used to translate the returned value from an effect into an event, while + /// the function `g` does the same for a potential error. + fn result(self, f_ok: F, f_err: G) -> Multiple> + where + F: FnOnce(Self::Value) -> U + 'static + Send, + G: FnOnce(Self::Error) -> U + 'static + Send, + U: 'static; +} + impl EffectExt for T where T: Future + Send + 'static + Sized, @@ -109,17 +133,37 @@ where } } +impl EffectResultExt for T +where + T: Future> + Send + 'static + Sized, +{ + type Value = V; + type Error = E; + + fn result(self, f_ok: F, f_err: G) -> Multiple> + where + F: FnOnce(V) -> U + 'static + Send, + G: FnOnce(E) -> U + 'static + Send, + U: 'static, + { + smallvec![self + .map(|result| result.map_or_else(f_err, f_ok)) + .map(|item| smallvec![item]) + .boxed()] + } +} + /// Core effects. pub trait Core { - /// Do not do anything. + /// Immediately completes without doing anything. /// - /// Immediately completes, can be used to trigger an event. + /// Can be used to trigger an event. fn immediately(self) -> BoxFuture<'static, ()>; - /// Set a timeout. + /// Sets a timeout. /// /// Once the timeout fires, it will return the actual elapsed time since the execution (not /// creation!) of this effect. Event loops typically execute effects right after a called event /// handling function completes. - fn set_timeout(self, timeout: time::Duration) -> BoxFuture<'static, time::Duration>; + fn set_timeout(self, timeout: Duration) -> BoxFuture<'static, Duration>; } diff --git a/src/main.rs b/src/main.rs index 86864a5555..344231339e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,15 +9,30 @@ //! its core event loop is found inside the [reactor](reactor/index.html). To get a tour of the //! sourcecode, be sure to run `cargo doc --open`. +#![doc( + html_favicon_url = "https://raw.githubusercontent.com/CasperLabs/casper-node/master/images/CasperLabs_Logo_Favicon_RGB_50px.png", + html_logo_url = "https://raw.githubusercontent.com/CasperLabs/casper-node/master/images/CasperLabs_Logo_Symbol_RGB.png", + test(attr(forbid(warnings))) +)] +#![warn( + missing_docs, + trivial_casts, + trivial_numeric_casts, + // unreachable_pub, + unused_qualifications +)] + mod cli; +mod components; mod config; mod effect; mod reactor; -mod util; +mod tls; +mod utils; use structopt::StructOpt; -/// Parse [command-line arguments](cli/index.html) and run application. +/// Parses [command-line arguments](cli/index.html) and run application. #[tokio::main] pub async fn main() -> anyhow::Result<()> { // Parse CLI args and run selected subcommand. diff --git a/src/reactor.rs b/src/reactor.rs index 4628f2051d..b0b18f0b20 100644 --- a/src/reactor.rs +++ b/src/reactor.rs @@ -2,104 +2,104 @@ //! //! Any long running instance of the node application uses an event-dispatch pattern: Events are //! generated and stored on an event queue, then processed one-by-one. This process happens inside -//! the *reactor*, which also exclusively holds the state of the application: +//! the reactor*, which also exclusively holds the state of the application besides pending events: //! -//! 1. The reactor pops an event off of the queue. +//! 1. The reactor pops an event off the event queue (called a [`Scheduler`](type.Scheduler.html)). //! 2. The event is dispatched by the reactor. Since the reactor holds mutable state, it can grant //! any component that processes an event mutable, exclusive access to its state. //! 3. Once the (synchronous) event processing has completed, the component returns an effect. -//! 4. The reactor spawns a task that executes these effects and eventually puts an event onto the -//! event queue. +//! 4. The reactor spawns a task that executes these effects and eventually schedules another event. //! 5. meanwhile go to 1. //! //! # Reactors //! -//! There no single reactor, but a reactor for each application type, since it defines which -//! components are used and how they are wired up. The reactor defines the state by being a `struct` -//! of components, their initialization through the `Reactor::new` and a function to `dispatch` -//! events to components. +//! There is no single reactor, but rather a reactor for each application type, since it defines +//! which components are used and how they are wired up. The reactor defines the state by being a +//! `struct` of components, their initialization through the +//! [`Reactor::new()`](trait.Reactor.html#tymethod.new) and a method +//! [`Reactor::dispatch_event()`](trait.Reactor.html#tymethod.dispatch_event) to dispatch events to +//! components. //! -//! With all these set up, a reactor can be `launch`ed, causing it to run indefinitely, processing -//! events. +//! With all these set up, a reactor can be [`launch`](fn.launch.html)ed, causing it to run +//! indefinitely, processing events. pub mod non_validator; mod queue_kind; pub mod validator; -use crate::util::Multiple; -use crate::{config, effect, util}; -use async_trait::async_trait; use std::{fmt, mem}; -use tracing::{debug, info, warn}; -pub use queue_kind::Queue; +use futures::FutureExt; +use tracing::{debug, info, trace, warn}; -/// Event queue handle +use crate::{ + config, + effect::Effect, + utils::{self, Multiple, WeightedRoundRobin}, +}; +pub use queue_kind::QueueKind; + +/// Event scheduler +/// +/// The scheduler is a combination of multiple event queues that are polled in a specific order. It +/// is the central hook for any part of the program that schedules events directly. +/// +/// Components rarely use this, but use a bound `EventQueueHandle` instead. +pub type Scheduler = WeightedRoundRobin; + +/// Bound event queue handle /// /// The event queue handle is how almost all parts of the application interact with the reactor /// outside of the normal event loop. It gives different parts a chance to schedule messages that /// stem from things like external IO. /// -/// It is also possible to schedule new events by directly processing effects. This allows re-use of -/// the existing code for handling particular effects, as adding events directly should be a matter -/// of last resort. +/// Every event queue handle allows scheduling events of type `Ev` onto a reactor `R`. For this it +/// carries with it a reference to a wrapper function that maps an `Ev` to a `Reactor::Event`. #[derive(Debug)] -pub struct EventQueueHandle(&'static util::round_robin::WeightedRoundRobin); +pub struct EventQueueHandle +where + R: Reactor, +{ + /// The scheduler events will be scheduled on. + scheduler: &'static Scheduler<::Event>, + /// A wrapper function translating from component event (input of `W`) to reactor event `Ev`. + wrapper: fn(Ev) -> R::Event, +} -// Copy and Clone need to be implemented manually, since `Ev` prevents derivation. -impl Copy for EventQueueHandle {} -impl Clone for EventQueueHandle { +// Implement `Clone` and `Copy` manually, as `derive` will make it depend on `R` and `Ev` otherwise. +impl Clone for EventQueueHandle +where + R: Reactor, +{ fn clone(&self) -> Self { - EventQueueHandle(self.0) + EventQueueHandle { + scheduler: self.scheduler, + wrapper: self.wrapper, + } } } +impl Copy for EventQueueHandle where R: Reactor {} -impl EventQueueHandle +impl EventQueueHandle where - Ev: Send + 'static, + R: Reactor, { - /// Create a new event queue handle. - fn new(round_robin: &'static util::round_robin::WeightedRoundRobin) -> Self { - EventQueueHandle(round_robin) - } - - /// Return the next event in the queue - /// - /// Awaits until there is an event, then returns it. - #[inline] - async fn next_event(self) -> (Ev, Queue) { - self.0.pop().await - } - - /// Process an effect. - /// - /// Spawns tasks that will process the given effects. - #[inline] - pub fn process_effects(self, effects: Multiple>) { - let eq = self; - // TODO: Properly carry around priorities. - let queue = Default::default(); - - for effect in effects { - tokio::spawn(async move { - for event in effect.await { - eq.schedule(event, queue).await; - } - }); - } + /// Creates a new event queue handle with an associated wrapper function. + fn bind(scheduler: &'static Scheduler, wrapper: fn(Ev) -> R::Event) -> Self { + EventQueueHandle { scheduler, wrapper } } - /// Schedule an event in the given queue. + /// Schedule an event on a specific queue. #[inline] - pub async fn schedule(self, event: Ev, queue_kind: Queue) { - self.0.push(event, queue_kind).await + pub async fn schedule(self, event: Ev, queue_kind: QueueKind) { + self.scheduler.push((self.wrapper)(event), queue_kind).await } } /// Reactor core. /// -/// Any reactor implements should implement this trait and be launched by the `launch` function. -#[async_trait] +/// Any reactor should implement this trait and be launched by the [`launch`](fn.launch.html) +/// function. pub trait Reactor: Sized { // Note: We've gone for the `Sized` bound here, since we return an instance in `new`. As an // alternative, `new` could return a boxed instance instead, removing this requirement. @@ -107,30 +107,30 @@ pub trait Reactor: Sized { /// Event type associated with reactor. /// /// Defines what kind of event the reactor processes. - type Event: Send + fmt::Debug + 'static; + type Event: Send + fmt::Debug + fmt::Display + 'static; - /// Dispatch an event on the reactor. + /// Dispatches an event on the reactor. /// /// This function is typically only called by the reactor itself to dispatch an event. It is /// safe to call regardless, but will cause the event to skip the queue and things like /// accounting. - fn dispatch_event(&mut self, event: Self::Event) -> Multiple>; + fn dispatch_event(&mut self, event: Self::Event) -> Multiple>; - /// Create a new instance of the reactor. + /// Creates a new instance of the reactor. /// /// This method creates the full state, which consists of all components, and returns a reactor /// instances along with the effects the components generated upon instantiation. /// /// If any instantiation fails, an error is returned. fn new( - cfg: &config::Config, - eq: EventQueueHandle, - ) -> anyhow::Result<(Self, Multiple>)>; + cfg: config::Config, + scheduler: &'static Scheduler, + ) -> anyhow::Result<(Self, Multiple>)>; } -/// Run a reactor. +/// Runs a reactor. /// -/// Start the reactor and associated background tasks, then enter main the event processing loop. +/// Starts the reactor and associated background tasks, then enters main the event processing loop. /// /// `launch` will leak memory on start for global structures each time it is called. /// @@ -141,30 +141,81 @@ pub async fn launch(cfg: config::Config) -> anyhow::Result<()> { // Check if the event is of a reasonable size. This only emits a runtime warning at startup // right now, since storage size of events is not an issue per se, but copying might be // expensive if events get too large. - if event_size > 4 * mem::size_of::() { + if event_size > 16 * mem::size_of::() { warn!( "event size is {} bytes, consider reducing it or boxing", event_size ); } - let scheduler = util::round_robin::WeightedRoundRobin::::new(Queue::weights()); + let scheduler = Scheduler::::new(QueueKind::weights()); // Create a new event queue for this reactor run. - let eq = EventQueueHandle::new(util::leak(scheduler)); + let scheduler = utils::leak(scheduler); - let (mut reactor, initial_effects) = R::new(&cfg, eq)?; + let (mut reactor, initial_effects) = R::new(cfg, scheduler)?; // Run all effects from component instantiation. - eq.process_effects(initial_effects); + process_effects(scheduler, initial_effects).await; info!("entering reactor main loop"); loop { - let (event, q) = eq.next_event().await; - debug!(?event, ?q, "event"); + let (event, q) = scheduler.pop().await; + + // We log events twice, once in display and once in debug mode. + debug!(%event, ?q, "event"); + trace!(?event, ?q, "event"); // Dispatch the event, then execute the resulting effect. let effects = reactor.dispatch_event(event); - eq.process_effects(effects); + process_effects(scheduler, effects).await; + } +} + +/// Spawns tasks that will process the given effects. +#[inline] +async fn process_effects(scheduler: &'static Scheduler, effects: Multiple>) +where + Ev: Send + 'static, +{ + // TODO: Properly carry around priorities. + let queue_kind = QueueKind::default(); + + for effect in effects { + tokio::spawn(async move { + for event in effect.await { + scheduler.push(event, queue_kind).await + } + }); } } + +/// Converts a single effect into another by wrapping it. +#[inline] +pub fn wrap_effect(wrap: F, effect: Effect) -> Effect +where + F: Fn(Ev) -> REv + Send + 'static, + Ev: Send + 'static, + REv: Send + 'static, +{ + // TODO: The double-boxing here is very unfortunate =(. + (async move { + let events: Multiple = effect.await; + events.into_iter().map(wrap).collect() + }) + .boxed() +} + +/// Converts multiple effects into another by wrapping. +#[inline] +pub fn wrap_effects(wrap: F, effects: Multiple>) -> Multiple> +where + F: Fn(Ev) -> REv + Send + 'static + Clone, + Ev: Send + 'static, + REv: Send + 'static, +{ + effects + .into_iter() + .map(move |effect| wrap_effect(wrap.clone(), effect)) + .collect() +} diff --git a/src/reactor/queue_kind.rs b/src/reactor/queue_kind.rs index c05f5820b9..63c0d451ac 100644 --- a/src/reactor/queue_kind.rs +++ b/src/reactor/queue_kind.rs @@ -1,17 +1,18 @@ -//! Queue kinds +//! Queue kinds. //! //! The reactor's event queue uses different queues to group events by priority and polls them in a //! round-robin manner. This way, events are only competing for time within one queue, non-congested //! queues can always assume to be speedily processed. +use std::num::NonZeroUsize; + use enum_iterator::IntoEnumIterator; -use std::num; /// Scheduling priority. /// /// Priorities are ordered from lowest to highest. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, IntoEnumIterator)] -pub enum Queue { +pub enum QueueKind { /// Network events that were initiated outside of this node. /// /// Their load may vary and grouping them together in one queue aides DoS protection. @@ -24,34 +25,36 @@ pub enum Queue { Regular, /// Reporting events on the local node. /// - /// Metric events take precendence over most other events since missing a request for metrics + /// Metric events take precedence over most other events since missing a request for metrics /// might cause the requester to assume that the node is down and forcefully restart it. Metrics, } -impl Default for Queue { +impl Default for QueueKind { fn default() -> Self { - Queue::Regular + QueueKind::Regular } } -impl Queue { - /// Return the weight of a specific queue. +impl QueueKind { + /// Returns the weight of a specific queue. /// /// The weight determines how many events are at most processed from a specific queue during /// each event processing round. - fn weight(self) -> num::NonZeroUsize { - num::NonZeroUsize::new(match self { - Queue::NetworkIncoming => 4, - Queue::Network => 4, - Queue::Regular => 8, - Queue::Metrics => 16, + fn weight(self) -> NonZeroUsize { + NonZeroUsize::new(match self { + QueueKind::NetworkIncoming => 4, + QueueKind::Network => 4, + QueueKind::Regular => 8, + QueueKind::Metrics => 16, }) .expect("weight must be positive") } /// Return weights of all possible `Queue`s. - pub(super) fn weights() -> Vec<(Self, num::NonZeroUsize)> { - Queue::into_enum_iter().map(|q| (q, q.weight())).collect() + pub(super) fn weights() -> Vec<(Self, NonZeroUsize)> { + QueueKind::into_enum_iter() + .map(|q| (q, q.weight())) + .collect() } } diff --git a/src/reactor/validator.rs b/src/reactor/validator.rs index 05ecadbf2f..9a99f09cd5 100644 --- a/src/reactor/validator.rs +++ b/src/reactor/validator.rs @@ -1,31 +1,69 @@ //! Reactor for validator nodes. //! -//! Validator nodes join the validator only network upon startup. -use crate::util::Multiple; -use crate::{config, effect, reactor}; +//! Validator nodes join the validator-only network upon startup. + +use std::fmt::{self, Display, Formatter}; + +use serde::{Deserialize, Serialize}; + +use crate::{ + components::small_network::{self, SmallNetwork}, + config::Config, + effect::Effect, + reactor::{self, EventQueueHandle, Scheduler}, + utils::Multiple, +}; /// Top-level event for the reactor. #[derive(Debug)] #[must_use] -pub enum Event {} +pub enum Event { + Network(small_network::Event), +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub enum Message {} /// Validator node reactor. -pub struct Reactor; +pub struct Reactor { + net: SmallNetwork, +} impl reactor::Reactor for Reactor { type Event = Event; fn new( - _cfg: &config::Config, - _eq: reactor::EventQueueHandle, - ) -> anyhow::Result<(Self, Multiple>)> { - // TODO: Instantiate components here. - let mut _effects = Multiple::new(); + cfg: Config, + scheduler: &'static Scheduler, + ) -> anyhow::Result<(Self, Multiple>)> { + let (net, net_effects) = SmallNetwork::new( + EventQueueHandle::bind(scheduler, Event::Network), + cfg.validator_net, + )?; - Ok((Reactor, _effects)) + Ok(( + Reactor { net }, + reactor::wrap_effects(Event::Network, net_effects), + )) } - fn dispatch_event(&mut self, _event: Event) -> Multiple> { - todo!() + fn dispatch_event(&mut self, event: Event) -> Multiple> { + match event { + Event::Network(ev) => reactor::wrap_effects(Event::Network, self.net.handle_event(ev)), + } + } +} + +impl Display for Event { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Event::Network(ev) => write!(f, "network: {}", ev), + } + } +} + +impl Display for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "TODO: MessagePayload") } } diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000000..cde7c6cb1b --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,840 @@ +//! Transport layer security and signing based on OpenSSL. +//! +//! This module wraps some of the lower-level TLS constructs to provide a reasonably safe-to-use API +//! surface for the rest of the application. It also fixes the security parameters of the TLS level +//! in a central place. +//! +//! Features include +//! +//! * a fixed set of chosen encryption parameters +//! ([`SIGNATURE_ALGORITHM`](constant.SIGNATURE_ALGORITHM.html), +//! [`SIGNATURE_CURVE`](constant.SIGNATURE_CURVE.html), +//! [`SIGNATURE_DIGEST`](constant.SIGNATURE_DIGEST.html)), +//! * construction of TLS acceptors for listening TCP sockets +//! ([`create_tls_acceptor`](fn.create_tls_acceptor.html)), +//! * construction of TLS connectors for outgoing TCP connections +//! ([`create_tls_connector`](fn.create_tls_connector.html)), +//! * creation and validation of self-signed certificates +//! ([`generate_node_cert`](fn.generate_node_cert.html)), +//! * signing and verification of arbitrary values using keys from certificates +//! ([`Signature`](struct.Signature.html), [`Signed`](struct.Signed.html)), and +//! * `serde` support for certificates ([`x509_serde`](x509_serde/index.html)) + +use std::{ + cmp::Ordering, + convert::TryInto, + fmt::{self, Debug, Display, Formatter}, + fs, + hash::Hash, + marker::PhantomData, + path::Path, + str, + time::{SystemTime, UNIX_EPOCH}, +}; + +use anyhow::{anyhow, Context}; +use displaydoc::Display; +use hex_fmt::HexFmt; +use nid::Nid; +use openssl::{ + asn1::{Asn1Integer, Asn1IntegerRef, Asn1Time}, + bn::{BigNum, BigNumContext}, + ec, + error::ErrorStack, + hash::{DigestBytes, MessageDigest}, + nid, + pkey::{PKey, PKeyRef, Private, Public}, + sha, + sign::{Signer, Verifier}, + ssl::{SslAcceptor, SslConnector, SslContextBuilder, SslMethod, SslVerifyMode, SslVersion}, + x509::{X509Builder, X509Name, X509NameBuilder, X509NameRef, X509Ref, X509}, +}; +use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer}; +use serde_big_array::big_array; +use thiserror::Error; + +big_array! { BigArray; } + +/// The chosen signature algorithm (**ECDSA with SHA512**). +const SIGNATURE_ALGORITHM: Nid = Nid::ECDSA_WITH_SHA512; + +/// The underlying elliptic curve (**P-521**). +const SIGNATURE_CURVE: Nid = Nid::SECP521R1; + +/// The chosen signature algorithm (**SHA512**). +const SIGNATURE_DIGEST: Nid = Nid::SHA512; + +/// OpenSSL result type alias. +/// +/// Many functions rely solely on `openssl` functions and return this kind of result. +pub type SslResult = Result; + +/// SHA512 hash. +#[derive(Copy, Clone, Deserialize, Serialize)] +pub struct Sha512(#[serde(with = "BigArray")] [u8; Sha512::SIZE]); + +/// Certificate fingerprint. +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +pub struct CertFingerprint(Sha512); + +/// Public key fingerprint. +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +pub struct KeyFingerprint(Sha512); + +/// Cryptographic signature. +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct Signature(Vec); + +/// TLS certificate. +/// +/// Thin wrapper around `X509` enabling things like Serde serialization and fingerprint caching. +#[derive(Clone)] +pub struct TlsCert { + /// The wrapped x509 certificate. + x509: X509, + + /// Cached certificate fingerprint. + cert_fingerprint: CertFingerprint, + + /// Cached public key fingerprint. + key_fingerprint: KeyFingerprint, +} + +// Serialization and deserialization happens only via x509, which is checked upon deserialization. +impl<'de> Deserialize<'de> for TlsCert { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + validate_cert(x509_serde::deserialize(deserializer)?).map_err(serde::de::Error::custom) + } +} + +impl Serialize for TlsCert { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + x509_serde::serialize(&self.x509, serializer) + } +} + +/// A signed value. +/// +/// Combines a value `V` with a `Signature` and a signature scheme. The signature scheme involves +/// serializing the value to bytes and signing the result. +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct Signed { + data: Vec, + signature: Signature, + _phantom: PhantomData, +} + +impl Signed +where + V: Serialize, +{ + /// Creates a new signed value. + /// + /// Serializes the value to a buffer and signs the buffer. + pub fn new(value: &V, signing_key: &PKeyRef) -> anyhow::Result { + let data = rmp_serde::to_vec(value)?; + let signature = Signature::create(signing_key, &data)?; + + Ok(Signed { + data, + signature, + _phantom: PhantomData, + }) + } +} + +impl Signed +where + V: DeserializeOwned, +{ + /// Validates signature and restore value. + #[allow(dead_code)] + pub fn validate(&self, public_key: &PKeyRef) -> anyhow::Result { + if self.signature.verify(public_key, &self.data)? { + Ok(rmp_serde::from_read(self.data.as_slice())?) + } else { + Err(anyhow!("invalid signature")) + } + } + + /// Validates a self-signed value. + /// + /// Allows for extraction of a public key prior to validating a value. + #[inline] + pub fn validate_self_signed(&self, extract: F) -> anyhow::Result + where + F: FnOnce(&V) -> anyhow::Result>, + { + let unverified = rmp_serde::from_read(self.data.as_slice())?; + { + let public_key = + extract(&unverified).context("could not extract public key from self-signed")?; + if self.signature.verify(&public_key, &self.data)? { + Ok(unverified) + } else { + Err(anyhow!("invalid signature")) + } + } + } +} + +impl Sha512 { + /// Size of digest in bytes. + pub const SIZE: usize = 64; + + /// OpenSSL NID. + const NID: Nid = Nid::SHA512; + + /// Create a new Sha512 by hashing a slice. + pub fn new>(data: B) -> Self { + let mut openssl_sha = sha::Sha512::new(); + openssl_sha.update(data.as_ref()); + Sha512(openssl_sha.finish()) + } + + /// Returns bytestring of the hash, with length `Self::SIZE`. + pub fn bytes(&self) -> &[u8] { + let bs = &self.0[..]; + + debug_assert_eq!(bs.len(), Self::SIZE); + bs + } + + /// Converts an OpenSSL digest into an `Sha512`. + fn from_openssl_digest(digest: &DigestBytes) -> Self { + let digest_bytes = digest.as_ref(); + + debug_assert_eq!( + digest_bytes.len(), + Self::SIZE, + "digest is not the right size - check constants in `tls.rs`" + ); + + let mut buf = [0; Self::SIZE]; + buf.copy_from_slice(&digest_bytes[0..Self::SIZE]); + + Sha512(buf) + } + + /// Returns a new OpenSSL `MessageDigest` set to SHA-512. + fn create_message_digest() -> MessageDigest { + // This can only fail if we specify a `Nid` that does not exist, which cannot happen unless + // there is something wrong with `Self::NID`. + MessageDigest::from_nid(Self::NID).expect("Sha512::NID is invalid") + } +} + +impl Signature { + /// Signs a binary blob with the blessed ciphers and TLS parameters. + pub fn create(private_key: &PKeyRef, data: &[u8]) -> SslResult { + // TODO: This needs verification to ensure we're not doing stupid/textbook RSA-ish. + + // Sha512 is hardcoded, so check we're creating the correct signature. + assert_eq!(Sha512::NID, SIGNATURE_DIGEST); + + let mut signer = Signer::new(Sha512::create_message_digest(), private_key)?; + + // The API of OpenSSL is a bit weird here; there is no constant size for the buffer required + // to create the signatures. Additionally, we need to truncate it to the returned size. + let sig_len = signer.len()?; + let mut sig_buf = vec![0; sig_len]; + let bytes_written = signer.sign_oneshot(&mut sig_buf, data)?; + sig_buf.truncate(bytes_written); + + Ok(Signature(sig_buf)) + } + + /// Verifies that signature matches on a binary blob. + pub fn verify(self: &Signature, public_key: &PKeyRef, data: &[u8]) -> SslResult { + assert_eq!(Sha512::NID, SIGNATURE_DIGEST); + + let mut verifier = Verifier::new(Sha512::create_message_digest(), public_key)?; + + verifier.verify_oneshot(&self.0, data) + } +} + +impl TlsCert { + /// Returns the certificate's fingerprint. + /// + /// In contrast to the `public_key_fingerprint`, this fingerprint also contains the certificate + /// information. + pub fn fingerprint(&self) -> CertFingerprint { + self.cert_fingerprint + } + + /// Extracts the public key from the certificate. + pub fn public_key(&self) -> PKey { + // This can never fail, we validate the certificate on construction and deserialization. + self.x509 + .public_key() + .expect("public key extraction failed, how did we end up with an invalid cert?") + } + + /// Returns the public key fingerprint. + pub fn public_key_fingerprint(&self) -> KeyFingerprint { + self.key_fingerprint + } + + #[allow(dead_code)] + /// Returns OpenSSL X509 certificate. + fn x509(&self) -> &X509 { + &self.x509 + } +} + +impl Debug for TlsCert { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "TlsCert({:?})", self.fingerprint()) + } +} + +impl Hash for TlsCert { + fn hash(&self, state: &mut H) { + self.fingerprint().hash(state); + } +} + +impl PartialEq for TlsCert { + fn eq(&self, other: &Self) -> bool { + self.fingerprint() == other.fingerprint() + } +} + +impl Eq for TlsCert {} + +/// Generates a self-signed (key, certificate) pair suitable for TLS and signing. +/// +/// The common name of the certificate will be "casper-node". +pub fn generate_node_cert() -> SslResult<(X509, PKey)> { + let private_key = generate_private_key()?; + let cert = generate_cert(&private_key, "casper-node")?; + + Ok((cert, private_key)) +} + +/// Creates a TLS acceptor for a server. +/// +/// The acceptor will restrict TLS parameters to secure one defined in this crate that are +/// compatible with connectors built with `create_tls_connector`. +/// +/// Incoming certificates must still be validated using `validate_cert`. +pub fn create_tls_acceptor( + cert: &X509Ref, + private_key: &PKeyRef, +) -> SslResult { + let mut builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls_server())?; + set_context_options(&mut builder, cert, private_key)?; + + Ok(builder.build()) +} + +/// Creates a TLS acceptor for a client. +/// +/// A connector compatible with the acceptor created using `create_tls_acceptor`. Server +/// certificates must always be validated using `validate_cert` after connecting. +pub fn create_tls_connector( + cert: &X509Ref, + private_key: &PKeyRef, +) -> SslResult { + let mut builder = SslConnector::builder(SslMethod::tls_client())?; + set_context_options(&mut builder, cert, private_key)?; + + Ok(builder.build()) +} + +/// Sets common options of both acceptor and connector on TLS context. +/// +/// Used internally to set various TLS parameters. +fn set_context_options( + ctx: &mut SslContextBuilder, + cert: &X509Ref, + private_key: &PKeyRef, +) -> SslResult<()> { + ctx.set_min_proto_version(Some(SslVersion::TLS1_3))?; + + ctx.set_certificate(cert)?; + ctx.set_private_key(private_key)?; + ctx.check_private_key()?; + + // Note that this does not seem to work as one might naively expect; the client can still send + // no certificate and there will be no error from OpenSSL. For this reason, we pass set `PEER` + // (causing the request of a cert), but pass all of them through and verify them after the + // handshake has completed. + ctx.set_verify_callback(SslVerifyMode::PEER, |_, _| true); + + Ok(()) +} + +/// Error during certificate validation. +#[derive(Debug, Display, Error)] +pub enum ValidationError { + /// error reading public key from certificate: {0:?} + CannotReadPublicKey(#[source] ErrorStack), + /// error reading subject or issuer name: {0:?} + CorruptSubjectOrIssuer(#[source] ErrorStack), + /// wrong signature scheme + WrongSignatureAlgorithm, + /// there was an issue reading or converting times: {0:?} + TimeIssue(#[source] ErrorStack), + /// the certificate is not yet valid + NotYetValid, + /// the certificate expired + Expired, + /// the serial number could not be compared to the reference: {0:?} + InvalidSerialNumber(#[source] ErrorStack), + /// wrong serial number + WrongSerialNumber, + /// no valid elliptic curve key could be extracted from certificate: {0:?} + CouldNotExtractEcKey(#[source] ErrorStack), + /// the given public key fails basic sanity checks: {0:?} + KeyFailsCheck(#[source] ErrorStack), + /// underlying elliptic curve is wrong + WrongCurve, + /// certificate is not self-signed + NotSelfSigned, + /// the signature could not be validated + FailedToValidateSignature(#[source] ErrorStack), + /// the signature is invalid + InvalidSignature, + /// failed to read fingerprint + InvalidFingerprint(#[source] ErrorStack), + /// could not create a big num context + BigNumContextNotAvailable(#[source] ErrorStack), + /// could not encode public key as bytes + PublicKeyEncodingFailed(#[source] ErrorStack), +} + +/// Checks that the cryptographic parameters on a certificate are correct and returns the +/// fingerprint of the public key. +/// +/// At the very least this ensures that no weaker ciphers have been used to forge a certificate. +pub fn validate_cert(cert: X509) -> Result { + if cert.signature_algorithm().object().nid() != SIGNATURE_ALGORITHM { + // The signature algorithm is not of the exact kind we are using to generate our + // certificates, an attacker could have used a weaker one to generate colliding keys. + return Err(ValidationError::WrongSignatureAlgorithm); + } + // TODO: Lock down extensions on the certificate --- if we manage to lock down the whole cert in + // a way that no additional bytes can be added (all fields are either known or of fixed + // length) we would have an additional hurdle for preimage attacks to clear. + + let subject = + name_to_string(cert.subject_name()).map_err(ValidationError::CorruptSubjectOrIssuer)?; + let issuer = + name_to_string(cert.issuer_name()).map_err(ValidationError::CorruptSubjectOrIssuer)?; + if subject != issuer { + // All of our certificates are self-signed, so it cannot hurt to check. + return Err(ValidationError::NotSelfSigned); + } + + // All our certificates have serial number 1. + if !num_eq(cert.serial_number(), 1).map_err(ValidationError::InvalidSerialNumber)? { + return Err(ValidationError::WrongSerialNumber); + } + + // Check expiration times against current time. + let asn1_now = Asn1Time::from_unix(now()).map_err(ValidationError::TimeIssue)?; + if asn1_now + .compare(cert.not_before()) + .map_err(ValidationError::TimeIssue)? + != Ordering::Greater + { + return Err(ValidationError::NotYetValid); + } + + if asn1_now + .compare(cert.not_after()) + .map_err(ValidationError::TimeIssue)? + != Ordering::Less + { + return Err(ValidationError::Expired); + } + + // Ensure that the key is using the correct curve parameters. + let public_key = cert + .public_key() + .map_err(ValidationError::CannotReadPublicKey)?; + + let ec_key = public_key + .ec_key() + .map_err(ValidationError::CouldNotExtractEcKey)?; + + ec_key.check_key().map_err(ValidationError::KeyFailsCheck)?; + if ec_key.group().curve_name() != Some(SIGNATURE_CURVE) { + // The underlying curve is not the one we chose. + return Err(ValidationError::WrongCurve); + } + + // Finally we can check the actual signature. + if !cert + .verify(&public_key) + .map_err(ValidationError::FailedToValidateSignature)? + { + return Err(ValidationError::InvalidSignature); + } + + // We now have a valid certificate and can extract the fingerprint. + assert_eq!(Sha512::NID, SIGNATURE_DIGEST); + let digest = &cert + .digest(Sha512::create_message_digest()) + .map_err(ValidationError::InvalidFingerprint)?; + let cert_fingerprint = CertFingerprint(Sha512::from_openssl_digest(digest)); + + // Additionally we can calculate a fingerprint for the public key: + let mut big_num_context = + BigNumContext::new().map_err(ValidationError::BigNumContextNotAvailable)?; + + let buf = ec_key + .public_key() + .to_bytes( + ec::EcGroup::from_curve_name(SIGNATURE_CURVE) + .expect("broken constant SIGNATURE_CURVE") + .as_ref(), + ec::PointConversionForm::COMPRESSED, + &mut big_num_context, + ) + .map_err(ValidationError::PublicKeyEncodingFailed)?; + + let key_fingerprint = KeyFingerprint(Sha512::new(&buf)); + + Ok(TlsCert { + x509: cert, + cert_fingerprint, + key_fingerprint, + }) +} + +/// Loads a certificate from a file. +pub fn load_cert>(src: P) -> anyhow::Result { + let pem = fs::read(src.as_ref()) + .with_context(|| format!("failed to load certificate {:?}", src.as_ref()))?; + + Ok(X509::from_pem(&pem).context("parsing certificate")?) +} + +/// Loads a private key from a file. +pub fn load_private_key>(src: P) -> anyhow::Result> { + let pem = fs::read(src.as_ref()) + .with_context(|| format!("failed to load private key {:?}", src.as_ref()))?; + + // TODO: It might be that we need to call `PKey::private_key_from_pkcs8` instead. + Ok(PKey::private_key_from_pem(&pem).context("parsing private key")?) +} + +/// Saves a certificate to a file. +pub fn save_cert>(cert: &X509Ref, dest: P) -> anyhow::Result<()> { + let pem = cert.to_pem().context("converting certificate to PEM")?; + + fs::write(dest.as_ref(), pem) + .with_context(|| format!("failed to write certificate {:?}", dest.as_ref()))?; + Ok(()) +} + +/// Saves a private key to a file. +pub fn save_private_key>(key: &PKeyRef, dest: P) -> anyhow::Result<()> { + let pem = key + .private_key_to_pem_pkcs8() + .context("converting private key to PEM")?; + + fs::write(dest.as_ref(), pem) + .with_context(|| format!("failed to write private key {:?}", dest.as_ref()))?; + Ok(()) +} + +/// Returns an OpenSSL compatible timestamp. +fn now() -> i64 { + // Note: We could do the timing dance a little better going straight to the UNIX time functions, + // but this saves us having to bring in `libc` as a dependency. + let now = SystemTime::now(); + let ts: i64 = now + .duration_since(UNIX_EPOCH) + // This should work unless the clock is set to before 1970. + .expect("Great Scott! Your clock is horribly broken, Marty.") + .as_secs() + // This will fail past year 2038 on 32 bit systems and very far into the future, both cases + // we consider out of scope. + .try_into() + .expect("32-bit systems and far future are not supported"); + + ts +} + +/// Creates an ASN1 integer from a `u32`. +fn mknum(n: u32) -> Result { + let bn = BigNum::from_u32(n)?; + + bn.to_asn1_integer() +} + +/// Creates an ASN1 name from string components. +/// +/// If `c` or `o` are empty string, they are omitted from the result. +fn mkname(c: &str, o: &str, cn: &str) -> Result { + let mut builder = X509NameBuilder::new()?; + + if !c.is_empty() { + builder.append_entry_by_text("C", c)?; + } + + if !o.is_empty() { + builder.append_entry_by_text("O", o)?; + } + + builder.append_entry_by_text("CN", cn)?; + Ok(builder.build()) +} + +/// Converts an `X509NameRef` to a human readable string. +fn name_to_string(name: &X509NameRef) -> SslResult { + let mut output = String::new(); + + for entry in name.entries() { + output.push_str(entry.object().nid().long_name()?); + output.push_str("="); + output.push_str(entry.data().as_utf8()?.as_ref()); + output.push_str(" "); + } + + Ok(output) +} + +/// Checks if an `Asn1IntegerRef` is equal to a given u32. +fn num_eq(num: &Asn1IntegerRef, other: u32) -> SslResult { + let l = num.to_bn()?; + let r = BigNum::from_u32(other)?; + + // The `BigNum` API seems to be really lacking here. + Ok(l.is_negative() == r.is_negative() && l.ucmp(&r.as_ref()) == Ordering::Equal) +} + +/// Generates a secret key suitable for TLS encryption. +fn generate_private_key() -> SslResult> { + // We do not care about browser-compliance, so we're free to use elliptic curves that are more + // likely to hold up under pressure than the NIST ones. We want to go with ED25519 because djb + // knows best: PKey::generate_ed25519() + // + // However the following bug currently prevents us from doing so: + // https://mta.openssl.org/pipermail/openssl-users/2018-July/008362.html (The same error occurs + // when trying to sign the cert inside the builder) + + // Our second choice is 2^521-1, which is slow but a "nice prime". + // http://blog.cr.yp.to/20140323-ecdsa.html + + // An alternative is https://en.bitcoin.it/wiki/Secp256k1, which puts us at level of bitcoin. + + // TODO: Please verify this for accuracy! + + let ec_group = ec::EcGroup::from_curve_name(SIGNATURE_CURVE)?; + let ec_key = ec::EcKey::generate(ec_group.as_ref())?; + + PKey::from_ec_key(ec_key) +} + +/// Generates a self-signed certificate based on `private_key` with given CN. +fn generate_cert(private_key: &PKey, cn: &str) -> SslResult { + let mut builder = X509Builder::new()?; + + // x509 v3 commonly used, the version is 0-indexed, thus 2 == v3. + builder.set_version(2)?; + + // The serial number is always one, since we are issuing only one cert. + builder.set_serial_number(mknum(1)?.as_ref())?; + + let issuer = mkname("US", "CasperLabs Blockchain", cn)?; + + // Set the issuer, subject names, putting the "self" in "self-signed". + builder.set_issuer_name(issuer.as_ref())?; + builder.set_subject_name(issuer.as_ref())?; + + let ts = now(); + // We set valid-from to one minute into the past to allow some clock-skew. + builder.set_not_before(Asn1Time::from_unix(ts - 60)?.as_ref())?; + + // Valid-until is a little under 10 years, missing at least 2 leap days. + builder.set_not_after(Asn1Time::from_unix(ts + 10 * 365 * 24 * 60 * 60)?.as_ref())?; + + // Set the public key and sign. + builder.set_pubkey(private_key.as_ref())?; + assert_eq!(Sha512::NID, SIGNATURE_DIGEST); + builder.sign(private_key.as_ref(), Sha512::create_message_digest())?; + + let cert = builder.build(); + + // Cheap sanity check. + assert!( + validate_cert(cert.clone()).is_ok(), + "newly generated cert does not pass our own validity check" + ); + + Ok(cert) +} + +/// Serde support for `openx509::X509` certificates. +/// +/// Will also check if certificates are valid according to `validate_cert` when deserializing. +mod x509_serde { + use std::str; + + use openssl::x509::X509; + use serde::{Deserialize, Deserializer, Serializer}; + + use super::validate_cert; + + /// Serde-compatible serialization for X509 certificates. + pub fn serialize(value: &X509, serializer: S) -> Result + where + S: Serializer, + { + let encoded = value.to_pem().map_err(serde::ser::Error::custom)?; + + // We don't expect encoding to fail, since PEMs are ASCII, but pass the error just in case. + serializer.serialize_str(str::from_utf8(&encoded).map_err(serde::ser::Error::custom)?) + } + + /// Serde-compatible deserialization for X509 certificates. + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Create an extra copy for simplicity here. If this becomes a bottleneck, feel free to try + // to leverage Cow here, or implement a custom visitor that handles both cases. + let s: String = Deserialize::deserialize(deserializer)?; + let x509 = X509::from_pem(s.as_bytes()).map_err(serde::de::Error::custom)?; + + validate_cert(x509) + .map_err(serde::de::Error::custom) + .map(|tc| tc.x509) + } +} + +// Below are trait implementations for signatures and fingerprints. Both implement the full set of +// traits that are required to stick into either a `HashMap` or `BTreeMap`. +impl PartialEq for Sha512 { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.bytes() == other.bytes() + } +} + +impl Eq for Sha512 {} + +impl Ord for Sha512 { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + Ord::cmp(self.bytes(), other.bytes()) + } +} + +impl PartialOrd for Sha512 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(Ord::cmp(self, other)) + } +} + +impl Debug for Sha512 { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", HexFmt(&self.0[..])) + } +} + +impl Display for Sha512 { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", HexFmt(&self.0[0..7])) + } +} + +impl Display for CertFingerprint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl Display for KeyFingerprint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl Display for Signature { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", HexFmt(&self.0[0..7])) + } +} + +impl Display for Signed +where + T: Display + for<'de> Deserialize<'de>, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // Decode the data here, even if it is expensive. + match rmp_serde::from_read::<_, T>(self.data.as_slice()) { + Ok(item) => write!(f, "signed[{}]<{} bytes>", self.signature, item), + Err(_err) => write!(f, "signed[{}]", self.signature), + } + } +} + +// Since all `Sha512`s are already hashes, we provide a very cheap hashing function that uses +// bytes from the fingerprint as input, cutting the number of bytes to be hashed to 1/16th. + +// If this is ever a performance bottleneck, a custom hasher can be added that passes these bytes +// through unchanged. +impl Hash for Sha512 { + #[inline] + fn hash(&self, state: &mut H) { + // Use the first eight bytes when hashing, giving 64 bits pure entropy. + let mut chunk = [0u8; 8]; + + // TODO: Benchmark if this is really worthwhile over the automatic derivation. + chunk.copy_from_slice(&self.bytes()[0..8]); + + state.write_u64(u64::from_le_bytes(chunk)) + } +} + +#[cfg(test)] +mod test { + use super::{generate_node_cert, mkname, name_to_string, validate_cert, Signature, TlsCert}; + + #[test] + fn simple_name_to_string() { + let name = mkname("sc", "some_org", "some_cn").expect("could not create name"); + + assert_eq!( + name_to_string(name.as_ref()).expect("name to string failed"), + "countryName=sc organizationName=some_org commonName=some_cn " + ); + } + + #[test] + fn test_tls_cert_serde_roundtrip() { + let (cert, _private_key) = generate_node_cert().expect("failed to generate key, cert pair"); + + let tls_cert = validate_cert(cert).expect("generated cert is not valid"); + + // There is no `PartialEq` impl for `TlsCert`, so we simply serialize it twice. + let serialized = rmp_serde::to_vec(&tls_cert).expect("could not serialize"); + let deserialized: TlsCert = + rmp_serde::from_read(serialized.as_slice()).expect("could not deserialize"); + let serialized_again = rmp_serde::to_vec(&deserialized).expect("could not serialize"); + + assert_eq!(serialized, serialized_again); + } + + #[test] + fn test_signature_roundtrip() { + let (cert, private_key) = generate_node_cert().expect("failed to generate key, cert pair"); + let public_key = cert.public_key().unwrap(); + let data = vec![1, 2, 3, 4, 5]; + let sig = Signature::create(&private_key, &data).expect("signing failed"); + assert!(sig.verify(&public_key, &data).expect("verification failed")); + } +} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 654df2ec4a..0000000000 --- a/src/util.rs +++ /dev/null @@ -1,22 +0,0 @@ -//! Various utilities. -//! -//! The Generic functions that are not limited to a particular module, but are too small to warrant -//! being factored out into standalone crates. - -pub mod round_robin; - -/// Leak a value. -/// -/// Moves a value to the heap and then forgets about, leaving only a static reference behind. -#[inline] -pub fn leak(value: T) -> &'static T { - Box::leak(Box::new(value)) -} - -/// Small amount store. -/// -/// Stored in a smallvec to avoid allocations in case there are less than three items grouped. The -/// size of two items is chosen because one item is the most common use case, and large items are -/// typically boxed. In the latter case two pointers and one enum variant discriminator is almost -/// the same size as an empty vec, which is two pointers. -pub type Multiple = smallvec::SmallVec<[T; 2]>; diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000000..7da11e649e --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,61 @@ +//! Various functions that are not limited to a particular module, but are too small to warrant +//! being factored out into standalone crates. + +mod round_robin; + +use std::{ + cell::RefCell, + fmt::{self, Display, Formatter}, +}; + +use smallvec::SmallVec; + +pub use round_robin::WeightedRoundRobin; + +/// Moves a value to the heap and then forgets about, leaving only a static reference behind. +#[inline] +pub fn leak(value: T) -> &'static T { + Box::leak(Box::new(value)) +} + +/// Small amount store. +/// +/// Stored in a `SmallVec` to avoid allocations in case there are less than three items grouped. The +/// size of two items is chosen because one item is the most common use case, and large items are +/// typically boxed. In the latter case two pointers and one enum variant discriminator is almost +/// the same size as an empty vec, which is two pointers. +pub type Multiple = SmallVec<[T; 2]>; + +/// A display-helper that shows iterators display joined by ",". +#[derive(Debug)] +pub struct DisplayIter(RefCell>); + +impl DisplayIter { + pub fn new(item: T) -> Self { + DisplayIter(RefCell::new(Some(item))) + } +} + +impl Display for DisplayIter +where + I: IntoIterator, + T: Display, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if let Some(src) = self.0.borrow_mut().take() { + let mut first = true; + for item in src.into_iter().take(f.width().unwrap_or(usize::MAX)) { + if first { + first = false; + write!(f, "{}", item)?; + } else { + write!(f, ", {}", item)?; + } + } + + Ok(()) + } else { + write!(f, "DisplayIter:GONE") + } + } +} diff --git a/src/util/round_robin.rs b/src/utils/round_robin.rs similarity index 60% rename from src/util/round_robin.rs rename to src/utils/round_robin.rs index c83bec328e..b15677b6c5 100644 --- a/src/util/round_robin.rs +++ b/src/utils/round_robin.rs @@ -1,13 +1,16 @@ -//! Round-robin scheduling. +//! Weighted round-robin scheduling. //! //! This module implements a weighted round-robin scheduler that ensures no deadlocks occur, but -//! still allows prioriting events from one source over another. The module uses `tokio`s +//! still allows prioritizing events from one source over another. The module uses `tokio`'s //! synchronization primitives under the hood. -use std::collections::{HashMap, VecDeque}; -use std::hash::Hash; -use std::num::NonZeroUsize; -use tokio::sync; +use std::{ + collections::{HashMap, VecDeque}, + hash::Hash, + num::NonZeroUsize, +}; + +use tokio::sync::{Mutex, Semaphore}; /// Weighted round-robin scheduler. /// @@ -22,16 +25,16 @@ use tokio::sync; #[derive(Debug)] pub struct WeightedRoundRobin { /// Current iteration state. - state: sync::Mutex>, + state: Mutex>, /// A list of slots that are round-robin'd. slots: Vec>, /// Actual queues. - queues: HashMap>>, + queues: HashMap>>, /// Number of items in all queues combined. - total: sync::Semaphore, + total: Semaphore, } /// The inner state of the queue iteration. @@ -48,8 +51,8 @@ struct IterationState { /// An internal slot in the round-robin scheduler. /// -/// A slot marks the scheduling position, i.e. which queue we are currently -/// polling and how many tickets it has left before the next one is due. +/// A slot marks the scheduling position, i.e. which queue we are currently polling and how many +/// tickets it has left before the next one is due. #[derive(Copy, Clone, Debug)] struct Slot { /// The key, identifying a queue. @@ -63,17 +66,16 @@ impl WeightedRoundRobin where K: Copy + Clone + Eq + Hash, { - /// Create new weighted round-robin scheduler. + /// Creates a new weighted round-robin scheduler. /// - /// Creates a queue for each pair given in `weights`. The second component - /// of each `weight` is the number of times to return items from one - /// queue before moving on to the next one. + /// Creates a queue for each pair given in `weights`. The second component of each `weight` is + /// the number of times to return items from one queue before moving on to the next one. pub fn new(weights: Vec<(K, NonZeroUsize)>) -> Self { assert!(!weights.is_empty(), "must provide at least one slot"); let queues = weights .iter() - .map(|(idx, _)| (*idx, sync::Mutex::new(VecDeque::new()))) + .map(|(idx, _)| (*idx, Mutex::new(VecDeque::new()))) .collect(); let slots: Vec> = weights .into_iter() @@ -85,17 +87,17 @@ where let active_slot = slots[0]; WeightedRoundRobin { - state: sync::Mutex::new(IterationState { + state: Mutex::new(IterationState { active_slot, active_slot_idx: 0, }), slots, queues, - total: sync::Semaphore::new(0), + total: Semaphore::new(0), } } - /// Push an item to a queue identified by key. + /// Pushes an item to a queue identified by key. /// /// ## Panics /// @@ -103,7 +105,7 @@ where pub async fn push(&self, item: I, queue: K) { self.queues .get(&queue) - .expect("tried to push to non-existant queue") + .expect("tried to push to non-existent queue") .lock() .await .push_back(item); @@ -112,10 +114,9 @@ where self.total.add_permits(1); } - /// Return the next item from queue. + /// Returns the next item from queue. /// - /// Returns `None` if the queue is empty or an internal error occurred. The - /// latter should never happen. + /// Asynchronously waits until a queue is non-empty or panics if an internal error occurred. pub async fn pop(&self) -> (I, K) { self.total.acquire().await.forget(); @@ -151,3 +152,51 @@ where } } } + +#[cfg(test)] +mod tests { + use std::num::NonZeroUsize; + + use futures::{future::FutureExt, join}; + + use super::*; + + #[repr(usize)] + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] + enum QueueKind { + One = 1, + Two, + } + + fn weights() -> Vec<(QueueKind, NonZeroUsize)> { + unsafe { + vec![ + (QueueKind::One, NonZeroUsize::new_unchecked(1)), + (QueueKind::Two, NonZeroUsize::new_unchecked(2)), + ] + } + } + + #[tokio::test] + async fn should_respect_weighting() { + let scheduler = WeightedRoundRobin::::new(weights()); + // Push three items on to each queue + let future1 = scheduler + .push('a', QueueKind::One) + .then(|_| scheduler.push('b', QueueKind::One)) + .then(|_| scheduler.push('c', QueueKind::One)); + let future2 = scheduler + .push('d', QueueKind::Two) + .then(|_| scheduler.push('e', QueueKind::Two)) + .then(|_| scheduler.push('f', QueueKind::Two)); + join!(future2, future1); + + // We should receive the popped values in the order a, d, e, b, f, c + assert_eq!(('a', QueueKind::One), scheduler.pop().await); + assert_eq!(('d', QueueKind::Two), scheduler.pop().await); + assert_eq!(('e', QueueKind::Two), scheduler.pop().await); + assert_eq!(('b', QueueKind::One), scheduler.pop().await); + assert_eq!(('f', QueueKind::Two), scheduler.pop().await); + assert_eq!(('c', QueueKind::One), scheduler.pop().await); + } +}