From 89517032bf0fc66416e99e92d66c348159593c96 Mon Sep 17 00:00:00 2001 From: noot Date: Mon, 18 Oct 2021 12:18:15 -0400 Subject: [PATCH 01/13] fix block response logic --- dot/state/block.go | 10 + dot/sync/chain_sync.go | 22 +- dot/sync/chain_sync_test.go | 24 +++ dot/sync/interface.go | 3 + dot/sync/message.go | 385 ++++++++++++++++++++++++++++++----- dot/sync/message_test.go | 346 +++++++++++++++++++++++++++++-- dot/sync/mocks/BlockState.go | 67 ++++++ dot/sync/tip_syncer.go | 23 ++- 8 files changed, 807 insertions(+), 73 deletions(-) diff --git a/dot/state/block.go b/dot/state/block.go index a69b15a523..d942ff2fd8 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -494,6 +494,16 @@ func (bs *BlockState) AddBlockToBlockTree(header *types.Header) error { return bs.bt.AddBlock(header, uint64(arrivalTime.UnixNano())) } +// GetAllBlocksAtNumber returns all unfinalised blocks with the given number +func (bs *BlockState) GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) { + header, err := bs.GetHeaderByNumber(num) + if err != nil { + return nil, err + } + + return bs.GetAllBlocksAtDepth(header.ParentHash), nil +} + // GetAllBlocksAtDepth returns all hashes with the depth of the given hash plus one func (bs *BlockState) GetAllBlocksAtDepth(hash common.Hash) []common.Hash { return bs.bt.GetAllBlocksAtDepth(hash) diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 69a9e8063a..eadad60846 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -897,6 +897,12 @@ func workerToRequests(w *worker) ([]*network.BlockRequestMessage, error) { } else { // in tip-syncing mode, we know the hash of the block on the fork we wish to sync start, _ = variadic.NewUint64OrHash(w.startHash) + + // if we're doing descending requests and not at the last (highest starting) request, + // then use number as start block + if w.direction == network.Descending && i != numRequests-1 { + start, _ = variadic.NewUint64OrHash(startNumber) + } } var end *common.Hash @@ -911,7 +917,21 @@ func workerToRequests(w *worker) ([]*network.BlockRequestMessage, error) { Direction: w.direction, Max: &max, } - startNumber += maxResponseSize + + switch w.direction { + case network.Ascending: + startNumber += maxResponseSize + case network.Descending: + startNumber -= maxResponseSize + } + } + + // if our direction is descending, we want to send out the request with the lowest + // startNumber first + if w.direction == network.Descending { + for i, j := 0, len(reqs)-1; i < j; i, j = i+1, j-1 { + reqs[i], reqs[j] = reqs[j], reqs[i] + } } return reqs, nil diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index ef215e69b6..55c208c9d0 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -405,6 +405,30 @@ func TestWorkerToRequests(t *testing.T) { }, }, }, + { + w: &worker{ + startNumber: big.NewInt(1 + maxResponseSize + (maxResponseSize / 2)), + targetNumber: big.NewInt(1), + direction: network.Descending, + requestData: bootstrapRequestData, + }, + expected: []*network.BlockRequestMessage{ + { + RequestedData: network.RequestedDataHeader + network.RequestedDataBody + network.RequestedDataJustification, + StartingBlock: *variadic.MustNewUint64OrHash(1 + (maxResponseSize / 2)), + EndBlockHash: nil, + Direction: network.Descending, + Max: &max64, + }, + { + RequestedData: bootstrapRequestData, + StartingBlock: *variadic.MustNewUint64OrHash(1 + maxResponseSize + (maxResponseSize / 2)), + EndBlockHash: nil, + Direction: network.Descending, + Max: &max128, + }, + }, + }, } for i, tc := range testCases { diff --git a/dot/sync/interface.go b/dot/sync/interface.go index 92f3883501..ba1140414e 100644 --- a/dot/sync/interface.go +++ b/dot/sync/interface.go @@ -55,6 +55,9 @@ type BlockState interface { StoreRuntime(common.Hash, runtime.Instance) GetHighestFinalisedHeader() (*types.Header, error) GetFinalisedNotifierChannel() chan *types.FinalisationInfo + GetHeaderByNumber(num *big.Int) (*types.Header, error) + GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) + IsDescendantOf(parent, child common.Hash) (bool, error) } // StorageState is the interface for the storage state diff --git a/dot/sync/message.go b/dot/sync/message.go index 107ce787e4..82a0c4c9c1 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -32,108 +32,393 @@ const ( ) // CreateBlockResponse creates a block response message from a block request message -func (s *Service) CreateBlockResponse(blockRequest *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { +func (s *Service) CreateBlockResponse(req *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { + switch req.Direction { + case network.Ascending: + return s.handleAscendingRequest(req) + case network.Descending: + return s.handleDescendingRequest(req) + default: + return nil, errors.New("invalid request direction") + } +} + +func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { var ( - startHash, endHash common.Hash - startHeader, endHeader *types.Header - err error - respSize uint32 + startHash *common.Hash + endHash = req.EndBlockHash + startNumber, endNumber uint64 + max uint32 = maxResponseSize ) - if blockRequest.Max != nil { - respSize = *blockRequest.Max - if respSize > maxResponseSize { - respSize = maxResponseSize - } - } else { - respSize = maxResponseSize + // determine maximum response size + if req.Max != nil && *req.Max < maxResponseSize { + max = *req.Max } - switch startBlock := blockRequest.StartingBlock.Value().(type) { + switch startBlock := req.StartingBlock.Value().(type) { case uint64: if startBlock == 0 { startBlock = 1 } - block, err := s.blockState.GetBlockByNumber(big.NewInt(0).SetUint64(startBlock)) //nolint + bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { - return nil, fmt.Errorf("failed to get start block %d for request: %w", startBlock, err) + return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) } - startHeader = &block.Header - startHash = block.Header.Hash() + // if request start is higher than our best block, return error + if bestBlockNumber.Uint64() < startBlock { + return nil, errors.New("request start number is higher than our best block") + } + + startNumber = startBlock + + if endHash != nil { + // TODO: end is hash is provided but start hash isn't. need to determine start block + // that is an ancestor of the end block + sh, err := s.blockState.GetHashByNumber(big.NewInt(int64(startNumber))) + if err != nil { + return nil, fmt.Errorf("failed to get start block %d for request: %w", startNumber, err) + } + + is, err := s.blockState.IsDescendantOf(sh, *endHash) + if err != nil { + return nil, err + } + + if !is { + return nil, fmt.Errorf("failed to get ancestor of end block with hash %s", *endHash) + } + + startHash = &sh + } case common.Hash: - startHash = startBlock - startHeader, err = s.blockState.GetHeader(startHash) + startHash = &startBlock + + // make sure we actually have the starting block + header, err := s.blockState.GetHeader(*startHash) if err != nil { return nil, fmt.Errorf("failed to get start block %s for request: %w", startHash, err) } + + startNumber = header.Number.Uint64() default: return nil, ErrInvalidBlockRequest } - if blockRequest.EndBlockHash != nil { - endHash = *blockRequest.EndBlockHash - endHeader, err = s.blockState.GetHeader(endHash) + if endHash == nil { + endNumber = startNumber + uint64(max) - 1 + bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { - return nil, fmt.Errorf("failed to get end block %s for request: %w", endHash, err) + return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) + } + + if endNumber > bestBlockNumber.Uint64() { + endNumber = bestBlockNumber.Uint64() } } else { - endNumber := big.NewInt(0).Add(startHeader.Number, big.NewInt(int64(respSize-1))) + header, err := s.blockState.GetHeader(*endHash) + if err != nil { + return nil, fmt.Errorf("failed to get end block %s: %w", *endHash, err) + } + + endNumber = header.Number.Uint64() + } + + // start hash provided, need to determine end hash that is descendant of start hash + if startHash != nil { + eh, err := s.checkOrGetDescendantHash(*startHash, endHash, big.NewInt(int64(endNumber))) + endHash = &eh + if err != nil { + return nil, err + } + } + + if startHash == nil && endHash == nil { + logger.Debug("handling BlockRequestMessage", + "start", startNumber, + "end", endNumber, + "direction", req.Direction, + ) + return s.handleAscendingByNumber(startNumber, endNumber, req.RequestedData) + } + + if startHash == nil { + panic("startHash is nil!") + } + + if endHash == nil { + panic("endHash is nil!") + } + + logger.Debug("handling BlockRequestMessage", + "start", *startHash, + "end", *endHash, + "direction", req.Direction, + ) + return s.handleChainByHash(*startHash, *endHash, max, req.RequestedData, req.Direction) +} + +func (s *Service) handleDescendingRequest(req *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { + var ( + startHash *common.Hash + endHash = req.EndBlockHash + startNumber, endNumber uint64 + max uint32 = maxResponseSize + ) + + // determine maximum response size + if req.Max != nil && *req.Max < maxResponseSize { + max = *req.Max + } + + switch startBlock := req.StartingBlock.Value().(type) { + case uint64: bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) } - if endNumber.Cmp(bestBlockNumber) == 1 { - endNumber = bestBlockNumber + // if request start is higher than our best block, only return blocks from our best block and below + if bestBlockNumber.Uint64() < startBlock { + startNumber = bestBlockNumber.Uint64() + } else { + startNumber = startBlock } + case common.Hash: + startHash = &startBlock - endBlock, err := s.blockState.GetBlockByNumber(endNumber) + // make sure we actually have the starting block + header, err := s.blockState.GetHeader(*startHash) if err != nil { - return nil, fmt.Errorf("failed to get end block %d for request: %w", endNumber, err) + return nil, fmt.Errorf("failed to get start block %s for request: %w", startHash, err) } - endHeader = &endBlock.Header - endHash = endHeader.Hash() + + startNumber = header.Number.Uint64() + default: + return nil, ErrInvalidBlockRequest } - logger.Debug("handling BlockRequestMessage", "start", startHeader.Number, "end", endHeader.Number, "startHash", startHash, "endHash", endHash) + // end hash provided, need to determine start hash that is descendant of end hash + if endHash != nil { + sh, err := s.checkOrGetDescendantHash(*endHash, startHash, big.NewInt(int64(startNumber))) + startHash = &sh + if err != nil { + return nil, err + } + } - responseData := []*types.BlockData{} + // end hash is not provided, calculate end by number + if endHash == nil { + if startNumber <= uint64(max+1) { + endNumber = 1 + } else { + endNumber = startNumber - uint64(max) + 1 + } - switch blockRequest.Direction { - case network.Ascending: - for i := startHeader.Number.Int64(); i <= endHeader.Number.Int64(); i++ { - blockData, err := s.getBlockData(big.NewInt(i), blockRequest.RequestedData) + if startHash != nil { + // need to get blocks by subchain if start hash is provided, get end hash + endHeader, err := s.blockState.GetHeaderByNumber(big.NewInt(int64(endNumber))) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get end block %d for request: %w", endNumber, err) } - responseData = append(responseData, blockData) + + hash := endHeader.Hash() + endHash = &hash } - case network.Descending: - for i := endHeader.Number.Int64(); i >= startHeader.Number.Int64(); i-- { - blockData, err := s.getBlockData(big.NewInt(i), blockRequest.RequestedData) + } + + if startHash == nil && endHash == nil { + logger.Debug("handling BlockRequestMessage", + "start", startNumber, + "end", endNumber, + "direction", req.Direction, + ) + return s.handleDescendingByNumber(startNumber, endNumber, req.RequestedData) + } + + if startHash == nil { + panic("startHash is nil!") + } + + if endHash == nil { + panic("endHash is nil!") + } + + logger.Debug("handling BlockRequestMessage", + "start", *startHash, + "end", *endHash, + "direction", req.Direction, + ) + return s.handleChainByHash(*endHash, *startHash, max, req.RequestedData, req.Direction) +} + +// checkOrGetDescendantHash checks if the provided `descedant` is on the same chain as the `ancestor`, if it's provided, +// otherwise, it sets `descendant` to a block with number=`descendantNumber` that is a descendant of the ancestor +// if used with an Ascending request, ancestor is the start block and descendant is the end block +// if used with an Descending request, ancestor is the end block and descendant is the start block +func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *common.Hash, descendantNumber *big.Int) (common.Hash, error) { + if descendantNumber == nil { + return common.EmptyHash, errors.New("descendantNumber is nil") + } + + // if `descendant` was provided, check that it's a descendant of `ancestor` + if descendant != nil { + header, err := s.blockState.GetHeader(ancestor) + if err != nil { + return common.EmptyHash, fmt.Errorf("failed to get descendant %s: %w", *descendant, err) + } + + // if descendant number is lower than ancestor number, this is an error + if header.Number.Cmp(descendantNumber) > 0 { + return common.EmptyHash, fmt.Errorf("invalid request, descendant number %d is higher than ancestor %d", header.Number, descendantNumber) + } + + // check if provided start hash is descendant of provided descendant hash + is, err := s.blockState.IsDescendantOf(ancestor, *descendant) + if err != nil { + return common.EmptyHash, err + } + + if !is { + return common.EmptyHash, errors.New("request start and end hash are not on the same chain") + } + + return *descendant, nil + } + + // otherwise, get block on canonical chain by descendantNumber + hash, err := s.blockState.GetHashByNumber(descendantNumber) + if err != nil { + return common.EmptyHash, err + } + + // check if it's a descendant of the provided ancestor hash + is, err := s.blockState.IsDescendantOf(ancestor, hash) + if err != nil { + return common.EmptyHash, err + } + + if !is { + // if it's not a descedant, search for a block that has number=descendantNumber that is + hashes, err := s.blockState.GetAllBlocksAtNumber(descendantNumber) + if err != nil { + return common.EmptyHash, fmt.Errorf("failed to get blocks at number %d: %w", descendantNumber, err) + } + + for _, hash := range hashes { + is, err := s.blockState.IsDescendantOf(ancestor, hash) if err != nil { - return nil, err + continue + } + + if !is { + continue } - responseData = append(responseData, blockData) + + // this sets the descendant hash to whatever the first block we find with descendantNumber + // is, however there might be multiple blocks that fit this criteria + h := common.EmptyHash + copy(h[:], hash[:]) + descendant = &h + break + } + + if descendant == nil { + return common.EmptyHash, fmt.Errorf("failed to find descedant block with number %d", descendantNumber) + } + } else { + // if it is, set descendant hash to our block w/ descendantNumber + descendant = &hash + } + + logger.Trace("determined descendant", + "ancestor", ancestor, + "descendant", *descendant, + "number", descendantNumber, + ) + return *descendant, nil +} + +func (s *Service) handleAscendingByNumber(start, end uint64, requestedData byte) (*network.BlockResponseMessage, error) { + var err error + data := make([]*types.BlockData, (end-start)+1) + + idx := 0 + for i := start; i <= end; i++ { + data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(i)), requestedData) + if err != nil { + return nil, err + } + idx++ + } + + return &network.BlockResponseMessage{ + BlockData: data, + }, nil +} + +func (s *Service) handleDescendingByNumber(start, end uint64, requestedData byte) (*network.BlockResponseMessage, error) { + var err error + data := make([]*types.BlockData, (start-end)+1) + + idx := 0 + for i := start; i >= end; i-- { + data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(i)), requestedData) + if err != nil { + return nil, err + } + idx++ + } + + return &network.BlockResponseMessage{ + BlockData: data, + }, nil +} + +func (s *Service) handleChainByHash(ancestor, descendant common.Hash, max uint32, requestedData byte, direction network.SyncDirection) (*network.BlockResponseMessage, error) { + subchain, err := s.blockState.SubChain(ancestor, descendant) + if err != nil { + return nil, err + } + + if uint32(len(subchain)) > max { + subchain = subchain[:max] + } + + data := make([]*types.BlockData, len(subchain)) + + for i, hash := range subchain { + data[i], err = s.getBlockData(hash, requestedData) + if err != nil { + return nil, err + } + } + + // reverse BlockData, if descending request + if direction == network.Descending { + for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { + data[i], data[j] = data[j], data[i] } - default: - return nil, errors.New("invalid BlockRequest direction") } - logger.Debug("sending BlockResponseMessage", "start", startHeader.Number, "end", endHeader.Number) return &network.BlockResponseMessage{ - BlockData: responseData, + BlockData: data, }, nil } -func (s *Service) getBlockData(num *big.Int, requestedData byte) (*types.BlockData, error) { +func (s *Service) getBlockDataByNumber(num *big.Int, requestedData byte) (*types.BlockData, error) { hash, err := s.blockState.GetHashByNumber(num) if err != nil { return nil, err } + return s.getBlockData(hash, requestedData) +} + +func (s *Service) getBlockData(hash common.Hash, requestedData byte) (*types.BlockData, error) { + var err error blockData := &types.BlockData{ Hash: hash, } @@ -145,14 +430,14 @@ func (s *Service) getBlockData(num *big.Int, requestedData byte) (*types.BlockDa if (requestedData & network.RequestedDataHeader) == 1 { blockData.Header, err = s.blockState.GetHeader(hash) if err != nil { - logger.Debug("failed to get header for block", "number", num, "hash", hash, "error", err) + logger.Debug("failed to get header for block", "hash", hash, "error", err) } } if (requestedData&network.RequestedDataBody)>>1 == 1 { blockData.Body, err = s.blockState.GetBlockBody(hash) if err != nil { - logger.Debug("failed to get body for block", "number", num, "hash", hash, "error", err) + logger.Debug("failed to get body for block", "hash", hash, "error", err) } } diff --git a/dot/sync/message_test.go b/dot/sync/message_test.go index 17f77f4867..e522358bfb 100644 --- a/dot/sync/message_test.go +++ b/dot/sync/message_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common/optional" "github.com/ChainSafe/gossamer/lib/common/variadic" @@ -38,8 +39,9 @@ func addTestBlocksToState(t *testing.T, depth int, blockState BlockState) { func TestService_CreateBlockResponse_MaxSize(t *testing.T) { s := newTestSyncer(t) - addTestBlocksToState(t, int(maxResponseSize), s.blockState) + addTestBlocksToState(t, int(maxResponseSize*2), s.blockState) + // test ascending start, err := variadic.NewUint64OrHash(uint64(1)) require.NoError(t, err) @@ -47,7 +49,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: nil, - Direction: 0, + Direction: network.Ascending, Max: nil, } @@ -62,7 +64,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: nil, - Direction: 0, + Direction: network.Ascending, Max: &max, } @@ -71,12 +73,79 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { require.Equal(t, int(maxResponseSize), len(resp.BlockData)) require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) + + max = uint32(16) + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Ascending, + Max: &max, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(max), len(resp.BlockData)) + require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(16), resp.BlockData[15].Number()) + + // test descending + start, err = variadic.NewUint64OrHash(uint64(128)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(128), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[127].Number()) + + max = uint32(maxResponseSize + 100) + start, err = variadic.NewUint64OrHash(uint64(256)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: &max, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(256), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(129), resp.BlockData[127].Number()) + + max = uint32(16) + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: &max, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(max), len(resp.BlockData)) + require.Equal(t, big.NewInt(256), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(241), resp.BlockData[15].Number()) } func TestService_CreateBlockResponse_StartHash(t *testing.T) { s := newTestSyncer(t) - addTestBlocksToState(t, int(maxResponseSize), s.blockState) + addTestBlocksToState(t, int(maxResponseSize*2), s.blockState) + // test ascending with nil endBlockHash startHash, err := s.blockState.GetHashByNumber(big.NewInt(1)) require.NoError(t, err) @@ -87,7 +156,149 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: nil, - Direction: 0, + Direction: network.Ascending, + Max: nil, + } + + resp, err := s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) + + endHash, err := s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + // test ascending with non-nil endBlockHash + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &endHash, + Direction: network.Ascending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(16), len(resp.BlockData)) + require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(16), resp.BlockData[15].Number()) + + // test descending with nil endBlockHash + startHash, err = s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + start, err = variadic.NewUint64OrHash(startHash) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(16), len(resp.BlockData)) + require.Equal(t, big.NewInt(16), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[15].Number()) + + // test descending with non-nil endBlockHash + endHash, err = s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &endHash, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(16), len(resp.BlockData)) + require.Equal(t, big.NewInt(16), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[15].Number()) + + // test descending with nil endBlockHash and start > maxResponseSize + startHash, err = s.blockState.GetHashByNumber(big.NewInt(256)) + require.NoError(t, err) + + start, err = variadic.NewUint64OrHash(startHash) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(256), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(129), resp.BlockData[127].Number()) + + startHash, err = s.blockState.GetHashByNumber(big.NewInt(128)) + require.NoError(t, err) + + start, err = variadic.NewUint64OrHash(startHash) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(128), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[127].Number()) +} + +func TestService_CreateBlockResponse_Ascending_EndHash(t *testing.T) { + s := newTestSyncer(t) + addTestBlocksToState(t, int(maxResponseSize+1), s.blockState) + + // should error if end < start + start, err := variadic.NewUint64OrHash(uint64(128)) + require.NoError(t, err) + + end, err := s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + req := &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &end, + Direction: network.Ascending, + Max: nil, + } + + _, err = s.CreateBlockResponse(req) + require.Error(t, err) + + // base case + start, err = variadic.NewUint64OrHash(uint64(1)) + require.NoError(t, err) + + end, err = s.blockState.GetHashByNumber(big.NewInt(128)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &end, + Direction: network.Ascending, Max: nil, } @@ -98,21 +309,40 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) } -func TestService_CreateBlockResponse_Descending(t *testing.T) { +func TestService_CreateBlockResponse_Descending_EndHash(t *testing.T) { s := newTestSyncer(t) - addTestBlocksToState(t, int(maxResponseSize), s.blockState) + addTestBlocksToState(t, int(maxResponseSize+1), s.blockState) - startHash, err := s.blockState.GetHashByNumber(big.NewInt(1)) + // should error if start < end + start, err := variadic.NewUint64OrHash(uint64(1)) require.NoError(t, err) - start, err := variadic.NewUint64OrHash(startHash) + end, err := s.blockState.GetHashByNumber(big.NewInt(128)) require.NoError(t, err) req := &network.BlockRequestMessage{ RequestedData: 3, StartingBlock: *start, - EndBlockHash: nil, - Direction: 1, + EndBlockHash: &end, + Direction: network.Descending, + Max: nil, + } + + _, err = s.CreateBlockResponse(req) + require.Error(t, err) + + // base case + start, err = variadic.NewUint64OrHash(uint64(128)) + require.NoError(t, err) + + end, err = s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &end, + Direction: network.Descending, Max: nil, } @@ -123,8 +353,90 @@ func TestService_CreateBlockResponse_Descending(t *testing.T) { require.Equal(t, big.NewInt(1), resp.BlockData[127].Number()) } -// tests the ProcessBlockRequestMessage method -func TestService_CreateBlockResponse(t *testing.T) { +func TestService_checkOrGetDescendantHash(t *testing.T) { + s := newTestSyncer(t) + branches := map[int]int{ + 8: 1, + } + state.AddBlocksToStateWithFixedBranches(t, s.blockState.(*state.BlockState), 16, branches, 1) + + // base case + ancestor, err := s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + descendant, err := s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + descendantNumber := big.NewInt(16) + + res, err := s.checkOrGetDescendantHash(ancestor, &descendant, descendantNumber) + require.NoError(t, err) + require.Equal(t, descendant, res) + + // supply descendant that's not on canonical chain + leaves := s.blockState.(*state.BlockState).Leaves() + require.Equal(t, 2, len(leaves)) + + ancestor, err = s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + descendant, err = s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + for _, leaf := range leaves { + if !leaf.Equal(descendant) { + descendant = leaf + break + } + } + + res, err = s.checkOrGetDescendantHash(ancestor, &descendant, descendantNumber) + require.NoError(t, err) + require.Equal(t, descendant, res) + + // supply descedant that's not on same chain as ancestor + ancestor, err = s.blockState.GetHashByNumber(big.NewInt(9)) + require.NoError(t, err) + res, err = s.checkOrGetDescendantHash(ancestor, &descendant, descendantNumber) + require.Error(t, err) + + // don't supply descendant, should return block on canonical chain + // as ancestor is on canonical chain + expected, err := s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + res, err = s.checkOrGetDescendantHash(ancestor, nil, descendantNumber) + require.NoError(t, err) + require.Equal(t, expected, res) + + // don't supply descendant and provide ancestor not on canonical chain + // should return descendant block also not on canonical chain + block9s, err := s.blockState.GetAllBlocksAtNumber(big.NewInt(9)) + require.NoError(t, err) + canonical, err := s.blockState.GetHashByNumber(big.NewInt(9)) + require.NoError(t, err) + + // set ancestor to non-canonical block 9 + for _, block := range block9s { + if !canonical.Equal(block) { + ancestor = block + break + } + } + + // expected is non-canonical block 16 + for _, leaf := range leaves { + is, err := s.blockState.IsDescendantOf(ancestor, leaf) //nolint + require.NoError(t, err) + if is { + expected = leaf + break + } + } + + res, err = s.checkOrGetDescendantHash(ancestor, nil, descendantNumber) + require.NoError(t, err) + require.Equal(t, expected, res) +} + +func TestService_CreateBlockResponse_Fields(t *testing.T) { s := newTestSyncer(t) addTestBlocksToState(t, 2, s.blockState) @@ -172,7 +484,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ @@ -191,7 +503,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 1, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ @@ -210,7 +522,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 4, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ @@ -230,7 +542,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 8, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ diff --git a/dot/sync/mocks/BlockState.go b/dot/sync/mocks/BlockState.go index c4c7d4863b..51d087b7e7 100644 --- a/dot/sync/mocks/BlockState.go +++ b/dot/sync/mocks/BlockState.go @@ -122,6 +122,29 @@ func (_m *MockBlockState) CompareAndSetBlockData(bd *types.BlockData) error { return r0 } +// GetAllBlocksAtNumber provides a mock function with given fields: num +func (_m *MockBlockState) GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) { + ret := _m.Called(num) + + var r0 []common.Hash + if rf, ok := ret.Get(0).(func(*big.Int) []common.Hash); ok { + r0 = rf(num) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]common.Hash) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*big.Int) error); ok { + r1 = rf(num) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetBlockBody provides a mock function with given fields: _a0 func (_m *MockBlockState) GetBlockBody(_a0 common.Hash) (*types.Body, error) { ret := _m.Called(_a0) @@ -253,6 +276,29 @@ func (_m *MockBlockState) GetHeader(_a0 common.Hash) (*types.Header, error) { return r0, r1 } +// GetHeaderByNumber provides a mock function with given fields: num +func (_m *MockBlockState) GetHeaderByNumber(num *big.Int) (*types.Header, error) { + ret := _m.Called(num) + + var r0 *types.Header + if rf, ok := ret.Get(0).(func(*big.Int) *types.Header); ok { + r0 = rf(num) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Header) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*big.Int) error); ok { + r1 = rf(num) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetHighestFinalisedHeader provides a mock function with given fields: func (_m *MockBlockState) GetHighestFinalisedHeader() (*types.Header, error) { ret := _m.Called() @@ -410,6 +456,27 @@ func (_m *MockBlockState) HasHeader(hash common.Hash) (bool, error) { return r0, r1 } +// IsDescendantOf provides a mock function with given fields: parent, child +func (_m *MockBlockState) IsDescendantOf(parent common.Hash, child common.Hash) (bool, error) { + ret := _m.Called(parent, child) + + var r0 bool + if rf, ok := ret.Get(0).(func(common.Hash, common.Hash) bool); ok { + r0 = rf(parent, child) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(common.Hash, common.Hash) error); ok { + r1 = rf(parent, child) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // SetFinalisedHash provides a mock function with given fields: hash, round, setID func (_m *MockBlockState) SetFinalisedHash(hash common.Hash, round uint64, setID uint64) error { ret := _m.Called(hash, round, setID) diff --git a/dot/sync/tip_syncer.go b/dot/sync/tip_syncer.go index 45987f56b1..416fe6c3bb 100644 --- a/dot/sync/tip_syncer.go +++ b/dot/sync/tip_syncer.go @@ -65,16 +65,25 @@ func (s *tipSyncer) handleWorkerResult(res *worker) (*worker, error) { return nil, nil } - if errors.Is(res.err.err, errUnknownParent) { - // handleTick will handle the errUnknownParent case - return nil, nil - } - fin, err := s.blockState.GetHighestFinalisedHeader() if err != nil { return nil, err } + if errors.Is(res.err.err, errUnknownParent) { + // handleTick will handle the errUnknownParent case + // TODO: determine if handleTick is working?? + + w := &worker{ + startHash: res.startHash, + startNumber: res.startNumber, + targetNumber: fin.Number, + direction: network.Descending, + requestData: bootstrapRequestData, + } + return w, nil + } + // don't retry if we're requesting blocks lower than finalised switch res.direction { case network.Ascending: @@ -157,6 +166,8 @@ func (*tipSyncer) hasCurrentWorker(w *worker, workers map[uint64]*worker) bool { // handleTick traverses the pending blocks set to find which forks still need to be requested func (s *tipSyncer) handleTick() ([]*worker, error) { + logger.Debug("handling tick...", "num pending blocks", s.pendingBlocks.size()) + if s.pendingBlocks.size() == 0 { return nil, nil } @@ -181,6 +192,8 @@ func (s *tipSyncer) handleTick() ([]*worker, error) { continue } + logger.Debug("handleTick handling pending block", "hash", block.hash, "number", block.number) + if block.header == nil { // case 1 workers = append(workers, &worker{ From f0250e9c60c4f0570adeb9587e7fbdad81511cef Mon Sep 17 00:00:00 2001 From: noot Date: Tue, 19 Oct 2021 07:23:13 -0400 Subject: [PATCH 02/13] update logs --- dot/sync/tip_syncer.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/dot/sync/tip_syncer.go b/dot/sync/tip_syncer.go index 416fe6c3bb..a69b29adcf 100644 --- a/dot/sync/tip_syncer.go +++ b/dot/sync/tip_syncer.go @@ -74,14 +74,15 @@ func (s *tipSyncer) handleWorkerResult(res *worker) (*worker, error) { // handleTick will handle the errUnknownParent case // TODO: determine if handleTick is working?? - w := &worker{ - startHash: res.startHash, - startNumber: res.startNumber, - targetNumber: fin.Number, - direction: network.Descending, - requestData: bootstrapRequestData, - } - return w, nil + // w := &worker{ + // startHash: res.startHash, + // startNumber: res.startNumber, + // targetNumber: fin.Number, + // direction: network.Descending, + // requestData: bootstrapRequestData, + // } + // return w, nil + return nil, nil } // don't retry if we're requesting blocks lower than finalised @@ -166,7 +167,7 @@ func (*tipSyncer) hasCurrentWorker(w *worker, workers map[uint64]*worker) bool { // handleTick traverses the pending blocks set to find which forks still need to be requested func (s *tipSyncer) handleTick() ([]*worker, error) { - logger.Debug("handling tick...", "num pending blocks", s.pendingBlocks.size()) + logger.Debug("handling tick...", "pending blocks count", s.pendingBlocks.size()) if s.pendingBlocks.size() == 0 { return nil, nil @@ -192,7 +193,7 @@ func (s *tipSyncer) handleTick() ([]*worker, error) { continue } - logger.Debug("handleTick handling pending block", "hash", block.hash, "number", block.number) + logger.Trace("handling pending block", "hash", block.hash, "number", block.number) if block.header == nil { // case 1 From 7b07d9d98667fa35a3f7705dd1736da458292f51 Mon Sep 17 00:00:00 2001 From: noot Date: Tue, 19 Oct 2021 07:51:55 -0400 Subject: [PATCH 03/13] cleanup --- dot/sync/tip_syncer.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/dot/sync/tip_syncer.go b/dot/sync/tip_syncer.go index a69b29adcf..ce95fceff1 100644 --- a/dot/sync/tip_syncer.go +++ b/dot/sync/tip_syncer.go @@ -72,16 +72,6 @@ func (s *tipSyncer) handleWorkerResult(res *worker) (*worker, error) { if errors.Is(res.err.err, errUnknownParent) { // handleTick will handle the errUnknownParent case - // TODO: determine if handleTick is working?? - - // w := &worker{ - // startHash: res.startHash, - // startNumber: res.startNumber, - // targetNumber: fin.Number, - // direction: network.Descending, - // requestData: bootstrapRequestData, - // } - // return w, nil return nil, nil } From bdc55658daf900be4554539340e29d4d0eaa489a Mon Sep 17 00:00:00 2001 From: noot Date: Tue, 19 Oct 2021 07:52:48 -0400 Subject: [PATCH 04/13] cleanup more --- dot/sync/message.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/dot/sync/message.go b/dot/sync/message.go index 82a0c4c9c1..58bb307d27 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -144,14 +144,6 @@ func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*net return s.handleAscendingByNumber(startNumber, endNumber, req.RequestedData) } - if startHash == nil { - panic("startHash is nil!") - } - - if endHash == nil { - panic("endHash is nil!") - } - logger.Debug("handling BlockRequestMessage", "start", *startHash, "end", *endHash, @@ -238,14 +230,6 @@ func (s *Service) handleDescendingRequest(req *network.BlockRequestMessage) (*ne return s.handleDescendingByNumber(startNumber, endNumber, req.RequestedData) } - if startHash == nil { - panic("startHash is nil!") - } - - if endHash == nil { - panic("endHash is nil!") - } - logger.Debug("handling BlockRequestMessage", "start", *startHash, "end", *endHash, From 3e64089fed8373c48ed6f4403e99f04911a48709 Mon Sep 17 00:00:00 2001 From: noot Date: Tue, 19 Oct 2021 07:53:42 -0400 Subject: [PATCH 05/13] cleanup more --- dot/sync/tip_syncer.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dot/sync/tip_syncer.go b/dot/sync/tip_syncer.go index ce95fceff1..a20f68551a 100644 --- a/dot/sync/tip_syncer.go +++ b/dot/sync/tip_syncer.go @@ -65,16 +65,16 @@ func (s *tipSyncer) handleWorkerResult(res *worker) (*worker, error) { return nil, nil } - fin, err := s.blockState.GetHighestFinalisedHeader() - if err != nil { - return nil, err - } - if errors.Is(res.err.err, errUnknownParent) { // handleTick will handle the errUnknownParent case return nil, nil } + fin, err := s.blockState.GetHighestFinalisedHeader() + if err != nil { + return nil, err + } + // don't retry if we're requesting blocks lower than finalised switch res.direction { case network.Ascending: From 21cb17d808b842e758e39288b9e09c3ad51bc050 Mon Sep 17 00:00:00 2001 From: noot Date: Mon, 25 Oct 2021 09:27:36 -0400 Subject: [PATCH 06/13] address comments --- dot/sync/chain_sync.go | 4 +--- dot/sync/errors.go | 5 +++- dot/sync/message.go | 49 ++++++++++++++++++---------------------- dot/sync/message_test.go | 1 + dot/sync/syncer.go | 7 ++++++ 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index eadad60846..663c818225 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -646,9 +646,7 @@ func (cs *chainSync) doSync(req *network.BlockRequestMessage) *workerError { if req.Direction == network.Descending { // reverse blocks before pre-validating and placing in ready queue - for i, j := 0, len(resp.BlockData)-1; i < j; i, j = i+1, j-1 { - resp.BlockData[i], resp.BlockData[j] = resp.BlockData[j], resp.BlockData[i] - } + resp.BlockData = reverseBlockData(resp.BlockData) } // perform some pre-validation of response, error if failure diff --git a/dot/sync/errors.go b/dot/sync/errors.go index 53d2f2458a..b685516b2c 100644 --- a/dot/sync/errors.go +++ b/dot/sync/errors.go @@ -40,7 +40,10 @@ var ( ErrInvalidBlock = errors.New("could not verify block") // ErrInvalidBlockRequest is returned when an invalid block request is received - ErrInvalidBlockRequest = errors.New("invalid block request") + ErrInvalidBlockRequest = errors.New("invalid block request") + errInvalidRequestDirection = errors.New("invalid request direction") + errRequestStartTooHigh = errors.New("request start number is higher than our best block") + errFailedToGetEndHashAncestor = errors.New("failed to get ancestor of end block") // chainSync errors errEmptyBlockData = errors.New("empty block data") diff --git a/dot/sync/message.go b/dot/sync/message.go index 58bb307d27..1a03077793 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -39,7 +39,7 @@ func (s *Service) CreateBlockResponse(req *network.BlockRequestMessage) (*networ case network.Descending: return s.handleDescendingRequest(req) default: - return nil, errors.New("invalid request direction") + return nil, errInvalidRequestDirection } } @@ -69,13 +69,13 @@ func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*net // if request start is higher than our best block, return error if bestBlockNumber.Uint64() < startBlock { - return nil, errors.New("request start number is higher than our best block") + return nil, errRequestStartTooHigh } startNumber = startBlock if endHash != nil { - // TODO: end is hash is provided but start hash isn't. need to determine start block + // TODO: end hash is provided but start hash isn't, so we need to determine a start block // that is an ancestor of the end block sh, err := s.blockState.GetHashByNumber(big.NewInt(int64(startNumber))) if err != nil { @@ -88,7 +88,7 @@ func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*net } if !is { - return nil, fmt.Errorf("failed to get ancestor of end block with hash %s", *endHash) + return nil, fmt.Errorf("%w: hash=%s", errFailedToGetEndHashAncestor, *endHash) } startHash = &sh @@ -129,10 +129,11 @@ func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*net // start hash provided, need to determine end hash that is descendant of start hash if startHash != nil { eh, err := s.checkOrGetDescendantHash(*startHash, endHash, big.NewInt(int64(endNumber))) - endHash = &eh if err != nil { return nil, err } + + endHash = &eh } if startHash == nil && endHash == nil { @@ -238,35 +239,35 @@ func (s *Service) handleDescendingRequest(req *network.BlockRequestMessage) (*ne return s.handleChainByHash(*endHash, *startHash, max, req.RequestedData, req.Direction) } -// checkOrGetDescendantHash checks if the provided `descedant` is on the same chain as the `ancestor`, if it's provided, +// checkOrGetDescendantHash checks if the provided `descendant` is on the same chain as the `ancestor`, if it's provided, // otherwise, it sets `descendant` to a block with number=`descendantNumber` that is a descendant of the ancestor // if used with an Ascending request, ancestor is the start block and descendant is the end block // if used with an Descending request, ancestor is the end block and descendant is the start block func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *common.Hash, descendantNumber *big.Int) (common.Hash, error) { if descendantNumber == nil { - return common.EmptyHash, errors.New("descendantNumber is nil") + return common.Hash{}, errors.New("descendantNumber is nil") } // if `descendant` was provided, check that it's a descendant of `ancestor` if descendant != nil { header, err := s.blockState.GetHeader(ancestor) if err != nil { - return common.EmptyHash, fmt.Errorf("failed to get descendant %s: %w", *descendant, err) + return common.Hash{}, fmt.Errorf("failed to get descendant %s: %w", *descendant, err) } // if descendant number is lower than ancestor number, this is an error if header.Number.Cmp(descendantNumber) > 0 { - return common.EmptyHash, fmt.Errorf("invalid request, descendant number %d is higher than ancestor %d", header.Number, descendantNumber) + return common.Hash{}, fmt.Errorf("invalid request, descendant number %d is higher than ancestor %d", header.Number, descendantNumber) } // check if provided start hash is descendant of provided descendant hash is, err := s.blockState.IsDescendantOf(ancestor, *descendant) if err != nil { - return common.EmptyHash, err + return common.Hash{}, err } if !is { - return common.EmptyHash, errors.New("request start and end hash are not on the same chain") + return common.Hash{}, errors.New("request start and end hash are not on the same chain") } return *descendant, nil @@ -275,42 +276,38 @@ func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *com // otherwise, get block on canonical chain by descendantNumber hash, err := s.blockState.GetHashByNumber(descendantNumber) if err != nil { - return common.EmptyHash, err + return common.Hash{}, err } // check if it's a descendant of the provided ancestor hash is, err := s.blockState.IsDescendantOf(ancestor, hash) if err != nil { - return common.EmptyHash, err + return common.Hash{}, err } if !is { - // if it's not a descedant, search for a block that has number=descendantNumber that is + // if it's not a descendant, search for a block that has number=descendantNumber that is hashes, err := s.blockState.GetAllBlocksAtNumber(descendantNumber) if err != nil { - return common.EmptyHash, fmt.Errorf("failed to get blocks at number %d: %w", descendantNumber, err) + return common.Hash{}, fmt.Errorf("failed to get blocks at number %d: %w", descendantNumber, err) } for _, hash := range hashes { is, err := s.blockState.IsDescendantOf(ancestor, hash) - if err != nil { - continue - } - - if !is { + if err != nil || !is { continue } // this sets the descendant hash to whatever the first block we find with descendantNumber // is, however there might be multiple blocks that fit this criteria - h := common.EmptyHash + h := common.Hash{} copy(h[:], hash[:]) descendant = &h break } if descendant == nil { - return common.EmptyHash, fmt.Errorf("failed to find descedant block with number %d", descendantNumber) + return common.Hash{}, fmt.Errorf("failed to find descendant block with number %d", descendantNumber) } } else { // if it is, set descendant hash to our block w/ descendantNumber @@ -330,8 +327,8 @@ func (s *Service) handleAscendingByNumber(start, end uint64, requestedData byte) data := make([]*types.BlockData, (end-start)+1) idx := 0 - for i := start; i <= end; i++ { - data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(i)), requestedData) + for blockNumber := start; blockNumber <= end; blockNumber++ { + data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) if err != nil { return nil, err } @@ -382,9 +379,7 @@ func (s *Service) handleChainByHash(ancestor, descendant common.Hash, max uint32 // reverse BlockData, if descending request if direction == network.Descending { - for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { - data[i], data[j] = data[j], data[i] - } + data = reverseBlockData(data) } return &network.BlockResponseMessage{ diff --git a/dot/sync/message_test.go b/dot/sync/message_test.go index e522358bfb..5c9b407855 100644 --- a/dot/sync/message_test.go +++ b/dot/sync/message_test.go @@ -354,6 +354,7 @@ func TestService_CreateBlockResponse_Descending_EndHash(t *testing.T) { } func TestService_checkOrGetDescendantHash(t *testing.T) { + t.Parallel() s := newTestSyncer(t) branches := map[int]int{ 8: 1, diff --git a/dot/sync/syncer.go b/dot/sync/syncer.go index 9c4330f45b..5dba88630e 100644 --- a/dot/sync/syncer.go +++ b/dot/sync/syncer.go @@ -133,3 +133,10 @@ func (s *Service) HandleBlockAnnounce(from peer.ID, msg *network.BlockAnnounceMe func (s *Service) IsSynced() bool { return s.chainSync.syncState() == tip } + +func reverseBlockData(data []*types.BlockData) []*types.BlockData { + for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { + data[i], data[j] = data[j], data[i] + } + return data +} From 5d921eb5de319b334e0cb6132a49a12bb301ce3d Mon Sep 17 00:00:00 2001 From: noot Date: Mon, 25 Oct 2021 09:28:13 -0400 Subject: [PATCH 07/13] address comments --- dot/sync/message.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dot/sync/message.go b/dot/sync/message.go index 1a03077793..035f4d77fa 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -345,8 +345,8 @@ func (s *Service) handleDescendingByNumber(start, end uint64, requestedData byte data := make([]*types.BlockData, (start-end)+1) idx := 0 - for i := start; i >= end; i-- { - data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(i)), requestedData) + for blockNumber := start; blockNumber >= end; blockNumber-- { + data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) if err != nil { return nil, err } From 23126092858811ca89293f8ed652ae54a26488ad Mon Sep 17 00:00:00 2001 From: noot Date: Wed, 27 Oct 2021 13:07:35 -0400 Subject: [PATCH 08/13] address comments --- dot/sync/chain_sync.go | 2 +- dot/sync/message.go | 18 +++++++----------- dot/sync/message_test.go | 1 + dot/sync/syncer.go | 3 +-- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 663c818225..dcb6d4ddfd 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -646,7 +646,7 @@ func (cs *chainSync) doSync(req *network.BlockRequestMessage) *workerError { if req.Direction == network.Descending { // reverse blocks before pre-validating and placing in ready queue - resp.BlockData = reverseBlockData(resp.BlockData) + reverseBlockData(resp.BlockData) } // perform some pre-validation of response, error if failure diff --git a/dot/sync/message.go b/dot/sync/message.go index 035f4d77fa..f62c1cf23d 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -136,7 +136,7 @@ func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*net endHash = &eh } - if startHash == nil && endHash == nil { + if startHash == nil || endHash == nil { logger.Debug("handling BlockRequestMessage", "start", startNumber, "end", endNumber, @@ -222,7 +222,7 @@ func (s *Service) handleDescendingRequest(req *network.BlockRequestMessage) (*ne } } - if startHash == nil && endHash == nil { + if startHash == nil || endHash == nil { logger.Debug("handling BlockRequestMessage", "start", startNumber, "end", endNumber, @@ -326,13 +326,11 @@ func (s *Service) handleAscendingByNumber(start, end uint64, requestedData byte) var err error data := make([]*types.BlockData, (end-start)+1) - idx := 0 - for blockNumber := start; blockNumber <= end; blockNumber++ { - data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) + for idx, bn := 0, start; bn <= end; idx, bn = idx+1, bn+1 { + data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(bn)), requestedData) if err != nil { return nil, err } - idx++ } return &network.BlockResponseMessage{ @@ -344,13 +342,11 @@ func (s *Service) handleDescendingByNumber(start, end uint64, requestedData byte var err error data := make([]*types.BlockData, (start-end)+1) - idx := 0 - for blockNumber := start; blockNumber >= end; blockNumber-- { - data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) + for idx, bn := 0, start; bn <= end; idx, bn = idx+1, bn+1 { + data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(bn)), requestedData) if err != nil { return nil, err } - idx++ } return &network.BlockResponseMessage{ @@ -379,7 +375,7 @@ func (s *Service) handleChainByHash(ancestor, descendant common.Hash, max uint32 // reverse BlockData, if descending request if direction == network.Descending { - data = reverseBlockData(data) + reverseBlockData(data) } return &network.BlockResponseMessage{ diff --git a/dot/sync/message_test.go b/dot/sync/message_test.go index 5c9b407855..b12888d4b8 100644 --- a/dot/sync/message_test.go +++ b/dot/sync/message_test.go @@ -266,6 +266,7 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { } func TestService_CreateBlockResponse_Ascending_EndHash(t *testing.T) { + t.Parallel() s := newTestSyncer(t) addTestBlocksToState(t, int(maxResponseSize+1), s.blockState) diff --git a/dot/sync/syncer.go b/dot/sync/syncer.go index 5dba88630e..9df49c15b5 100644 --- a/dot/sync/syncer.go +++ b/dot/sync/syncer.go @@ -134,9 +134,8 @@ func (s *Service) IsSynced() bool { return s.chainSync.syncState() == tip } -func reverseBlockData(data []*types.BlockData) []*types.BlockData { +func reverseBlockData(data []*types.BlockData) { for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { data[i], data[j] = data[j], data[i] } - return data } From 95431ced5ee9cd0062a12ee080c7ee13b037a0a9 Mon Sep 17 00:00:00 2001 From: noot Date: Wed, 27 Oct 2021 13:10:11 -0400 Subject: [PATCH 09/13] fix handleDescendingByNumber for loop --- dot/sync/message.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dot/sync/message.go b/dot/sync/message.go index f62c1cf23d..5cc8c462c7 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -342,7 +342,7 @@ func (s *Service) handleDescendingByNumber(start, end uint64, requestedData byte var err error data := make([]*types.BlockData, (start-end)+1) - for idx, bn := 0, start; bn <= end; idx, bn = idx+1, bn+1 { + for idx, bn := 0, start; bn >= end; idx, bn = idx+1, bn-1 { data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(bn)), requestedData) if err != nil { return nil, err From 8985831c172add4b0ed417fbbe1b2263a71f96d5 Mon Sep 17 00:00:00 2001 From: noot Date: Wed, 27 Oct 2021 13:26:25 -0400 Subject: [PATCH 10/13] fix merge issues --- dot/sync/mocks/block_state.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dot/sync/mocks/block_state.go b/dot/sync/mocks/block_state.go index 8ec79c3742..2a3dcd60bc 100644 --- a/dot/sync/mocks/block_state.go +++ b/dot/sync/mocks/block_state.go @@ -123,7 +123,7 @@ func (_m *BlockState) CompareAndSetBlockData(bd *types.BlockData) error { } // GetAllBlocksAtNumber provides a mock function with given fields: num -func (_m *MockBlockState) GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) { +func (_m *BlockState) GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) { ret := _m.Called(num) var r0 []common.Hash @@ -277,7 +277,7 @@ func (_m *BlockState) GetHeader(_a0 common.Hash) (*types.Header, error) { } // GetHeaderByNumber provides a mock function with given fields: num -func (_m *MockBlockState) GetHeaderByNumber(num *big.Int) (*types.Header, error) { +func (_m *BlockState) GetHeaderByNumber(num *big.Int) (*types.Header, error) { ret := _m.Called(num) var r0 *types.Header @@ -457,7 +457,7 @@ func (_m *BlockState) HasHeader(hash common.Hash) (bool, error) { } // IsDescendantOf provides a mock function with given fields: parent, child -func (_m *MockBlockState) IsDescendantOf(parent common.Hash, child common.Hash) (bool, error) { +func (_m *BlockState) IsDescendantOf(parent common.Hash, child common.Hash) (bool, error) { ret := _m.Called(parent, child) var r0 bool From 8e465101bd12516c8eaed2264adc4c85f51041c9 Mon Sep 17 00:00:00 2001 From: noot Date: Thu, 28 Oct 2021 13:51:58 -0400 Subject: [PATCH 11/13] update workerToRequests target hash param --- dot/sync/chain_sync.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 7c17711c48..792454983c 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -904,7 +904,9 @@ func workerToRequests(w *worker) ([]*network.BlockRequestMessage, error) { } var end *common.Hash - if !w.targetHash.IsEmpty() { + if !w.targetHash.IsEmpty() && i == numRequests-1 { + // if we're on our last request (which should contain the target hash), + // then add it end = &w.targetHash } From 997f988f25e39dc953eb7189b1a8e869e23089be Mon Sep 17 00:00:00 2001 From: noot Date: Thu, 28 Oct 2021 14:14:38 -0400 Subject: [PATCH 12/13] address comments --- dot/sync/chain_sync.go | 2 +- dot/sync/errors.go | 3 +++ dot/sync/message.go | 7 +++---- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 792454983c..e7d4d5c9b4 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -899,7 +899,7 @@ func workerToRequests(w *worker) ([]*network.BlockRequestMessage, error) { // if we're doing descending requests and not at the last (highest starting) request, // then use number as start block if w.direction == network.Descending && i != numRequests-1 { - start, _ = variadic.NewUint64OrHash(startNumber) + start = variadic.MustNewUint64OrHash(startNumber) } } diff --git a/dot/sync/errors.go b/dot/sync/errors.go index b685516b2c..6bfc7f6e93 100644 --- a/dot/sync/errors.go +++ b/dot/sync/errors.go @@ -60,6 +60,9 @@ var ( errUnknownParent = errors.New("parent of first block in block response is unknown") errUnknownBlockForJustification = errors.New("received justification for unknown block") errFailedToGetParent = errors.New("failed to get parent header") + errNilDescendantNumber = errors.New("descendant number is nil") + errStartAndEndMismatch = errors.New("request start and end hash are not on the same chain") + errFailedToGetDescendant = errors.New("failed to find descendant block") ) // ErrNilChannel is returned if a channel is nil diff --git a/dot/sync/message.go b/dot/sync/message.go index 5cc8c462c7..d2c7bd217a 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -17,7 +17,6 @@ package sync import ( - "errors" "fmt" "math/big" @@ -245,7 +244,7 @@ func (s *Service) handleDescendingRequest(req *network.BlockRequestMessage) (*ne // if used with an Descending request, ancestor is the end block and descendant is the start block func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *common.Hash, descendantNumber *big.Int) (common.Hash, error) { if descendantNumber == nil { - return common.Hash{}, errors.New("descendantNumber is nil") + return common.Hash{}, errNilDescendantNumber } // if `descendant` was provided, check that it's a descendant of `ancestor` @@ -267,7 +266,7 @@ func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *com } if !is { - return common.Hash{}, errors.New("request start and end hash are not on the same chain") + return common.Hash{}, errStartAndEndMismatch } return *descendant, nil @@ -307,7 +306,7 @@ func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *com } if descendant == nil { - return common.Hash{}, fmt.Errorf("failed to find descendant block with number %d", descendantNumber) + return common.Hash{}, fmt.Errorf("%d with number %d", errFailedToGetDescendant, descendantNumber) } } else { // if it is, set descendant hash to our block w/ descendantNumber From 08d513cad67c4f314efd905f9ed848108c123ddc Mon Sep 17 00:00:00 2001 From: noot Date: Thu, 28 Oct 2021 17:40:59 -0400 Subject: [PATCH 13/13] address comments --- dot/network/notifications.go | 2 +- dot/sync/message.go | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/dot/network/notifications.go b/dot/network/notifications.go index b236157d4e..7ba62745e7 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -349,7 +349,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc err := s.host.writeToStream(hsData.stream, msg) if err != nil { - logger.Trace("failed to send message to peer", "peer", peer, "error", err) + logger.Debug("failed to send message to peer", "peer", peer, "error", err) } } diff --git a/dot/sync/message.go b/dot/sync/message.go index d2c7bd217a..5b3d324895 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -306,7 +306,7 @@ func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *com } if descendant == nil { - return common.Hash{}, fmt.Errorf("%d with number %d", errFailedToGetDescendant, descendantNumber) + return common.Hash{}, fmt.Errorf("%w with number %d", errFailedToGetDescendant, descendantNumber) } } else { // if it is, set descendant hash to our block w/ descendantNumber @@ -325,8 +325,9 @@ func (s *Service) handleAscendingByNumber(start, end uint64, requestedData byte) var err error data := make([]*types.BlockData, (end-start)+1) - for idx, bn := 0, start; bn <= end; idx, bn = idx+1, bn+1 { - data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(bn)), requestedData) + for i := uint64(0); start+i <= end; i++ { + blockNumber := start + i + data[i], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) if err != nil { return nil, err } @@ -341,8 +342,9 @@ func (s *Service) handleDescendingByNumber(start, end uint64, requestedData byte var err error data := make([]*types.BlockData, (start-end)+1) - for idx, bn := 0, start; bn >= end; idx, bn = idx+1, bn-1 { - data[idx], err = s.getBlockDataByNumber(big.NewInt(int64(bn)), requestedData) + for i := uint64(0); start-i >= end; i++ { + blockNumber := start - i + data[i], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) if err != nil { return nil, err }