Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions crates/core/src/checkpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::fmt;
use core::ops::RangeBounds;

use alloc::sync::Arc;
use alloc::vec::Vec;
use bitcoin::{block::Header, BlockHash};

use crate::BlockId;
Expand Down Expand Up @@ -56,10 +57,10 @@ impl<D> Drop for CPInner<D> {

/// Trait that converts [`CheckPoint`] `data` to [`BlockHash`].
///
/// Implementations of [`ToBlockHash`] must always return the blocks consensus-defined hash. If
/// Implementations of [`ToBlockHash`] must always return the block's consensus-defined hash. If
/// your type contains extra fields (timestamps, metadata, etc.), these must be ignored. For
/// example, [`BlockHash`] trivially returns itself, [`Header`] calls its `block_hash()`, and a
/// wrapper type around a [`Header`] should delegate to the headers hash rather than derive one
/// wrapper type around a [`Header`] should delegate to the header's hash rather than derive one
/// from other fields.
pub trait ToBlockHash {
/// Returns the [`BlockHash`] for the associated [`CheckPoint`] `data` type.
Expand All @@ -78,6 +79,20 @@ impl ToBlockHash for Header {
}
}

/// Trait that extracts a block time from [`CheckPoint`] `data`.
///
/// `data` types that contain a block time should implement this.
pub trait ToBlockTime {
/// Returns the block time from the [`CheckPoint`] `data`.
fn to_blocktime(&self) -> u32;
}

impl ToBlockTime for Header {
fn to_blocktime(&self) -> u32 {
self.time
}
}

impl<D> PartialEq for CheckPoint<D> {
fn eq(&self, other: &Self) -> bool {
let self_cps = self.iter().map(|cp| cp.block_id());
Expand Down Expand Up @@ -191,6 +206,8 @@ impl<D> CheckPoint<D>
where
D: ToBlockHash + fmt::Debug + Copy,
{
const MTP_BLOCK_COUNT: u32 = 11;

/// Construct a new base [`CheckPoint`] from given `height` and `data` at the front of a linked
/// list.
pub fn new(height: u32, data: D) -> Self {
Expand All @@ -204,6 +221,38 @@ where
}))
}

/// Calculate the median time past (MTP) for this checkpoint.
///
/// Uses 11 blocks (heights h-10 through h, where h is the current height) to compute the MTP
/// for the current block. This is used in Bitcoin's consensus rules for time-based validations
/// (BIP-0113).
///
/// Note: This is a pseudo-median that doesn't average the two middle values.
///
/// Returns `None` if the data type doesn't support block times or if any of the required
/// 11 sequential blocks are missing.
pub fn median_time_past(&self) -> Option<u32>
where
D: ToBlockTime,
{
let current_height = self.height();
let earliest_height = current_height.saturating_sub(Self::MTP_BLOCK_COUNT - 1);

let mut timestamps = (earliest_height..=current_height)
.map(|height| {
// Return `None` for missing blocks or missing block times
let cp = self.get(height)?;
let block_time = cp.data_ref().to_blocktime();
Some(block_time)
})
.collect::<Option<Vec<u32>>>()?;
timestamps.sort_unstable();

// If there are more than 1 middle values, use the higher middle value.
// This is mathematically incorrect, but this is the BIP-0113 specification.
Some(timestamps[timestamps.len() / 2])
}

/// Construct from an iterator of block data.
///
/// Returns `Err(None)` if `blocks` doesn't yield any data. If the blocks are not in ascending
Expand Down
133 changes: 132 additions & 1 deletion crates/core/tests/test_checkpoint.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use bdk_core::CheckPoint;
use bdk_core::{CheckPoint, ToBlockHash, ToBlockTime};
use bdk_testenv::{block_id, hash};
use bitcoin::hashes::Hash;
use bitcoin::BlockHash;

/// Inserting a block that already exists in the checkpoint chain must always succeed.
Expand Down Expand Up @@ -55,3 +56,133 @@ fn checkpoint_destruction_is_sound() {
}
assert_eq!(cp.iter().count() as u32, end);
}

/// Test helper: A block data type that includes timestamp
/// Fields are (height, time)
#[derive(Debug, Clone, Copy)]
struct BlockWithTime(u32, u32);

impl ToBlockHash for BlockWithTime {
fn to_blockhash(&self) -> BlockHash {
// Generate a deterministic hash from the height
let hash_bytes = bitcoin::hashes::sha256d::Hash::hash(&self.0.to_le_bytes());
BlockHash::from_raw_hash(hash_bytes)
}
}

