Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lib/grandpa/commits_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package grandpa

import (
"container/list"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
)
Expand All @@ -14,6 +15,7 @@ import (
// its maximum capacity is reached.
// It is NOT THREAD SAFE to use.
type commitsTracker struct {
sync.Mutex
Comment thread
kishansagathiya marked this conversation as resolved.
// map of commit block hash to linked list commit message.
mapping map[common.Hash]*list.Element
// double linked list of commit messages
Expand All @@ -36,6 +38,9 @@ func newCommitsTracker(capacity int) commitsTracker {
// If the commit message tracker capacity is reached,
// the oldest commit message is removed.
func (ct *commitsTracker) add(commitMessage *CommitMessage) {
ct.Lock()
defer ct.Unlock()

blockHash := commitMessage.Vote.Hash

listElement, has := ct.mapping[blockHash]
Expand Down Expand Up @@ -75,6 +80,9 @@ func (ct *commitsTracker) cleanup() {
// delete deletes all the vote messages for a particular
// block hash from the vote messages tracker.
func (ct *commitsTracker) delete(blockHash common.Hash) {
ct.Lock()
defer ct.Unlock()

listElement, has := ct.mapping[blockHash]
if !has {
return
Expand All @@ -90,6 +98,9 @@ func (ct *commitsTracker) delete(blockHash common.Hash) {
// does not exist in the tracker
func (ct *commitsTracker) message(blockHash common.Hash) (
message *CommitMessage) {
ct.Lock()
Comment thread
kishansagathiya marked this conversation as resolved.
defer ct.Unlock()
Comment thread
kishansagathiya marked this conversation as resolved.

listElement, ok := ct.mapping[blockHash]
if !ok {
return nil
Expand Down
52 changes: 51 additions & 1 deletion lib/grandpa/commits_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ import (
"container/list"
"crypto/rand"
"sort"
"sync"
"testing"
"time"

"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -51,7 +54,9 @@ func Test_newCommitsTracker(t *testing.T) {
}
vt := newCommitsTracker(capacity)

assert.Equal(t, expected, vt)
assert.Equal(t, expected.mapping, vt.mapping)
assert.Equal(t, expected.linkedList, vt.linkedList)
assert.Equal(t, expected.capacity, vt.capacity)
Comment thread
kishansagathiya marked this conversation as resolved.
}

// We cannot really unit test each method independently
Expand Down Expand Up @@ -319,3 +324,48 @@ func Benchmark_ForEachVsSlice(b *testing.B) {
}
})
}

func Test_commitsTracker_threadSafety(t *testing.T) {
// This test is meant to be run with the `-race` flag
// to detect any data race.
t.Parallel()

const capacity = 2
commitsTracker := newCommitsTracker(capacity)

const parallelism = 10

var endWg sync.WaitGroup
defer endWg.Wait()

for i := 1; i < parallelism; i++ {
endWg.Add(1)
go func(i int) {
defer endWg.Done()

blockHash := common.Hash{byte(i)}

commitMessage := &CommitMessage{
Round: 1,
SetID: 1,
Vote: types.GrandpaVote{
Hash: blockHash,
Number: uint32(i),
},
}

timer := time.NewTimer(50 * time.Millisecond)
for {
select {
case <-timer.C:
return
default:
}

commitsTracker.add(commitMessage)
commitsTracker.delete(blockHash)
_ = commitsTracker.message(blockHash)
}
}(i)
}
}
13 changes: 0 additions & 13 deletions lib/grandpa/message_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type tracker struct {
handler *MessageHandler
votes votesTracker
commits commitsTracker
mapLock sync.Mutex
in chan *types.Block // receive imported block from BlockState
stopped chan struct{}

Expand All @@ -38,7 +37,6 @@ func newTracker(bs BlockState, handler *MessageHandler) *tracker {
handler: handler,
votes: newVotesTracker(votesCapacity),
commits: newCommitsTracker(commitsCapacity),
mapLock: sync.Mutex{},
in: bs.GetImportedBlockNotifierChannel(),
stopped: make(chan struct{}),
catchUpResponseMessages: make(map[uint64]*CatchUpResponse),
Expand All @@ -59,15 +57,10 @@ func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) {
return
}

t.mapLock.Lock()
defer t.mapLock.Unlock()

t.votes.add(peerID, message)
}

func (t *tracker) addCommit(cm *CommitMessage) {
t.mapLock.Lock()
defer t.mapLock.Unlock()
t.commits.add(cm)
}

Expand Down Expand Up @@ -100,9 +93,6 @@ func (t *tracker) handleBlocks() {
}

func (t *tracker) handleBlock(b *types.Block) {
t.mapLock.Lock()
defer t.mapLock.Unlock()

h := b.Header.Hash()
vms := t.votes.messages(h)
for _, v := range vms {
Expand All @@ -128,9 +118,6 @@ func (t *tracker) handleBlock(b *types.Block) {
}

func (t *tracker) handleTick() {
t.mapLock.Lock()
defer t.mapLock.Unlock()

for _, networkVoteMessage := range t.votes.networkVoteMessages() {
peerID := networkVoteMessage.from
message := networkVoteMessage.msg
Expand Down
27 changes: 19 additions & 8 deletions lib/grandpa/message_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package grandpa

import (
"container/list"
"testing"
"time"

Expand All @@ -16,12 +17,12 @@ import (
"github.com/stretchr/testify/require"
)

// getMessageFromVotesTracker returns the vote message
// getMessageFromVotesMapping returns the vote message
// from the votes tracker for the given block hash and authority ID.
func getMessageFromVotesTracker(votes votesTracker,
func getMessageFromVotesMapping(votesMapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element,
blockHash common.Hash, authorityID ed25519.PublicKeyBytes) (
message *VoteMessage) {
authorityIDToElement, has := votes.mapping[blockHash]
authorityIDToElement, has := votesMapping[blockHash]
if !has {
return nil
}
Expand Down Expand Up @@ -54,7 +55,7 @@ func TestMessageTracker_ValidateMessage(t *testing.T) {
_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, fake.Hash(), authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, fake.Hash(), authorityID)
require.Equal(t, msg, voteMessage)
}

Expand Down Expand Up @@ -91,7 +92,7 @@ func TestMessageTracker_SendMessage(t *testing.T) {
_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Expand Down Expand Up @@ -143,7 +144,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
_, err = gs.validateVoteMessage("", msg)
require.Equal(t, ErrBlockDoesNotExist, err)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Expand All @@ -159,7 +160,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
}
pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes())
require.True(t, has)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.votes)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote)
}

func TestMessageTracker_MapInsideMap(t *testing.T) {
Expand All @@ -186,7 +187,7 @@ func TestMessageTracker_MapInsideMap(t *testing.T) {

gs.tracker.addVote("", msg)

voteMessage := getMessageFromVotesTracker(gs.tracker.votes, hash, authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, hash, authorityID)
require.NotEmpty(t, voteMessage)
}

Expand Down Expand Up @@ -227,6 +228,15 @@ func TestMessageTracker_handleTick(t *testing.T) {
},
}
gs.tracker.addVote("", msg)
commitMessage := &CommitMessage{
Round: 100,
SetID: 1,
Vote: types.GrandpaVote{
Hash: testHash,
Number: 1,
},
}
gs.tracker.addCommit(commitMessage)

gs.tracker.handleTick()

Expand All @@ -239,4 +249,5 @@ func TestMessageTracker_handleTick(t *testing.T) {

// should be deleted as round in message < grandpa round
require.Empty(t, gs.tracker.votes.messages(testHash))
require.Empty(t, gs.tracker.commits.message(testHash))
}
14 changes: 14 additions & 0 deletions lib/grandpa/votes_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package grandpa

import (
"container/list"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto/ed25519"
Expand All @@ -16,6 +17,7 @@ import (
// its maximum capacity is reached.
// It is NOT THREAD SAFE to use.
type votesTracker struct {
sync.Mutex
Comment thread
kishansagathiya marked this conversation as resolved.
// map of vote block hash to authority ID (ed25519 public Key)
// to linked list element pointer
mapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element
Expand All @@ -38,6 +40,9 @@ func newVotesTracker(capacity int) votesTracker {
// If the vote message tracker capacity is reached,
// the oldest vote message is removed.
func (vt *votesTracker) add(peerID peer.ID, voteMessage *VoteMessage) {
vt.Lock()
defer vt.Unlock()

signedMessage := voteMessage.Message
blockHash := signedMessage.BlockHash
authorityID := signedMessage.AuthorityID
Expand Down Expand Up @@ -101,6 +106,9 @@ func (vt *votesTracker) cleanup() {
// delete deletes all the vote messages for a particular
// block hash from the vote messages tracker.
func (vt *votesTracker) delete(blockHash common.Hash) {
vt.Lock()
defer vt.Unlock()

authIDToElement, has := vt.mapping[blockHash]
if !has {
return
Expand All @@ -119,6 +127,9 @@ func (vt *votesTracker) delete(blockHash common.Hash) {
// It returns nil if the block hash does not exist.
func (vt *votesTracker) messages(blockHash common.Hash) (
messages []networkVoteMessage) {
vt.Lock()
defer vt.Unlock()
Comment thread
kishansagathiya marked this conversation as resolved.

authIDToElement, ok := vt.mapping[blockHash]
if !ok {
// Note authIDToElement cannot be empty
Expand All @@ -138,6 +149,9 @@ func (vt *votesTracker) messages(blockHash common.Hash) (
// as a slice of networkVoteMessages.
func (vt *votesTracker) networkVoteMessages() (
messages []networkVoteMessage) {
vt.Lock()
defer vt.Unlock()
Comment thread
kishansagathiya marked this conversation as resolved.

messages = make([]networkVoteMessage, 0, vt.linkedList.Len())
for _, authorityIDToElement := range vt.mapping {
for _, element := range authorityIDToElement {
Expand Down
4 changes: 3 additions & 1 deletion lib/grandpa/votes_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ func Test_newVotesTracker(t *testing.T) {
}
vt := newVotesTracker(capacity)

assert.Equal(t, expected, vt)
assert.Equal(t, expected.mapping, vt.mapping)
assert.Equal(t, expected.linkedList, vt.linkedList)
assert.Equal(t, expected.capacity, vt.capacity)
Comment thread
kishansagathiya marked this conversation as resolved.
}

Comment thread
kishansagathiya marked this conversation as resolved.
// We cannot really unit test each method independently
Expand Down