diff --git a/dot/state/block.go b/dot/state/block.go index 47e54ad912..3736502601 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -268,6 +268,60 @@ func (bs *BlockState) GetHashesByNumber(blockNumber uint) ([]common.Hash, error) return blockHashes, nil } +// GetAllDescendants gets all the descendants for a given block hash (including itself), by first checking in memory +// and, if not found, reading from the block state database. +func (bs *BlockState) GetAllDescendants(hash common.Hash) ([]common.Hash, error) { + allDescendants, err := bs.bt.GetAllDescendants(hash) + if err != nil && !errors.Is(err, blocktree.ErrNodeNotFound) { + return nil, err + } + + if err == nil { + return allDescendants, nil + } + + allDescendants = []common.Hash{hash} + + header, err := bs.GetHeader(hash) + if err != nil { + return nil, fmt.Errorf("getting header: %w", err) + } + + nextBlockHashes, err := bs.GetHashesByNumber(header.Number + 1) + if err != nil { + return nil, fmt.Errorf("getting hashes by number: %w", err) + } + + for _, nextBlockHash := range nextBlockHashes { + nextHeader, err := bs.GetHeader(nextBlockHash) + if err != nil { + return nil, fmt.Errorf("getting header from block hash %s: %w", nextBlockHash, err) + } + // next block is not a descendant of the block for the given hash + if nextHeader.ParentHash != hash { + return []common.Hash{hash}, nil + } + + nextDescendants, err := bs.bt.GetAllDescendants(nextBlockHash) + if err != nil && !errors.Is(err, blocktree.ErrNodeNotFound) { + return nil, fmt.Errorf("getting all descendants: %w", err) + } + if err == nil { + allDescendants = append(allDescendants, nextDescendants...) + return allDescendants, nil + } + + nextDescendants, err = bs.GetAllDescendants(nextBlockHash) + if err != nil { + return nil, err + } + + allDescendants = append(allDescendants, nextDescendants...) + } + + return allDescendants, nil +} + // GetBlockHashesBySlot gets all block hashes that were produced in the given slot. func (bs *BlockState) GetBlockHashesBySlot(slotNum uint64) ([]common.Hash, error) { highestFinalisedHash, err := bs.GetHighestFinalisedHash() @@ -275,7 +329,7 @@ func (bs *BlockState) GetBlockHashesBySlot(slotNum uint64) ([]common.Hash, error return nil, fmt.Errorf("failed to get highest finalised hash: %w", err) } - descendants, err := bs.bt.GetAllDescendants(highestFinalisedHash) + descendants, err := bs.GetAllDescendants(highestFinalisedHash) if err != nil { return nil, fmt.Errorf("failed to get descendants: %w", err) } diff --git a/dot/state/block_test.go b/dot/state/block_test.go index 17796d9646..6a4a1d1478 100644 --- a/dot/state/block_test.go +++ b/dot/state/block_test.go @@ -254,6 +254,69 @@ func TestGetHashesByNumber(t *testing.T) { require.ElementsMatch(t, blocks, []common.Hash{block.Header.Hash(), block2.Header.Hash()}) } +func TestGetAllDescendants(t *testing.T) { + t.Parallel() + + bs := newTestBlockState(t, newTriesEmpty()) + slot := uint64(77) + + babeHeader := types.NewBabeDigest() + err := babeHeader.Set(*types.NewBabePrimaryPreDigest(0, slot, [32]byte{}, [64]byte{})) + require.NoError(t, err) + data, err := scale.Marshal(babeHeader) + require.NoError(t, err) + preDigest := types.NewBABEPreRuntimeDigest(data) + + digest := types.NewDigest() + err = digest.Add(*preDigest) + require.NoError(t, err) + block := &types.Block{ + Header: types.Header{ + ParentHash: testGenesisHeader.Hash(), + Number: 1, + Digest: digest, + }, + Body: sampleBlockBody, + } + + err = bs.AddBlockWithArrivalTime(block, time.Now()) + require.NoError(t, err) + + babeHeader2 := types.NewBabeDigest() + err = babeHeader2.Set(*types.NewBabePrimaryPreDigest(1, slot+1, [32]byte{}, [64]byte{})) + require.NoError(t, err) + data2, err := scale.Marshal(babeHeader2) + require.NoError(t, err) + preDigest2 := types.NewBABEPreRuntimeDigest(data2) + + digest2 := types.NewDigest() + err = digest2.Add(*preDigest2) + require.NoError(t, err) + block2 := &types.Block{ + Header: types.Header{ + ParentHash: block.Header.Hash(), + Number: 2, + Digest: digest2, + }, + Body: sampleBlockBody, + } + err = bs.AddBlockWithArrivalTime(block2, time.Now()) + require.NoError(t, err) + + err = bs.SetFinalisedHash(block2.Header.Hash(), 1, 1) + require.NoError(t, err) + + // can't fetch given block's descendants since the given block get removed from memory after + // being finalised, using blocktree.GetAllDescendants + _, err = bs.bt.GetAllDescendants(block.Header.Hash()) + require.ErrorIs(t, err, blocktree.ErrNodeNotFound) + + // can fetch given finalised block's descendants using disk, using using blockstate.GetAllDescendants + blockHashes, err := bs.GetAllDescendants(block.Header.Hash()) + require.NoError(t, err) + require.ElementsMatch(t, blockHashes, []common.Hash{block.Header.Hash(), block2.Header.Hash()}) +} + func TestGetBlockHashesBySlot(t *testing.T) { t.Parallel() diff --git a/lib/blocktree/blocktree.go b/lib/blocktree/blocktree.go index b378d92225..9bcadad376 100644 --- a/lib/blocktree/blocktree.go +++ b/lib/blocktree/blocktree.go @@ -422,7 +422,7 @@ func (bt *BlockTree) GetAllBlocks() []Hash { return bt.root.getAllDescendants(nil) } -// GetAllDescendants returns all block hashes that are descendants of the given block hash. +// GetAllDescendants returns all block hashes that are descendants of the given block hash (including itself). func (bt *BlockTree) GetAllDescendants(hash common.Hash) ([]Hash, error) { bt.RLock() defer bt.RUnlock() diff --git a/lib/blocktree/blocktree_test.go b/lib/blocktree/blocktree_test.go index 28254b5303..6b479d9130 100644 --- a/lib/blocktree/blocktree_test.go +++ b/lib/blocktree/blocktree_test.go @@ -6,30 +6,15 @@ package blocktree import ( "bytes" "fmt" - "math/rand" "testing" "time" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var zeroHash, _ = common.HexToHash("0x00") -var testHeader = &types.Header{ - ParentHash: zeroHash, - Number: 0, - Digest: types.NewDigest(), -} - -type testBranch struct { - hash Hash - number uint - arrivalTime int64 -} - func newBlockTreeFromNode(root *node) *BlockTree { return &BlockTree{ root: root, @@ -37,114 +22,6 @@ func newBlockTreeFromNode(root *node) *BlockTree { } } -func createPrimaryBABEDigest(t testing.TB) scale.VaryingDataTypeSlice { - babeDigest := types.NewBabeDigest() - err := babeDigest.Set(types.BabePrimaryPreDigest{AuthorityIndex: 0}) - require.NoError(t, err) - - bdEnc, err := scale.Marshal(babeDigest) - require.NoError(t, err) - - digest := types.NewDigest() - err = digest.Add(types.PreRuntimeDigest{ - ConsensusEngineID: types.BabeEngineID, - Data: bdEnc, - }) - require.NoError(t, err) - return digest -} - -func createTestBlockTree(t *testing.T, header *types.Header, number uint) (*BlockTree, []testBranch) { - bt := NewBlockTreeFromRoot(header) - previousHash := header.Hash() - - // branch tree randomly - var branches []testBranch - r := rand.New(rand.NewSource(time.Now().UnixNano())) //skipcq: GSC-G404 - - at := int64(0) - - // create base tree - for i := uint(1); i <= number; i++ { - header := &types.Header{ - ParentHash: previousHash, - Number: i, - Digest: createPrimaryBABEDigest(t), - } - - hash := header.Hash() - err := bt.AddBlock(header, time.Unix(0, at)) - require.NoError(t, err) - - previousHash = hash - - isBranch := r.Intn(2) - if isBranch == 1 { - branches = append(branches, testBranch{ - hash: hash, - number: bt.getNode(hash).number, - arrivalTime: at, - }) - } - - at += int64(r.Intn(8)) - } - - // create tree branches - for _, branch := range branches { - at := branch.arrivalTime - previousHash = branch.hash - - for i := branch.number; i <= number; i++ { - header := &types.Header{ - ParentHash: previousHash, - Number: i + 1, - StateRoot: common.Hash{0x1}, - Digest: createPrimaryBABEDigest(t), - } - - hash := header.Hash() - err := bt.AddBlock(header, time.Unix(0, at)) - require.NoError(t, err) - - previousHash = hash - at += int64(r.Intn(8)) - - } - } - - return bt, branches -} - -func createFlatTree(t testing.TB, number uint) (*BlockTree, []common.Hash) { - rootHeader := &types.Header{ - ParentHash: zeroHash, - Digest: createPrimaryBABEDigest(t), - } - - bt := NewBlockTreeFromRoot(rootHeader) - require.NotNil(t, bt) - previousHash := bt.root.hash - - hashes := []common.Hash{bt.root.hash} - for i := uint(1); i <= number; i++ { - header := &types.Header{ - ParentHash: previousHash, - Number: i, - Digest: createPrimaryBABEDigest(t), - } - - hash := header.Hash() - hashes = append(hashes, hash) - - err := bt.AddBlock(header, time.Unix(0, 0)) - require.NoError(t, err) - previousHash = hash - } - - return bt, hashes -} - func Test_NewBlockTreeFromNode(t *testing.T) { var bt *BlockTree var branches []testBranch diff --git a/lib/blocktree/helpers_test.go b/lib/blocktree/helpers_test.go new file mode 100644 index 0000000000..1edbfade0d --- /dev/null +++ b/lib/blocktree/helpers_test.go @@ -0,0 +1,136 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package blocktree + +import ( + "math/rand" + "testing" + "time" + + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/require" +) + +var zeroHash = common.MustHexToHash("0x00") +var testHeader = &types.Header{ + ParentHash: zeroHash, + Number: 0, + Digest: types.NewDigest(), +} + +type testBranch struct { + hash Hash + number uint + arrivalTime int64 +} + +func createPrimaryBABEDigest(t testing.TB) scale.VaryingDataTypeSlice { + babeDigest := types.NewBabeDigest() + err := babeDigest.Set(types.BabePrimaryPreDigest{AuthorityIndex: 0}) + require.NoError(t, err) + + bdEnc, err := scale.Marshal(babeDigest) + require.NoError(t, err) + + digest := types.NewDigest() + err = digest.Add(types.PreRuntimeDigest{ + ConsensusEngineID: types.BabeEngineID, + Data: bdEnc, + }) + require.NoError(t, err) + return digest +} + +func createTestBlockTree(t *testing.T, header *types.Header, number uint) (*BlockTree, []testBranch) { + bt := NewBlockTreeFromRoot(header) + previousHash := header.Hash() + + // branch tree randomly + var branches []testBranch + r := rand.New(rand.NewSource(time.Now().UnixNano())) // skipcq + + at := int64(0) + + // create base tree + for i := uint(1); i <= number; i++ { + header := &types.Header{ + ParentHash: previousHash, + Number: i, + Digest: createPrimaryBABEDigest(t), + } + + hash := header.Hash() + err := bt.AddBlock(header, time.Unix(0, at)) + require.NoError(t, err) + + previousHash = hash + + isBranch := r.Intn(2) + if isBranch == 1 { + branches = append(branches, testBranch{ + hash: hash, + number: bt.getNode(hash).number, + arrivalTime: at, + }) + } + + at += int64(r.Intn(8)) + } + + // create tree branches + for _, branch := range branches { + at := branch.arrivalTime + previousHash = branch.hash + + for i := branch.number; i <= number; i++ { + header := &types.Header{ + ParentHash: previousHash, + Number: i + 1, + StateRoot: common.Hash{0x1}, + Digest: createPrimaryBABEDigest(t), + } + + hash := header.Hash() + err := bt.AddBlock(header, time.Unix(0, at)) + require.NoError(t, err) + + previousHash = hash + at += int64(r.Intn(8)) + + } + } + + return bt, branches +} + +func createFlatTree(t testing.TB, number uint) (*BlockTree, []common.Hash) { + rootHeader := &types.Header{ + ParentHash: zeroHash, + Digest: createPrimaryBABEDigest(t), + } + + bt := NewBlockTreeFromRoot(rootHeader) + require.NotNil(t, bt) + previousHash := bt.root.hash + + hashes := []common.Hash{bt.root.hash} + for i := uint(1); i <= number; i++ { + header := &types.Header{ + ParentHash: previousHash, + Number: i, + Digest: createPrimaryBABEDigest(t), + } + + hash := header.Hash() + hashes = append(hashes, hash) + + err := bt.AddBlock(header, time.Unix(0, 0)) + require.NoError(t, err) + previousHash = hash + } + + return bt, hashes +}