impl ToBlockTime for BlockWithTime {
fn to_blocktime(&self) -> u32 {
self.1
}
}

#[test]
fn test_median_time_past_with_timestamps() {
// Create a chain with 12 blocks (heights 0-11) with incrementing timestamps
let blocks: Vec<_> = (0..=11)
.map(|i| (i, BlockWithTime(i, 1000 + i * 10)))
.collect();

let cp = CheckPoint::from_blocks(blocks).expect("must construct valid chain");

// Height 11: 11 previous blocks (11..=1), pseudo-median at index 6 = 1060
assert_eq!(cp.median_time_past(), Some(1060));

// Height 10: 11 previous blocks (10..=0), pseudo-median at index 5 = 1050
assert_eq!(cp.get(10).unwrap().median_time_past(), Some(1050));

// Height 5: 6 previous blocks (5..=0), pseudo-median at index 3 = 1030
assert_eq!(cp.get(5).unwrap().median_time_past(), Some(1030));

// Height 3: 4 previous blocks (3..=0), pseudo-median at index 2 = 1020
assert_eq!(cp.get(3).unwrap().median_time_past(), Some(1020));

// Height 0: 1 block at index 0 = 1000
assert_eq!(cp.get(0).unwrap().median_time_past(), Some(1000));
}

#[test]
fn test_previous_median_time_past_edge_cases() {
// Test with minimum required blocks (11)
let blocks: Vec<_> = (0..=10)
.map(|i| (i, BlockWithTime(i, 1000 + i * 100)))
.collect();

let cp = CheckPoint::from_blocks(blocks).expect("must construct valid chain");

// At height 10: next_mtp uses all 11 blocks (0-10)
// Times: [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000]
// Median at index 5 = 1500
assert_eq!(cp.median_time_past(), Some(1500));

// At height 9: mtp uses blocks 0-9 (10 blocks)
// Times: [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900]
// Median at index 5 = 1400
assert_eq!(cp.get(9).unwrap().median_time_past(), Some(1500));

// Test sparse chain where next_mtp returns None due to missing blocks
let sparse = vec![
(0, BlockWithTime(0, 1000)),
(5, BlockWithTime(5, 1050)),
(10, BlockWithTime(10, 1100)),
];
let sparse_cp = CheckPoint::from_blocks(sparse).expect("must construct valid chain");

// At height 10: next_mtp needs blocks 0-10 but many are missing
assert_eq!(sparse_cp.median_time_past(), None);
}

#[test]
fn test_mtp_with_non_monotonic_times() {
// Test both methods with shuffled timestamps
let blocks = vec![
(0, BlockWithTime(0, 1500)),
(1, BlockWithTime(1, 1200)),
(2, BlockWithTime(2, 1800)),
(3, BlockWithTime(3, 1100)),
(4, BlockWithTime(4, 1900)),
(5, BlockWithTime(5, 1300)),
(6, BlockWithTime(6, 1700)),
(7, BlockWithTime(7, 1400)),
(8, BlockWithTime(8, 1600)),
(9, BlockWithTime(9, 1000)),
(10, BlockWithTime(10, 2000)),
(11, BlockWithTime(11, 1650)),
];

let cp = CheckPoint::from_blocks(blocks).expect("must construct valid chain");

// Height 10:
// mtp uses blocks 0-10: sorted
// [1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,2000] Median at index 5 = 1500
assert_eq!(cp.get(10).unwrap().median_time_past(), Some(1500));

// Height 11:
// mtp uses blocks 1-11: sorted
// [1000,1100,1200,1300,1400,1600,1650,1700,1800,1900,2000] Median at index 5 = 1600
assert_eq!(cp.median_time_past(), Some(1600));

// Test with smaller chain to verify sorting at different heights
let cp3 = cp.get(3).unwrap();
// Height 3: timestamps [1100, 1800, 1200, 1500] -> sorted [1100, 1200, 1500, 1800]
// Pseudo-median at index 2 = 1500
assert_eq!(cp3.median_time_past(), Some(1500));
}

#[test]
fn test_mtp_sparse_chain() {
// Sparse chain missing required sequential blocks
let blocks = vec![
(0, BlockWithTime(0, 1000)),
(3, BlockWithTime(3, 1030)),
(7, BlockWithTime(7, 1070)),
(11, BlockWithTime(11, 1110)),
(15, BlockWithTime(15, 1150)),
];

let cp = CheckPoint::from_blocks(blocks).expect("must construct valid chain");

// All heights should return None due to missing sequential blocks
assert_eq!(cp.median_time_past(), None);
assert_eq!(cp.get(11).unwrap().median_time_past(), None);
}