From 0fd206e20e13603de04025f92c564c708c799713 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 18 Aug 2025 15:25:37 +0200 Subject: [PATCH 01/41] wip --- abci/example/example_test.go | 6 +- abci/example/kvstore/kvstore_test.go | 10 +- abci/tests/client_server_test.go | 2 +- cmd/tendermint/commands/reindex_event_test.go | 2 +- cmd/tendermint/commands/reset_test.go | 7 +- cmd/tendermint/commands/rollback_test.go | 2 +- cmd/tendermint/commands/root_test.go | 6 +- go.mod | 2 +- internal/blocksync/pool_test.go | 6 +- internal/blocksync/reactor_test.go | 6 +- internal/consensus/byzantine_test.go | 2 +- internal/consensus/invalid_test.go | 2 +- internal/consensus/mempool_test.go | 10 +- internal/consensus/msgs_test.go | 2 +- internal/consensus/pbts_test.go | 10 +- internal/consensus/reactor_test.go | 18 +-- internal/consensus/replay_test.go | 14 +-- internal/consensus/state_test.go | 104 +++++++++--------- .../consensus/types/height_vote_set_test.go | 2 +- internal/consensus/wal_test.go | 8 +- internal/dbsync/syncer_test.go | 12 +- internal/eventbus/event_bus_test.go | 18 +-- internal/eventlog/eventlog_test.go | 2 +- internal/evidence/pool_test.go | 20 ++-- internal/evidence/reactor_test.go | 12 +- internal/evidence/verify_test.go | 12 +- internal/inspect/inspect_test.go | 22 ++-- internal/libs/autofile/autofile_test.go | 4 +- internal/libs/autofile/group_test.go | 14 +-- internal/libs/queue/queue_test.go | 6 +- internal/mempool/mempool_bench_test.go | 2 +- internal/mempool/mempool_test.go | 38 +++---- internal/mempool/reactor_test.go | 14 +-- internal/p2p/address_test.go | 2 +- internal/p2p/channel_test.go | 2 +- internal/p2p/conn/connection_test.go | 28 ++--- internal/p2p/peermanager_scoring_test.go | 2 +- internal/p2p/peermanager_test.go | 54 ++++----- internal/p2p/pex/reactor_test.go | 18 +-- internal/p2p/pqueue_test.go | 2 +- internal/p2p/router_filter_test.go | 2 +- internal/p2p/router_init_test.go | 2 +- internal/p2p/router_test.go | 34 +++--- internal/p2p/rqueue_test.go | 2 +- internal/p2p/transport_mconn_test.go | 8 +- internal/p2p/transport_test.go | 24 ++-- internal/proxy/client_test.go | 10 +- internal/pubsub/example_test.go | 2 +- internal/pubsub/pubsub_test.go | 28 ++--- internal/rpc/core/blocks_test.go | 2 +- internal/state/execution_test.go | 24 ++-- internal/state/indexer/block/kv/kv_test.go | 2 +- .../state/indexer/indexer_service_test.go | 2 +- internal/state/indexer/sink/kv/kv_test.go | 10 +- internal/state/indexer/sink/null/null_test.go | 2 +- internal/state/indexer/sink/psql/psql_test.go | 2 +- internal/state/indexer/tx/kv/kv_bench_test.go | 2 +- internal/state/indexer/tx/kv/kv_test.go | 10 +- internal/state/rollback_test.go | 2 +- internal/state/state_test.go | 2 +- internal/state/store_test.go | 6 +- internal/state/validation_test.go | 6 +- internal/statesync/block_queue_test.go | 12 +- internal/statesync/reactor_test.go | 18 +-- internal/statesync/syncer_test.go | 24 ++-- libs/cli/setup_test.go | 8 +- libs/events/events_test.go | 10 +- libs/service/service_test.go | 4 +- light/client_benchmark_test.go | 8 +- light/client_test.go | 38 +++---- light/detector_test.go | 14 +-- light/dispatcher_test.go | 16 +-- light/example_test.go | 2 +- light/light_test.go | 6 +- light/provider/http/http_test.go | 2 +- light/store/db/db_test.go | 10 +- node/node_test.go | 24 ++-- privval/file_test.go | 10 +- privval/grpc/client_test.go | 6 +- privval/grpc/server_test.go | 6 +- privval/signer_client_test.go | 20 ++-- privval/signer_listener_endpoint_test.go | 4 +- rpc/client/eventstream/eventstream_test.go | 4 +- rpc/client/examples_test.go | 4 +- rpc/client/helpers_test.go | 2 +- rpc/client/mock/abci_test.go | 6 +- rpc/client/mock/status_test.go | 2 +- rpc/client/rpc_test.go | 14 +-- rpc/jsonrpc/client/integration_test.go | 2 +- rpc/jsonrpc/client/ws_client_test.go | 8 +- rpc/jsonrpc/jsonrpc_test.go | 2 +- rpc/jsonrpc/server/http_server_test.go | 4 +- rpc/jsonrpc/server/parse_test.go | 4 +- scripts/confix/confix_test.go | 2 +- scripts/keymigrate/migrate_test.go | 2 +- scripts/scmigrate/migrate_test.go | 4 +- test/e2e/tests/block_test.go | 2 +- test/e2e/tests/e2e_test.go | 2 +- test/e2e/tests/evidence_test.go | 2 +- test/e2e/tests/validator_test.go | 4 +- test/fuzz/tests/mempool_test.go | 2 +- types/block_test.go | 22 ++-- types/evidence_test.go | 18 +-- types/light_test.go | 6 +- types/proposal_test.go | 8 +- types/validation_test.go | 12 +- types/validator_set_test.go | 12 +- types/validator_test.go | 4 +- types/vote_set_test.go | 14 +-- types/vote_test.go | 20 ++-- 110 files changed, 559 insertions(+), 560 deletions(-) diff --git a/abci/example/example_test.go b/abci/example/example_test.go index 0503448b8..df514d8fa 100644 --- a/abci/example/example_test.go +++ b/abci/example/example_test.go @@ -28,7 +28,7 @@ func init() { } func TestKVStore(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -37,7 +37,7 @@ func TestKVStore(t *testing.T) { } func TestBaseApp(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -46,7 +46,7 @@ func TestBaseApp(t *testing.T) { } func TestGRPC(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/abci/example/kvstore/kvstore_test.go b/abci/example/kvstore/kvstore_test.go index 9631bc1ab..357917e71 100644 --- a/abci/example/kvstore/kvstore_test.go +++ b/abci/example/kvstore/kvstore_test.go @@ -67,7 +67,7 @@ func testKVStore(ctx context.Context, t *testing.T, app types.Application, tx [] } func TestKVStoreKV(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() kvstore := NewApplication() @@ -82,7 +82,7 @@ func TestKVStoreKV(t *testing.T) { } func TestPersistentKVStoreKV(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() dir := t.TempDir() @@ -100,7 +100,7 @@ func TestPersistentKVStoreKV(t *testing.T) { } func TestPersistentKVStoreInfo(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() dir := t.TempDir() logger := log.NewNopLogger() @@ -144,7 +144,7 @@ func TestPersistentKVStoreInfo(t *testing.T) { // add a validator, remove a validator, update a validator func TestValUpdates(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() kvstore := NewApplication() @@ -313,7 +313,7 @@ func makeGRPCClientServer( } func TestClientServer(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/abci/tests/client_server_test.go b/abci/tests/client_server_test.go index a97c0c7c4..b3a66804f 100644 --- a/abci/tests/client_server_test.go +++ b/abci/tests/client_server_test.go @@ -16,7 +16,7 @@ import ( func TestClientServerNoAddrPrefix(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() const ( diff --git a/cmd/tendermint/commands/reindex_event_test.go b/cmd/tendermint/commands/reindex_event_test.go index 971cf1c87..c845dd3ab 100644 --- a/cmd/tendermint/commands/reindex_event_test.go +++ b/cmd/tendermint/commands/reindex_event_test.go @@ -175,7 +175,7 @@ func TestReIndexEvent(t *testing.T) { {height, height, false}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() conf := config.DefaultConfig() diff --git a/cmd/tendermint/commands/reset_test.go b/cmd/tendermint/commands/reset_test.go index 2919eafc4..eec7f3a64 100644 --- a/cmd/tendermint/commands/reset_test.go +++ b/cmd/tendermint/commands/reset_test.go @@ -1,7 +1,6 @@ package commands import ( - "context" "path/filepath" "testing" @@ -19,7 +18,7 @@ func Test_ResetAll(t *testing.T) { config.SetRoot(dir) logger := log.NewNopLogger() cfg.EnsureRoot(dir) - require.NoError(t, initFilesWithConfig(context.Background(), config, logger, types.ABCIPubKeyTypeEd25519)) + require.NoError(t, initFilesWithConfig(t.Context(), config, logger, types.ABCIPubKeyTypeEd25519)) pv, err := privval.LoadFilePV(config.PrivValidator.KeyFile(), config.PrivValidator.StateFile()) require.NoError(t, err) pv.LastSignState.Height = 10 @@ -43,7 +42,7 @@ func Test_ResetState(t *testing.T) { config.SetRoot(dir) logger := log.NewNopLogger() cfg.EnsureRoot(dir) - require.NoError(t, initFilesWithConfig(context.Background(), config, logger, types.ABCIPubKeyTypeEd25519)) + require.NoError(t, initFilesWithConfig(t.Context(), config, logger, types.ABCIPubKeyTypeEd25519)) pv, err := privval.LoadFilePV(config.PrivValidator.KeyFile(), config.PrivValidator.StateFile()) require.NoError(t, err) pv.LastSignState.Height = 10 @@ -67,7 +66,7 @@ func Test_UnsafeResetAll(t *testing.T) { config.SetRoot(dir) logger := log.NewNopLogger() cfg.EnsureRoot(dir) - require.NoError(t, initFilesWithConfig(context.Background(), config, logger, types.ABCIPubKeyTypeEd25519)) + require.NoError(t, initFilesWithConfig(t.Context(), config, logger, types.ABCIPubKeyTypeEd25519)) pv, err := privval.LoadFilePV(config.PrivValidator.KeyFile(), config.PrivValidator.StateFile()) require.NoError(t, err) pv.LastSignState.Height = 10 diff --git a/cmd/tendermint/commands/rollback_test.go b/cmd/tendermint/commands/rollback_test.go index 48c41044b..bdbd06d46 100644 --- a/cmd/tendermint/commands/rollback_test.go +++ b/cmd/tendermint/commands/rollback_test.go @@ -17,7 +17,7 @@ import ( func TestRollbackIntegration(t *testing.T) { var height int64 dir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) diff --git a/cmd/tendermint/commands/root_test.go b/cmd/tendermint/commands/root_test.go index 6d38f91c9..99315823d 100644 --- a/cmd/tendermint/commands/root_test.go +++ b/cmd/tendermint/commands/root_test.go @@ -77,7 +77,7 @@ func TestRootHome(t *testing.T) { {nil, map[string]string{"TMHOME": newRoot}, newRoot}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for i, tc := range cases { @@ -114,7 +114,7 @@ func TestRootFlagsEnv(t *testing.T) { {nil, map[string]string{"TM_LOG_LEVEL": "debug"}, "debug"}, // right env } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for i, tc := range cases { @@ -131,7 +131,7 @@ func TestRootFlagsEnv(t *testing.T) { } func TestRootConfig(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // write non-default config diff --git a/go.mod b/go.mod index 99e13c670..e028c1a8b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tendermint/tendermint -go 1.21 +go 1.24.5 require ( github.com/BurntSushi/toml v1.1.0 diff --git a/internal/blocksync/pool_test.go b/internal/blocksync/pool_test.go index 894bc09e8..cf1dc8a25 100644 --- a/internal/blocksync/pool_test.go +++ b/internal/blocksync/pool_test.go @@ -107,7 +107,7 @@ func makePeerManager(peers map[types.NodeID]testPeer) *p2p.PeerManager { return peerManager } func TestBlockPoolBasic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() start := int64(42) @@ -163,7 +163,7 @@ func TestBlockPoolBasic(t *testing.T) { } func TestBlockPoolTimeout(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() start := int64(42) @@ -213,7 +213,7 @@ func TestBlockPoolTimeout(t *testing.T) { } func TestBlockPoolRemovePeer(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peers := make(testPeers, 10) diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index 0dfd94b80..6f6be7fca 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -277,7 +277,7 @@ func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { } func TestReactor_AbruptDisconnect(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := config.ResetTestRoot(t.TempDir(), "block_sync_reactor_test") @@ -317,7 +317,7 @@ func TestReactor_AbruptDisconnect(t *testing.T) { } func TestReactor_SyncTime(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := config.ResetTestRoot(t.TempDir(), "block_sync_reactor_test") @@ -426,7 +426,7 @@ func TestAutoRestartIfBehind(t *testing.T) { blockSync: newAtomicBool(tt.isBlockSync), } - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond) defer cancel() go r.autoRestartIfBehind(ctx) diff --git a/internal/consensus/byzantine_test.go b/internal/consensus/byzantine_test.go index 03f14453b..ea8054869 100644 --- a/internal/consensus/byzantine_test.go +++ b/internal/consensus/byzantine_test.go @@ -40,7 +40,7 @@ package consensus // // kind of deadlock and hit the larger timeout. This timeout // // can be extended a bunch if needed, but it's good to avoid // // falling back to a much coarser timeout -// ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) +// ctx, cancel := context.WithTimeout(t.Context(), 20*time.Second) // defer cancel() // // config := configSetup(t) diff --git a/internal/consensus/invalid_test.go b/internal/consensus/invalid_test.go index 53598c661..b33b0037d 100644 --- a/internal/consensus/invalid_test.go +++ b/internal/consensus/invalid_test.go @@ -21,7 +21,7 @@ import ( ) func TestReactorInvalidPrecommit(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() config := configSetup(t) diff --git a/internal/consensus/mempool_test.go b/internal/consensus/mempool_test.go index cb807b44e..28f7199c9 100644 --- a/internal/consensus/mempool_test.go +++ b/internal/consensus/mempool_test.go @@ -32,7 +32,7 @@ func assertMempool(t *testing.T, txn txNotifier) mempool.Mempool { } func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() baseConfig := configSetup(t) @@ -62,7 +62,7 @@ func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { baseConfig := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config, err := ResetConfig(t.TempDir(), "consensus_mempool_txs_available_test") @@ -88,7 +88,7 @@ func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { func TestMempoolProgressInHigherRound(t *testing.T) { baseConfig := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config, err := ResetConfig(t.TempDir(), "consensus_mempool_txs_available_test") @@ -144,7 +144,7 @@ func checkTxsRange(ctx context.Context, t *testing.T, cs *State, start, end int) } func TestMempoolTxConcurrentWithCommit(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -184,7 +184,7 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { func TestMempoolRmBadTx(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() state, privVals := makeGenesisState(ctx, t, config, genesisStateArgs{ diff --git a/internal/consensus/msgs_test.go b/internal/consensus/msgs_test.go index 4d1c04338..3c2790764 100644 --- a/internal/consensus/msgs_test.go +++ b/internal/consensus/msgs_test.go @@ -26,7 +26,7 @@ import ( ) func TestMsgToProto(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() psh := types.PartSetHeader{ diff --git a/internal/consensus/pbts_test.go b/internal/consensus/pbts_test.go index 76f71cc7f..5123df930 100644 --- a/internal/consensus/pbts_test.go +++ b/internal/consensus/pbts_test.go @@ -338,7 +338,7 @@ func (hr heightResult) isComplete() bool { // until after the genesis time has passed. The test sets the genesis time in the // future and then ensures that the observed validator waits to propose a block. func TestProposerWaitsForGenesisTime(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // create a genesis time far (enough) in the future. @@ -369,7 +369,7 @@ func TestProposerWaitsForGenesisTime(t *testing.T) { // and then verifies that the observed validator waits until after the block time // of height 4 to propose a block at height 5. func TestProposerWaitsForPreviousBlock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() initialTime := time.Now().Add(time.Millisecond * 50) cfg := pbtsTestConfiguration{ @@ -436,7 +436,7 @@ func TestProposerWaitTime(t *testing.T) { } func TestTimelyProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() initialTime := time.Now() @@ -458,7 +458,7 @@ func TestTimelyProposal(t *testing.T) { } func TestTooFarInThePastProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // localtime > proposedBlockTime + MsgDelay + Precision @@ -479,7 +479,7 @@ func TestTooFarInThePastProposal(t *testing.T) { } func TestTooFarInTheFutureProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // localtime < proposedBlockTime - Precision diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index 354d0f241..dc867a722 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -174,7 +174,7 @@ func waitForAndValidateBlock( ) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Minute) defer cancel() fn := func(j int) { @@ -232,7 +232,7 @@ func waitForAndValidateBlockWithTx( ) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Minute) defer cancel() fn := func(j int) { @@ -347,7 +347,7 @@ func ensureBlockSyncStatus(t *testing.T, msg tmpubsub.Message, complete bool, he } func TestReactorBasic(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), time.Minute) defer cancel() cfg := configSetup(t) @@ -438,7 +438,7 @@ func TestReactorBasic(t *testing.T) { } func TestReactorWithEvidence(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), time.Minute) defer cancel() cfg := configSetup(t) @@ -546,7 +546,7 @@ func TestReactorWithEvidence(t *testing.T) { } func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), time.Minute) defer cancel() cfg := configSetup(t) @@ -647,7 +647,7 @@ func TestSwitchToConsensusVoteExtensions(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + ctx, cancel := context.WithTimeout(t.Context(), time.Second*15) defer cancel() cs, vs := makeState(ctx, t, makeStateArgs{validators: 1}) validator := vs[0] @@ -715,7 +715,7 @@ func TestSwitchToConsensusVoteExtensions(t *testing.T) { } func TestReactorRecordsVotesAndBlockParts(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), time.Minute) defer cancel() cfg := configSetup(t) @@ -782,7 +782,7 @@ func TestReactorRecordsVotesAndBlockParts(t *testing.T) { // TODO: fix flaky test //func TestReactorVotingPowerChange(t *testing.T) { -// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) +// ctx, cancel := context.WithTimeout(t.Context(), 2*time.Minute) // defer cancel() // // cfg := configSetup(t) @@ -907,7 +907,7 @@ func TestReactorRecordsVotesAndBlockParts(t *testing.T) { //} func TestReactorValidatorSetChanges(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Minute) defer cancel() cfg := configSetup(t) diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index 253a8c5d4..3f20da2d1 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -136,7 +136,7 @@ func TestWALCrash(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() consensusReplayConfig, err := ResetConfig(t.TempDir(), tc.name) @@ -596,7 +596,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { // Sync from scratch func TestHandshakeReplayAll(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() sim := setupSimulator(ctx, t) @@ -613,7 +613,7 @@ func TestHandshakeReplayAll(t *testing.T) { // Sync many, not from scratch func TestHandshakeReplaySome(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() sim := setupSimulator(ctx, t) @@ -630,7 +630,7 @@ func TestHandshakeReplaySome(t *testing.T) { // Sync from lagging by one func TestHandshakeReplayOne(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() sim := setupSimulator(ctx, t) @@ -645,7 +645,7 @@ func TestHandshakeReplayOne(t *testing.T) { // Sync from caught up func TestHandshakeReplayNone(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() sim := setupSimulator(ctx, t) @@ -956,7 +956,7 @@ func TestHandshakeErrorsIfAppReturnsWrongAppHash(t *testing.T) { // - 0x02 // - 0x03 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := ResetConfig(t.TempDir(), "handshake_test_") @@ -1244,7 +1244,7 @@ func (bs *mockBlockStore) DeleteLatestBlock() error { return nil } // Test handshake/init chain func TestHandshakeUpdatesValidators(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index a21d4e25d..5d9d8c405 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -69,7 +69,7 @@ x * TestHalt1 - if we see +2/3 precommits after timing out into new round, we sh // ProposeSuite func TestStateProposerSelection0(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -114,7 +114,7 @@ func TestStateProposerSelection0(t *testing.T) { func TestStateProposerSelection2(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) // test needs more work for more than 3 validators height := cs1.roundState.Height() @@ -151,7 +151,7 @@ func TestStateProposerSelection2(t *testing.T) { // a non-validator should timeout into the prevote round func TestStateEnterProposeNoPrivValidator(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) @@ -174,7 +174,7 @@ func TestStateEnterProposeNoPrivValidator(t *testing.T) { // a validator should not timeout of the prevote round (TODO: unless the block is really big!) func TestStateEnterProposeYesPrivValidator(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) @@ -208,7 +208,7 @@ func TestStateEnterProposeYesPrivValidator(t *testing.T) { func TestStateBadProposal(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -271,7 +271,7 @@ func TestStateBadProposal(t *testing.T) { func TestStateOversizedBlock(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -336,7 +336,7 @@ func TestStateOversizedBlock(t *testing.T) { // propose, prevote, and precommit a block func TestStateFullRound1(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) @@ -366,7 +366,7 @@ func TestStateFullRound1(t *testing.T) { // nil is proposed, so prevote and precommit nil func TestStateFullRoundNil(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) @@ -385,7 +385,7 @@ func TestStateFullRoundNil(t *testing.T) { // where the first validator has to wait for votes from the second func TestStateFullRound2(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -429,7 +429,7 @@ func TestStateFullRound2(t *testing.T) { // two vals take turns proposing. val1 locks on first one, precommits nil on everything else func TestStateLock_NoPOL(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -635,7 +635,7 @@ func TestStateLock_POLUpdateLock(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, logger: logger}) @@ -741,7 +741,7 @@ func TestStateLock_POLUpdateLock(t *testing.T) { // it receives votes representing over 2/3 of the voting power on the network // for a block that it is already locked in. func TestStateLock_POLRelock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -840,7 +840,7 @@ func TestStateLock_POLRelock(t *testing.T) { // TestStateLock_PrevoteNilWhenLockedAndMissProposal tests that a validator prevotes nil // if it is locked on a block and misses the proposal in a round. func TestStateLock_PrevoteNilWhenLockedAndMissProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -850,7 +850,7 @@ func TestStateLock_PrevoteNilWhenLockedAndMissProposal(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + pv1, err := cs1.privValidator.GetPubKey(t.Context()) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -920,7 +920,7 @@ func TestStateLock_PrevoteNilWhenLockedAndMissProposal(t *testing.T) { // TestStateLock_PrevoteNilWhenLockedAndMissProposal tests that a validator prevotes nil // if it is locked on a block and misses the proposal in a round. func TestStateLock_PrevoteNilWhenLockedAndDifferentProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() config := configSetup(t) @@ -937,7 +937,7 @@ func TestStateLock_PrevoteNilWhenLockedAndDifferentProposal(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + pv1, err := cs1.privValidator.GetPubKey(t.Context()) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1022,7 +1022,7 @@ func TestStateLock_PrevoteNilWhenLockedAndDifferentProposal(t *testing.T) { func TestStateLock_POLDoesNotUnlock(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() /* All of the assertions in this test occur on the `cs1` validator. @@ -1039,7 +1039,7 @@ func TestStateLock_POLDoesNotUnlock(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) lockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryLock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + pv1, err := cs1.privValidator.GetPubKey(t.Context()) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1159,7 +1159,7 @@ func TestStateLock_POLDoesNotUnlock(t *testing.T) { func TestStateLock_MissingProposalWhenPOLSeenDoesNotUpdateLock(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, logger: logger}) @@ -1246,7 +1246,7 @@ func TestStateLock_MissingProposalWhenPOLSeenDoesNotUpdateLock(t *testing.T) { // block if a proposal was not seen for that block in the current round, but // was seen in a previous round. func TestStateLock_DoesNotLockOnOldProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -1256,7 +1256,7 @@ func TestStateLock_DoesNotLockOnOldProposal(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + pv1, err := cs1.privValidator.GetPubKey(t.Context()) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1323,7 +1323,7 @@ func TestStateLock_DoesNotLockOnOldProposal(t *testing.T) { func TestStateLock_POLSafety1(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, logger: logger}) @@ -1439,7 +1439,7 @@ func TestStateLock_POLSafety1(t *testing.T) { // dont see P0, lock on P1 at R1, dont unlock using P0 at R2 func TestStateLock_POLSafety2(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -1535,7 +1535,7 @@ func TestStateLock_POLSafety2(t *testing.T) { // for a block if it is locked on a different block but saw a POL for the block // it is not locked on in a previous round. func TestState_PrevotePOLFromPreviousRound(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) logger := log.NewNopLogger() @@ -1548,7 +1548,7 @@ func TestState_PrevotePOLFromPreviousRound(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + pv1, err := cs1.privValidator.GetPubKey(t.Context()) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1681,7 +1681,7 @@ func TestState_PrevotePOLFromPreviousRound(t *testing.T) { // P0 proposes B0 at R3. func TestProposeValidBlock(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -1773,7 +1773,7 @@ func TestProposeValidBlock(t *testing.T) { // P0 miss to lock B but set valid block to B after receiving delayed prevote. func TestSetValidBlockOnDelayedPrevote(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -1842,7 +1842,7 @@ func TestSetValidBlockOnDelayedPrevote(t *testing.T) { // receiving delayed Block Proposal. func TestSetValidBlockOnDelayedProposal(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -1921,7 +1921,7 @@ func TestProcessProposalAccept(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() m := abcimocks.NewApplication(t) @@ -1974,7 +1974,7 @@ func TestFinalizeBlockCalled(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() m := abcimocks.NewApplication(t) @@ -2057,7 +2057,7 @@ func TestExtendVoteCalledWhenEnabled(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() m := abcimocks.NewApplication(t) @@ -2144,7 +2144,7 @@ func TestExtendVoteCalledWhenEnabled(t *testing.T) { // method is not called for a validator's vote that is never delivered. func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() m := abcimocks.NewApplication(t) @@ -2214,7 +2214,7 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { // is the proposer again and ensures that the mock application receives the set of // vote extensions from the previous consensus instance. func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -2353,7 +2353,7 @@ func TestVoteExtensionEnableHeight(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() numValidators := 3 @@ -2431,7 +2431,7 @@ func TestVoteExtensionEnableHeight(t *testing.T) { // What we want: // P0 waits for timeoutPrecommit before starting next round func TestWaitingTimeoutOnNilPolka(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() config := configSetup(t) @@ -2457,7 +2457,7 @@ func TestWaitingTimeoutOnNilPolka(t *testing.T) { // P0 waits for timeoutPropose in the next round before entering prevote func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2495,7 +2495,7 @@ func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { // P0 jump to higher round, precommit and start precommit wait func TestRoundSkipOnNilPolkaFromHigherRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2535,7 +2535,7 @@ func TestRoundSkipOnNilPolkaFromHigherRound(t *testing.T) { // P0 wait for timeoutPropose to expire before sending prevote. func TestWaitTimeoutProposeOnNilPolkaForTheCurrentRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2565,7 +2565,7 @@ func TestWaitTimeoutProposeOnNilPolkaForTheCurrentRound(t *testing.T) { // P0 emit NewValidBlock event upon receiving 2/3+ Precommit for B but hasn't received block B yet func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2607,7 +2607,7 @@ func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { // After receiving block, it executes block and moves to the next height. func TestCommitFromPreviousRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2668,7 +2668,7 @@ func (n *fakeTxNotifier) Notify() { // start of the next round func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2734,7 +2734,7 @@ func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2804,7 +2804,7 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { // we receive a final precommit after going into next round, but others might have gone to commit already! func TestStateHalt1(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2877,7 +2877,7 @@ func TestStateHalt1(t *testing.T) { func TestStateOutputsBlockPartsStats(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // create dummy peer @@ -2925,7 +2925,7 @@ func TestStateOutputsBlockPartsStats(t *testing.T) { func TestGossipTransactionKeyOnlyConfig(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -2968,7 +2968,7 @@ func TestGossipTransactionKeyOnlyConfig(t *testing.T) { func TestStateOutputVoteStats(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -3009,7 +3009,7 @@ func TestStateOutputVoteStats(t *testing.T) { func TestSignSameVoteTwice(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() _, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -3049,7 +3049,7 @@ func TestSignSameVoteTwice(t *testing.T) { // corresponding proposal message. func TestStateTimestamp_ProposalNotMatch(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -3099,7 +3099,7 @@ func TestStateTimestamp_ProposalNotMatch(t *testing.T) { // corresponding proposal message. func TestStateTimestamp_ProposalMatch(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -3192,7 +3192,7 @@ func signAddPrecommitWithExtension(ctx context.Context, func TestAddProposalBlockPartMemoryLimit(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, _ := makeState(ctx, t, makeStateArgs{config: config}) @@ -3237,7 +3237,7 @@ func TestAddProposalBlockPartMemoryLimit(t *testing.T) { func TestAddProposalBlockPartWrongHeight(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, _ := makeState(ctx, t, makeStateArgs{config: config}) @@ -3267,7 +3267,7 @@ func TestAddProposalBlockPartWrongHeight(t *testing.T) { func TestAddProposalBlockPartNilProposalBlockParts(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cs1, _ := makeState(ctx, t, makeStateArgs{config: config}) diff --git a/internal/consensus/types/height_vote_set_test.go b/internal/consensus/types/height_vote_set_test.go index a2cfd84ec..1b0d70b89 100644 --- a/internal/consensus/types/height_vote_set_test.go +++ b/internal/consensus/types/height_vote_set_test.go @@ -21,7 +21,7 @@ func TestPeerCatchupRounds(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() valSet, privVals := factory.ValidatorSet(ctx, t, 10, 1) diff --git a/internal/consensus/wal_test.go b/internal/consensus/wal_test.go index 3aec25093..fe7573606 100644 --- a/internal/consensus/wal_test.go +++ b/internal/consensus/wal_test.go @@ -28,7 +28,7 @@ func TestWALTruncate(t *testing.T) { walFile := filepath.Join(walDir, "wal") logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // this magic number 4K can truncate the content when RotateFile. @@ -105,7 +105,7 @@ func TestWALWrite(t *testing.T) { walDir := t.TempDir() walFile := filepath.Join(walDir, "wal") - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() wal, err := NewWAL(ctx, log.NewNopLogger(), walFile) @@ -138,7 +138,7 @@ func TestWALWrite(t *testing.T) { } func TestWALSearchForEndHeight(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -170,7 +170,7 @@ func TestWALSearchForEndHeight(t *testing.T) { } func TestWALPeriodicSync(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() walDir := t.TempDir() diff --git a/internal/dbsync/syncer_test.go b/internal/dbsync/syncer_test.go index 0f2787d55..c728da005 100644 --- a/internal/dbsync/syncer_test.go +++ b/internal/dbsync/syncer_test.go @@ -45,7 +45,7 @@ func getTestSyncer(t *testing.T) *Syncer { func TestSetMetadata(t *testing.T) { syncer := getTestSyncer(t) // initial - syncer.SetMetadata(context.Background(), types.NodeID("someone"), &dbsync.MetadataResponse{ + syncer.SetMetadata(t.Context(), types.NodeID("someone"), &dbsync.MetadataResponse{ Height: 1, Hash: []byte("hash"), Filenames: []string{"f1"}, @@ -58,7 +58,7 @@ func TestSetMetadata(t *testing.T) { require.Equal(t, 1, len(syncer.peersToSync)) // second time - syncer.SetMetadata(context.Background(), types.NodeID("someone else"), &dbsync.MetadataResponse{ + syncer.SetMetadata(t.Context(), types.NodeID("someone else"), &dbsync.MetadataResponse{ Height: 1, Hash: []byte("hash"), Filenames: []string{"f1"}, @@ -75,7 +75,7 @@ func TestFileProcessHappyPath(t *testing.T) { syncer := getTestSyncer(t) data := []byte("data") sum := md5.Sum(data) - syncer.SetMetadata(context.Background(), types.NodeID("someone"), &dbsync.MetadataResponse{ + syncer.SetMetadata(t.Context(), types.NodeID("someone"), &dbsync.MetadataResponse{ Height: 1, Hash: []byte("hash"), Filenames: []string{"f1"}, @@ -94,7 +94,7 @@ func TestFileProcessHappyPath(t *testing.T) { Filename: "f1", Data: data, }) - syncer.Process(context.Background()) + syncer.Process(t.Context()) } func TestFileProcessTimeoutReprocess(t *testing.T) { @@ -102,7 +102,7 @@ func TestFileProcessTimeoutReprocess(t *testing.T) { syncer := getTestSyncer(t) data := []byte("data") sum := md5.Sum(data) - syncer.SetMetadata(context.Background(), types.NodeID("someone"), &dbsync.MetadataResponse{ + syncer.SetMetadata(t.Context(), types.NodeID("someone"), &dbsync.MetadataResponse{ Height: 1, Hash: []byte("hash"), Filenames: []string{"f1"}, @@ -122,5 +122,5 @@ func TestFileProcessTimeoutReprocess(t *testing.T) { Filename: "f1", Data: data, }) - syncer.Process(context.Background()) + syncer.Process(t.Context()) } diff --git a/internal/eventbus/event_bus_test.go b/internal/eventbus/event_bus_test.go index 1a857314e..c1f335fa2 100644 --- a/internal/eventbus/event_bus_test.go +++ b/internal/eventbus/event_bus_test.go @@ -19,7 +19,7 @@ import ( ) func TestEventBusPublishEventTx(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) @@ -73,7 +73,7 @@ func TestEventBusPublishEventTx(t *testing.T) { } func TestEventBusPublishEventNewBlock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -127,7 +127,7 @@ func TestEventBusPublishEventNewBlock(t *testing.T) { } func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -244,7 +244,7 @@ func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { } func TestEventBusPublishEventNewBlockHeader(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) @@ -294,7 +294,7 @@ func TestEventBusPublishEventNewBlockHeader(t *testing.T) { } func TestEventBusPublishEventEvidenceValidated(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) @@ -336,7 +336,7 @@ func TestEventBusPublishEventEvidenceValidated(t *testing.T) { } func TestEventBusPublishEventNewEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) @@ -378,7 +378,7 @@ func TestEventBusPublishEventNewEvidence(t *testing.T) { } func TestEventBusPublish(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) @@ -397,7 +397,7 @@ func TestEventBusPublish(t *testing.T) { count := make(chan int, 1) go func() { defer close(count) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() for n := 0; ; n++ { @@ -464,7 +464,7 @@ func benchmarkEventBus(numClients int, randQueries bool, randEvents bool, b *tes // for random* functions mrand.Seed(time.Now().Unix()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() eventBus := eventbus.NewDefault(log.NewNopLogger()) // set buffer capacity to 0 so we are not testing cache diff --git a/internal/eventlog/eventlog_test.go b/internal/eventlog/eventlog_test.go index 417b7d5d0..8db70e966 100644 --- a/internal/eventlog/eventlog_test.go +++ b/internal/eventlog/eventlog_test.go @@ -104,7 +104,7 @@ func TestConcurrent(t *testing.T) { t.Fatalf("New unexpectedly failed: %v", err) } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() var wg sync.WaitGroup diff --git a/internal/evidence/pool_test.go b/internal/evidence/pool_test.go index 8728b2da1..9a8e820ae 100644 --- a/internal/evidence/pool_test.go +++ b/internal/evidence/pool_test.go @@ -54,7 +54,7 @@ func TestEvidencePoolBasic(t *testing.T) { blockStore = &mocks.BlockStore{} ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() valSet, privVals := factory.ValidatorSet(ctx, t, 1, 10) blockStore.On("LoadBlockMeta", mock.AnythingOfType("int64")).Return( @@ -111,7 +111,7 @@ func TestEvidencePoolBasic(t *testing.T) { // Tests inbound evidence for the right time and height func TestAddExpiredEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() var ( @@ -156,7 +156,7 @@ func TestAddExpiredEvidence(t *testing.T) { tc := tc t.Run(tc.evDescription, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() ev, err := types.NewMockDuplicateVoteEvidenceWithValidator(ctx, tc.evHeight, tc.evTime, val, evidenceChainID) @@ -174,7 +174,7 @@ func TestAddExpiredEvidence(t *testing.T) { func TestReportConflictingVotes(t *testing.T) { var height int64 = 10 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() pool, pv, _ := defaultTestPool(ctx, t, height) @@ -214,7 +214,7 @@ func TestReportConflictingVotes(t *testing.T) { func TestEvidencePoolUpdate(t *testing.T) { height := int64(21) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() pool, val, _ := defaultTestPool(ctx, t, height) @@ -284,7 +284,7 @@ func TestEvidencePoolUpdate(t *testing.T) { func TestVerifyPendingEvidencePasses(t *testing.T) { var height int64 = 1 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() pool, val, _ := defaultTestPool(ctx, t, height) @@ -304,7 +304,7 @@ func TestVerifyPendingEvidencePasses(t *testing.T) { func TestVerifyDuplicatedEvidenceFails(t *testing.T) { var height int64 = 1 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() pool, val, _ := defaultTestPool(ctx, t, height) @@ -328,7 +328,7 @@ func TestVerifyDuplicatedEvidenceFails(t *testing.T) { func TestEventOnEvidenceValidated(t *testing.T) { const height = 1 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() pool, val, eventBus := defaultTestPool(ctx, t, height) @@ -379,7 +379,7 @@ func TestLightClientAttackEvidenceLifecycle(t *testing.T) { height int64 = 100 commonHeight int64 = 90 ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() ev, trusted, common := makeLunaticEvidence(ctx, t, height, commonHeight, @@ -443,7 +443,7 @@ func TestLightClientAttackEvidenceLifecycle(t *testing.T) { // Tests that restarting the evidence pool after a potential failure will recover the // pending evidence and continue to gossip it func TestRecoverPendingEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() height := int64(10) diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index eaed82f2e..97c7bb03c 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -231,7 +231,7 @@ func createEvidenceList( } func TestReactorMultiDisconnect(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() val := types.NewMockPV() @@ -271,7 +271,7 @@ func TestReactorMultiDisconnect(t *testing.T) { func TestReactorBroadcastEvidence(t *testing.T) { numPeers := 7 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // create a stateDB for all test suites (nodes) @@ -335,7 +335,7 @@ func TestReactorBroadcastEvidence_Lagging(t *testing.T) { height1 := int64(numEvidence) + 10 height2 := int64(numEvidence) / 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // stateDB1 is ahead of stateDB2, where stateDB1 has all heights (1-20) and @@ -371,7 +371,7 @@ func TestReactorBroadcastEvidence_Pending(t *testing.T) { val := types.NewMockPV() height := int64(10) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() stateDB1 := initializeValidatorState(ctx, t, val, height) @@ -412,7 +412,7 @@ func TestReactorBroadcastEvidence_Committed(t *testing.T) { val := types.NewMockPV() height := int64(10) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() stateDB1 := initializeValidatorState(ctx, t, val, height) @@ -467,7 +467,7 @@ func TestReactorBroadcastEvidence_FullyConnected(t *testing.T) { stateDBs := make([]sm.Store, numPeers) val := types.NewMockPV() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // We need all validators saved for heights at least as high as we have diff --git a/internal/evidence/verify_test.go b/internal/evidence/verify_test.go index 6341fde1d..c83ed90c5 100644 --- a/internal/evidence/verify_test.go +++ b/internal/evidence/verify_test.go @@ -33,7 +33,7 @@ func TestVerifyLightClientAttack_Lunatic(t *testing.T) { totalVals = 10 byzVals = 4 ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() attackTime := defaultEvidenceTime.Add(1 * time.Hour) @@ -74,7 +74,7 @@ func TestVerify_LunaticAttackAgainstState(t *testing.T) { totalVals = 10 byzVals = 4 ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -149,7 +149,7 @@ func TestVerify_ForwardLunaticAttack(t *testing.T) { ) attackTime := defaultEvidenceTime.Add(1 * time.Hour) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -205,7 +205,7 @@ func TestVerify_ForwardLunaticAttack(t *testing.T) { } func TestVerifyLightClientAttack_Equivocation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -307,7 +307,7 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { } func TestVerifyLightClientAttack_Amnesia(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -407,7 +407,7 @@ type voteData struct { } func TestVerifyDuplicateVoteEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 059e41e88..670da2ecf 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -53,7 +53,7 @@ func TestInspectRun(t *testing.T) { logger := testLogger.With(t.Name()) d, err := inspect.NewFromConfig(logger, cfg) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) stoppedWG := &sync.WaitGroup{} stoppedWG.Add(1) go func() { @@ -86,7 +86,7 @@ func TestBlock(t *testing.T) { rpcConfig := config.TestRPCConfig() l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -133,7 +133,7 @@ func TestTxSearch(t *testing.T) { rpcConfig := config.TestRPCConfig() l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -180,7 +180,7 @@ func TestTx(t *testing.T) { rpcConfig := config.TestRPCConfig() l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -229,7 +229,7 @@ func TestConsensusParams(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -280,7 +280,7 @@ func TestBlockResults(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -328,7 +328,7 @@ func TestCommit(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -382,7 +382,7 @@ func TestBlockByHash(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -435,7 +435,7 @@ func TestBlockchain(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -488,7 +488,7 @@ func TestValidators(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) @@ -547,7 +547,7 @@ func TestBlockSearch(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) wg := &sync.WaitGroup{} wg.Add(1) diff --git a/internal/libs/autofile/autofile_test.go b/internal/libs/autofile/autofile_test.go index 9dbba276a..34c265eab 100644 --- a/internal/libs/autofile/autofile_test.go +++ b/internal/libs/autofile/autofile_test.go @@ -13,7 +13,7 @@ import ( ) func TestSIGHUP(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() origDir, err := os.Getwd() @@ -102,7 +102,7 @@ func TestSIGHUP(t *testing.T) { // } func TestAutoFileSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // First, create an AutoFile writing to a tempfile dir diff --git a/internal/libs/autofile/group_test.go b/internal/libs/autofile/group_test.go index 4f5e346c2..c67b8ac52 100644 --- a/internal/libs/autofile/group_test.go +++ b/internal/libs/autofile/group_test.go @@ -44,7 +44,7 @@ func assertGroupInfo(t *testing.T, gInfo GroupInfo, minIndex, maxIndex int, tota } func TestCheckHeadSizeLimit(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -118,7 +118,7 @@ func TestCheckHeadSizeLimit(t *testing.T) { func TestRotateFile(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -183,7 +183,7 @@ func TestRotateFile(t *testing.T) { func TestWrite(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -211,7 +211,7 @@ func TestWrite(t *testing.T) { func TestGroupReaderRead(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -249,7 +249,7 @@ func TestGroupReaderRead(t *testing.T) { func TestGroupReaderRead2(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -288,7 +288,7 @@ func TestGroupReaderRead2(t *testing.T) { func TestMinIndex(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -301,7 +301,7 @@ func TestMinIndex(t *testing.T) { func TestMaxIndex(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) diff --git a/internal/libs/queue/queue_test.go b/internal/libs/queue/queue_test.go index 08ecc3955..f0c08bc62 100644 --- a/internal/libs/queue/queue_test.go +++ b/internal/libs/queue/queue_test.go @@ -125,14 +125,14 @@ func TestClose(t *testing.T) { } func TestWait(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() q := mustQueue(t, Options{SoftQuota: 2, HardLimit: 2}) // A wait on an empty queue should time out. t.Run("WaitTimeout", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) defer cancel() got, err := q.Wait(ctx) if err == nil { @@ -144,7 +144,7 @@ func TestWait(t *testing.T) { // A wait on a non-empty queue should report an item. t.Run("WaitNonEmpty", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() const input = "figgy pudding" diff --git a/internal/mempool/mempool_bench_test.go b/internal/mempool/mempool_bench_test.go index 14fb22197..a16b0a8e9 100644 --- a/internal/mempool/mempool_bench_test.go +++ b/internal/mempool/mempool_bench_test.go @@ -15,7 +15,7 @@ import ( ) func BenchmarkTxMempool_CheckTx(b *testing.B) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), kvstore.NewApplication()) diff --git a/internal/mempool/mempool_test.go b/internal/mempool/mempool_test.go index 3f6220765..c9795151d 100644 --- a/internal/mempool/mempool_test.go +++ b/internal/mempool/mempool_test.go @@ -233,7 +233,7 @@ func (e *TestPeerEvictor) Errored(peerID types.NodeID, err error) { } func TestTxMempool_TxsAvailable(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -296,7 +296,7 @@ func TestTxMempool_TxsAvailable(t *testing.T) { } func TestTxMempool_Size(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -330,7 +330,7 @@ func TestTxMempool_Size(t *testing.T) { } func TestTxMempool_Flush(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -364,7 +364,7 @@ func TestTxMempool_Flush(t *testing.T) { } func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() gasEstimated := int64(1) // gas estimated set to 1 @@ -460,7 +460,7 @@ func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { } func TestTxMempool_ReapMaxBytesMaxGas_FallbackToGasWanted(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() gasEstimated := int64(0) // gas estimated not set so fallback to gas wanted @@ -509,7 +509,7 @@ func TestTxMempool_ReapMaxBytesMaxGas_FallbackToGasWanted(t *testing.T) { } func TestTxMempool_ReapMaxTxs(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -583,7 +583,7 @@ func TestTxMempool_ReapMaxTxs(t *testing.T) { } func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -608,7 +608,7 @@ func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { } func TestTxMempool_Prioritization(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -677,7 +677,7 @@ func TestTxMempool_Prioritization(t *testing.T) { } func TestTxMempool_PendingStoreSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -699,7 +699,7 @@ func TestTxMempool_PendingStoreSize(t *testing.T) { } func TestTxMempool_RemoveCacheWhenPendingTxIsFull(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -721,7 +721,7 @@ func TestTxMempool_RemoveCacheWhenPendingTxIsFull(t *testing.T) { } func TestTxMempool_EVMEviction(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -764,7 +764,7 @@ func TestTxMempool_EVMEviction(t *testing.T) { } func TestTxMempool_CheckTxSamePeer(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -788,7 +788,7 @@ func TestTxMempool_CheckTxSamePeer(t *testing.T) { } func TestTxMempool_CheckTxSameSender(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -819,7 +819,7 @@ func TestTxMempool_CheckTxSameSender(t *testing.T) { } func TestTxMempool_ConcurrentTxs(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -892,7 +892,7 @@ func TestTxMempool_ConcurrentTxs(t *testing.T) { } func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -951,7 +951,7 @@ func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { } func TestTxMempool_CheckTxPostCheckError(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cases := []struct { @@ -1008,7 +1008,7 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { } func TestTxMempool_FailedCheckTxCount(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -1050,7 +1050,7 @@ func TestTxMempool_FailedCheckTxCount(t *testing.T) { func TestAppendCheckTxErr(t *testing.T) { // Setup - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) @@ -1075,7 +1075,7 @@ func TestAppendCheckTxErr(t *testing.T) { } func TestMempoolExpiration(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index 9ab4fbc70..ef3d259d2 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -145,7 +145,7 @@ func (rts *reactorTestSuite) waitForTxns(t *testing.T, txs []types.Tx, ids ...ty } func TestReactorBroadcastDoesNotPanic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() const numNodes = 2 @@ -191,7 +191,7 @@ func TestReactorBroadcastDoesNotPanic(t *testing.T) { func TestReactorBroadcastTxs(t *testing.T) { numTxs := 512 numNodes := 4 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -217,7 +217,7 @@ func TestReactorConcurrency(t *testing.T) { numTxs := 10 numNodes := 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -276,7 +276,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) { numTxs := 1000 numNodes := 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -301,7 +301,7 @@ func TestReactor_MaxTxBytes(t *testing.T) { numNodes := 2 cfg := config.TestConfig() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -339,7 +339,7 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) { // we're creating a single node network, but not starting the // network. - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -392,7 +392,7 @@ func TestBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { t.Skip("skipping test in short mode") } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index d4c745d8d..f6681a399 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -205,7 +205,7 @@ func TestParseNodeAddress(t *testing.T) { func TestNodeAddress_Resolve(t *testing.T) { id := types.NodeID("00112233445566778899aabbccddeeff00112233") - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() testcases := []struct { diff --git a/internal/p2p/channel_test.go b/internal/p2p/channel_test.go index e06e3e77e..35e178670 100644 --- a/internal/p2p/channel_test.go +++ b/internal/p2p/channel_test.go @@ -33,7 +33,7 @@ func testChannel(size int) (*channelInternal, *Channel) { func TestChannel(t *testing.T) { t.Cleanup(leaktest.Check(t)) - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() testCases := []struct { diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index 5a604cd23..efac9c519 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -51,7 +51,7 @@ func TestMConnectionSendFlushStop(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() clientConn := createTestMConnection(log.NewNopLogger(), client) @@ -88,7 +88,7 @@ func TestMConnectionSend(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createTestMConnection(log.NewNopLogger(), client) @@ -135,7 +135,7 @@ func TestMConnectionReceive(t *testing.T) { } logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError) @@ -165,7 +165,7 @@ func TestMConnectionWillEventuallyTimeout(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, nil, nil) @@ -221,7 +221,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { } } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) @@ -279,7 +279,7 @@ func TestMConnectionMultiplePings(t *testing.T) { case <-ctx.Done(): } } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) @@ -336,7 +336,7 @@ func TestMConnectionPingPongs(t *testing.T) { } } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) @@ -395,7 +395,7 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { case <-ctx.Done(): } } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) @@ -466,7 +466,7 @@ func expectSend(ch chan struct{}) bool { } func TestMConnectionReadErrorBadEncoding(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chOnErr := make(chan struct{}) @@ -482,7 +482,7 @@ func TestMConnectionReadErrorBadEncoding(t *testing.T) { } func TestMConnectionReadErrorUnknownChannel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chOnErr := make(chan struct{}) @@ -504,7 +504,7 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) @@ -544,7 +544,7 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { } func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chOnErr := make(chan struct{}) @@ -560,7 +560,7 @@ func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { func TestMConnectionTrySend(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconn := createTestMConnection(log.NewNopLogger(), client) @@ -609,7 +609,7 @@ func TestMConnectionChannelOverflow(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) diff --git a/internal/p2p/peermanager_scoring_test.go b/internal/p2p/peermanager_scoring_test.go index 3c56220f8..b8034ca47 100644 --- a/internal/p2p/peermanager_scoring_test.go +++ b/internal/p2p/peermanager_scoring_test.go @@ -31,7 +31,7 @@ func TestPeerScoring(t *testing.T) { require.NoError(t, err) require.True(t, added) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Run("Synchronous", func(t *testing.T) { diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 5f4cb98f9..e0040ab73 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -194,7 +194,7 @@ func TestNewPeerManager_Persistence(t *testing.T) { }, peerManager.Scores()) // Introduce a dial failure and persistent peer score should be reduced by one - ctx, _ := context.WithCancel(context.Background()) + ctx, _ := context.WithCancel(t.Context()) peerManager.DialFailed(ctx, bAddresses[0]) require.Equal(t, map[types.NodeID]p2p.PeerScore{ aID: 0, @@ -344,7 +344,7 @@ func TestPeerManager_Add(t *testing.T) { } func TestPeerManager_DialNext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -370,7 +370,7 @@ func TestPeerManager_DialNext(t *testing.T) { } func TestPeerManager_DialNext_Retry(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -405,7 +405,7 @@ func TestPeerManager_DialNext_Retry(t *testing.T) { } func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -444,7 +444,7 @@ func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ @@ -489,7 +489,7 @@ func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() options := p2p.PeerManagerOptions{MinRetryTime: 200 * time.Millisecond} @@ -519,7 +519,7 @@ func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDisconnected(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -588,7 +588,7 @@ func TestPeerManager_TryDialNext_MaxConnected(t *testing.T) { } func TestPeerManager_TryDialNext_MaxConnectedUpgrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -759,7 +759,7 @@ func TestPeerManager_TryDialNext_DialingConnected(t *testing.T) { } func TestPeerManager_TryDialNext_Multiple(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() aID := types.NodeID(strings.Repeat("a", 40)) @@ -811,7 +811,7 @@ func TestPeerManager_DialFailed(t *testing.T) { require.NoError(t, err) require.True(t, added) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Dialing and then calling DialFailed with a different address (same @@ -838,7 +838,7 @@ func TestPeerManager_DialFailed(t *testing.T) { } func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1059,7 +1059,7 @@ func TestPeerManager_Dialed_Upgrade(t *testing.T) { } func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1116,7 +1116,7 @@ func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { } func TestPeerManager_Dialed_UpgradeNoEvict(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1287,7 +1287,7 @@ func TestPeerManager_Accepted_MaxConnectedUpgrade(t *testing.T) { } func TestPeerManager_Accepted_Upgrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1383,7 +1383,7 @@ func TestPeerManager_Ready(t *testing.T) { a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -1417,7 +1417,7 @@ func TestPeerManager_Ready(t *testing.T) { } func TestPeerManager_Ready_Channels(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() pm, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -1441,7 +1441,7 @@ func TestPeerManager_Ready_Channels(t *testing.T) { // See TryEvictNext for most tests, this just tests blocking behavior. func TestPeerManager_EvictNext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1477,7 +1477,7 @@ func TestPeerManager_EvictNext(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1506,7 +1506,7 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1547,7 +1547,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1581,7 +1581,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { require.Equal(t, a.NodeID, evict) } func TestPeerManager_TryEvictNext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1625,7 +1625,7 @@ func TestPeerManager_Disconnected(t *testing.T) { peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() sub := peerManager.Subscribe(ctx) @@ -1676,7 +1676,7 @@ func TestPeerManager_Disconnected(t *testing.T) { } func TestPeerManager_Errored(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1715,7 +1715,7 @@ func TestPeerManager_Errored(t *testing.T) { } func TestPeerManager_Subscribe(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1778,7 +1778,7 @@ func TestPeerManager_Subscribe(t *testing.T) { } func TestPeerManager_Subscribe_Close(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1805,7 +1805,7 @@ func TestPeerManager_Subscribe_Close(t *testing.T) { } func TestPeerManager_Subscribe_Broadcast(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Cleanup(leaktest.Check(t)) @@ -1855,7 +1855,7 @@ func TestPeerManager_Close(t *testing.T) { // leaktest will check that spawned goroutines are closed. t.Cleanup(leaktest.CheckTimeout(t, 1*time.Second)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 33778b19d..67d3f5a59 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -33,7 +33,7 @@ const ( ) func TestReactorBasic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // start a network with one mock reactor and one "real" reactor testNet := setupNetwork(ctx, t, testOptions{ @@ -53,7 +53,7 @@ func TestReactorBasic(t *testing.T) { } func TestReactorConnectFullNetwork(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testNet := setupNetwork(ctx, t, testOptions{ @@ -72,7 +72,7 @@ func TestReactorConnectFullNetwork(t *testing.T) { } func TestReactorSendsRequestsTooOften(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() r := setupSingle(ctx, t) @@ -103,7 +103,7 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { func TestReactorSendsResponseWithoutRequest(t *testing.T) { t.Skip("This test needs updated https://github.com/tendermint/tendermint/issue/7634") - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testNet := setupNetwork(ctx, t, testOptions{ @@ -125,7 +125,7 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) { func TestReactorNeverSendsTooManyPeers(t *testing.T) { t.Skip("This test needs updated https://github.com/tendermint/tendermint/issue/7634") - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testNet := setupNetwork(ctx, t, testOptions{ @@ -148,7 +148,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { } func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() r := setupSingle(ctx, t) @@ -195,7 +195,7 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { } func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testNet := setupNetwork(ctx, t, testOptions{ @@ -219,7 +219,7 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { } func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testNet := setupNetwork(ctx, t, testOptions{ @@ -240,7 +240,7 @@ func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { func TestReactorWithNetworkGrowth(t *testing.T) { t.Skip("This test needs updated https://github.com/tendermint/tendermint/issue/7634") - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testNet := setupNetwork(ctx, t, testOptions{ diff --git a/internal/p2p/pqueue_test.go b/internal/p2p/pqueue_test.go index d1057ac7e..a9c32e95e 100644 --- a/internal/p2p/pqueue_test.go +++ b/internal/p2p/pqueue_test.go @@ -26,7 +26,7 @@ func TestCloseWhileDequeueFull(t *testing.T) { } } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() go pqueue.process(ctx) diff --git a/internal/p2p/router_filter_test.go b/internal/p2p/router_filter_test.go index 217be8d32..1c879577a 100644 --- a/internal/p2p/router_filter_test.go +++ b/internal/p2p/router_filter_test.go @@ -13,7 +13,7 @@ import ( ) func TestConnectionFiltering(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/internal/p2p/router_init_test.go b/internal/p2p/router_init_test.go index f2ae28482..5daa5fbdf 100644 --- a/internal/p2p/router_init_test.go +++ b/internal/p2p/router_init_test.go @@ -12,7 +12,7 @@ import ( ) func TestRouter_ConstructQueueFactory(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Run("ValidateOptionsPopulatesDefaultQueue", func(t *testing.T) { diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index c1674b8a0..e41df42c2 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -41,7 +41,7 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { } func TestRouter_Network(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Cleanup(leaktest.Check(t)) @@ -98,7 +98,7 @@ func TestRouter_Network(t *testing.T) { func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Set up a router with no transports (so no peers). @@ -163,7 +163,7 @@ func TestRouter_Channel_Basic(t *testing.T) { // Channel tests are hairy to mock, so we use an in-memory network instead. func TestRouter_Channel_SendReceive(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Cleanup(leaktest.Check(t)) @@ -227,7 +227,7 @@ func TestRouter_Channel_SendReceive(t *testing.T) { func TestRouter_Channel_Broadcast(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Create a test network and open a channel on all nodes. @@ -258,7 +258,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) { func TestRouter_Channel_Wrapper(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Create a test network and open a channel on all nodes. @@ -328,7 +328,7 @@ func (w *wrapperMessage) Unwrap() (proto.Message, error) { func TestRouter_Channel_Error(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Create a test network and open a channel on all nodes. @@ -371,7 +371,7 @@ func TestRouter_AcceptPeers(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() for name, tc := range testcases { @@ -383,7 +383,7 @@ func TestRouter_AcceptPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) // Set up a mock transport that handshakes. - connCtx, connCancel := context.WithCancel(context.Background()) + connCtx, connCancel := context.WithCancel(t.Context()) mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). @@ -450,7 +450,7 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { t.Run(err.Error(), func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Set up a mock transport that returns io.EOF once, which should prevent @@ -492,7 +492,7 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Set up a mock transport that returns a connection that blocks during the @@ -573,7 +573,7 @@ func TestRouter_DialPeers(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() for name, tc := range testcases { @@ -587,7 +587,7 @@ func TestRouter_DialPeers(t *testing.T) { endpoint := &p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} // Set up a mock transport that handshakes. - connCtx, connCancel := context.WithCancel(context.Background()) + connCtx, connCancel := context.WithCancel(t.Context()) defer connCancel() mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") @@ -665,7 +665,7 @@ func TestRouter_DialPeers(t *testing.T) { func TestRouter_DialPeers_Parallel(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() a := p2p.NodeAddress{Protocol: "mock", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -754,7 +754,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func TestRouter_EvictPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Set up a mock transport that we can evict. @@ -820,7 +820,7 @@ func TestRouter_EvictPeers(t *testing.T) { func TestRouter_ChannelCompatability(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() incompatiblePeer := types.NodeInfo{ @@ -872,7 +872,7 @@ func TestRouter_ChannelCompatability(t *testing.T) { func TestRouter_DontSendOnInvalidChannel(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peer := types.NodeInfo{ @@ -940,7 +940,7 @@ func TestRouter_Channel_FilterByID(t *testing.T) { t.Cleanup(leaktest.Check(t)) // Set up a mock transport that handshakes. - connCtx, connCancel := context.WithCancel(context.Background()) + connCtx, connCancel := context.WithCancel(t.Context()) defer connCancel() peer := types.NodeInfo{ NodeID: peerID, diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 43c4066e5..96ea4a77d 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -7,7 +7,7 @@ import ( ) func TestSimpleQueue(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // set up a small queue with very small buffers so we can diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index c478dbe1d..f550832c6 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -50,7 +50,7 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) { t.Cleanup(func() { _ = transport.Close() }) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() _, err := transport.Accept(ctx) @@ -59,7 +59,7 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) { } func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() transport := p2p.NewMConnTransport( @@ -129,7 +129,7 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { } func TestMConnTransport_Listen(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() testcases := []struct { @@ -195,7 +195,7 @@ func TestMConnTransport_Listen(t *testing.T) { go func() { // Dialing the endpoint should work. var err error - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerConn, err = transport.Dial(ctx, endpoint) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index b4edf9bc9..b461e2d0d 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -41,7 +41,7 @@ func withTransports(ctx context.Context, t *testing.T, tester func(context.Conte func TestTransport_AcceptClose(t *testing.T) { // Just test accept unblock on close, happy path is tested widely elsewhere. - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -82,7 +82,7 @@ func TestTransport_DialEndpoints(t *testing.T) { {[]byte{1, 2, 3, 4, 5}, false}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -160,7 +160,7 @@ func TestTransport_DialEndpoints(t *testing.T) { } func TestTransport_Dial(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Most just tests dial failures, happy path is tested widely elsewhere. @@ -205,7 +205,7 @@ func TestTransport_Dial(t *testing.T) { } func TestTransport_Endpoints(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -237,7 +237,7 @@ func TestTransport_Endpoints(t *testing.T) { } func TestTransport_Protocols(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -253,7 +253,7 @@ func TestTransport_Protocols(t *testing.T) { } func TestTransport_String(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -263,7 +263,7 @@ func TestTransport_String(t *testing.T) { } func TestConnection_Handshake(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -317,7 +317,7 @@ func TestConnection_Handshake(t *testing.T) { } func TestConnection_HandshakeCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -347,7 +347,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { } func TestConnection_FlushClose(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -368,7 +368,7 @@ func TestConnection_FlushClose(t *testing.T) { } func TestConnection_LocalRemoteEndpoint(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -385,7 +385,7 @@ func TestConnection_LocalRemoteEndpoint(t *testing.T) { } func TestConnection_SendReceive(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { @@ -446,7 +446,7 @@ func TestConnection_SendReceive(t *testing.T) { } func TestConnection_String(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go index 41a34bde7..d41b4604f 100644 --- a/internal/proxy/client_test.go +++ b/internal/proxy/client_test.go @@ -65,7 +65,7 @@ func TestEcho(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Start server @@ -105,7 +105,7 @@ func BenchmarkEcho(b *testing.B) { b.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Start server @@ -143,7 +143,7 @@ func BenchmarkEcho(b *testing.B) { } func TestInfo(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() sockPath := fmt.Sprintf("unix://%s/echo_%v.sock", t.TempDir(), tmrand.Str(6)) @@ -180,7 +180,7 @@ type noopStoppableClientImpl struct { func (c *noopStoppableClientImpl) Stop() { c.count++ } func TestAppConns_Start_Stop(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() clientMock := &abcimocks.Client{} @@ -209,7 +209,7 @@ func TestAppConns_Failure(t *testing.T) { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGTERM, syscall.SIGABRT) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() clientMock := &abcimocks.Client{} diff --git a/internal/pubsub/example_test.go b/internal/pubsub/example_test.go index c74f13903..e1873375a 100644 --- a/internal/pubsub/example_test.go +++ b/internal/pubsub/example_test.go @@ -13,7 +13,7 @@ import ( ) func TestExample(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() s := newTestServer(ctx, t, log.NewNopLogger()) diff --git a/internal/pubsub/pubsub_test.go b/internal/pubsub/pubsub_test.go index c5fd45d56..9a992a9b9 100644 --- a/internal/pubsub/pubsub_test.go +++ b/internal/pubsub/pubsub_test.go @@ -29,7 +29,7 @@ func (pubstring) TypeTag() string { return "pubstring" } func (e pubstring) ToLegacy() types.LegacyEventData { return e } func TestSubscribeWithArgs(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -59,7 +59,7 @@ func TestSubscribeWithArgs(t *testing.T) { } func TestObserver(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -80,7 +80,7 @@ func TestObserver(t *testing.T) { } func TestObserverErrors(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -93,7 +93,7 @@ func TestObserverErrors(t *testing.T) { } func TestPublishDoesNotBlock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -123,7 +123,7 @@ func TestPublishDoesNotBlock(t *testing.T) { } func TestSubscribeErrors(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -140,7 +140,7 @@ func TestSubscribeErrors(t *testing.T) { } func TestSlowSubscriber(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -162,7 +162,7 @@ func TestSlowSubscriber(t *testing.T) { } func TestDifferentClients(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -217,7 +217,7 @@ func TestDifferentClients(t *testing.T) { } func TestSubscribeDuplicateKeys(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -273,7 +273,7 @@ func TestSubscribeDuplicateKeys(t *testing.T) { } func TestClientSubscribesTwice(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -309,7 +309,7 @@ func TestClientSubscribesTwice(t *testing.T) { } func TestUnsubscribe(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -334,7 +334,7 @@ func TestUnsubscribe(t *testing.T) { } func TestClientUnsubscribesTwice(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -356,7 +356,7 @@ func TestClientUnsubscribesTwice(t *testing.T) { } func TestResubscribe(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -380,7 +380,7 @@ func TestResubscribe(t *testing.T) { } func TestUnsubscribeAll(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -409,7 +409,7 @@ func TestBufferCapacity(t *testing.T) { require.Equal(t, 2, s.BufferCapacity()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() require.NoError(t, s.Publish(pubstring("Nighthawk"))) diff --git a/internal/rpc/core/blocks_test.go b/internal/rpc/core/blocks_test.go index 6ec400069..2b44e9098 100644 --- a/internal/rpc/core/blocks_test.go +++ b/internal/rpc/core/blocks_test.go @@ -104,7 +104,7 @@ func TestBlockResults(t *testing.T) { }}, } - ctx := context.Background() + ctx := t.Context() for _, tc := range testCases { res, err := env.BlockResults(ctx, &coretypes.RequestBlockInfo{ Height: (*coretypes.Int64)(&tc.height), diff --git a/internal/state/execution_test.go b/internal/state/execution_test.go index 41d7ae27c..f6351dea2 100644 --- a/internal/state/execution_test.go +++ b/internal/state/execution_test.go @@ -43,7 +43,7 @@ func TestApplyBlock(t *testing.T) { cc := abciclient.NewLocalClient(logger, app) proxyApp := proxy.New(cc, logger, proxy.NopMetrics()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() require.NoError(t, proxyApp.Start(ctx)) @@ -86,7 +86,7 @@ func TestApplyBlock(t *testing.T) { // DecidedLastCommit properly reflects which validators signed the preceding // block. func TestFinalizeBlockDecidedLastCommit(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -162,7 +162,7 @@ func TestFinalizeBlockDecidedLastCommit(t *testing.T) { // TestFinalizeBlockByzantineValidators ensures we send byzantine validators list. func TestFinalizeBlockByzantineValidators(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() app := &testApp{} @@ -283,7 +283,7 @@ func TestFinalizeBlockByzantineValidators(t *testing.T) { func TestProcessProposal(t *testing.T) { const height = 2 txs := factory.MakeNTxs(height, 10) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() app := abcimocks.NewApplication(t) @@ -506,7 +506,7 @@ func TestUpdateValidators(t *testing.T) { // TestFinalizeBlockValidatorUpdates ensures we update validator set and send an event. func TestFinalizeBlockValidatorUpdates(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() app := &testApp{} @@ -592,7 +592,7 @@ func TestFinalizeBlockValidatorUpdates(t *testing.T) { // TestFinalizeBlockValidatorUpdatesResultingInEmptySet checks that processing validator updates that // would result in empty set causes no panic, an error is raised and NextValidators is not updated func TestFinalizeBlockValidatorUpdatesResultingInEmptySet(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() app := &testApp{} @@ -640,7 +640,7 @@ func TestFinalizeBlockValidatorUpdatesResultingInEmptySet(t *testing.T) { func TestEmptyPrepareProposal(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -694,7 +694,7 @@ func TestEmptyPrepareProposal(t *testing.T) { // a transaction as REMOVED that was not present in the original proposal. func TestPrepareProposalErrorOnNonExistingRemoved(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -751,7 +751,7 @@ func TestPrepareProposalErrorOnNonExistingRemoved(t *testing.T) { // in the order matching the order they are returned from PrepareProposal. func TestPrepareProposalReorderTxs(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -808,7 +808,7 @@ func TestPrepareProposalReorderTxs(t *testing.T) { // an error if the ResponsePrepareProposal returned from the application is invalid. func TestPrepareProposalErrorOnTooManyTxs(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -865,7 +865,7 @@ func TestPrepareProposalErrorOnTooManyTxs(t *testing.T) { // upon calling PrepareProposal on it. func TestPrepareProposalErrorOnPrepareProposalError(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -953,7 +953,7 @@ func TestCreateProposalAbsentVoteExtensions(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/internal/state/indexer/block/kv/kv_test.go b/internal/state/indexer/block/kv/kv_test.go index e2fcbc3cb..a915bee9d 100644 --- a/internal/state/indexer/block/kv/kv_test.go +++ b/internal/state/indexer/block/kv/kv_test.go @@ -133,7 +133,7 @@ func TestBlockIndexer(t *testing.T) { for name, tc := range testCases { tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() results, err := indexer.Search(ctx, tc.q) diff --git a/internal/state/indexer/indexer_service_test.go b/internal/state/indexer/indexer_service_test.go index 5524c9a20..bcd808ef3 100644 --- a/internal/state/indexer/indexer_service_test.go +++ b/internal/state/indexer/indexer_service_test.go @@ -40,7 +40,7 @@ var ( ) func TestIndexerServiceIndexesBlocks(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := tmlog.NewNopLogger() diff --git a/internal/state/indexer/sink/kv/kv_test.go b/internal/state/indexer/sink/kv/kv_test.go index 811b4c195..f49b7c379 100644 --- a/internal/state/indexer/sink/kv/kv_test.go +++ b/internal/state/indexer/sink/kv/kv_test.go @@ -144,7 +144,7 @@ func TestBlockFuncs(t *testing.T) { for name, tc := range testCases { tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() results, err := indexer.SearchBlockEvents(ctx, tc.q) @@ -169,7 +169,7 @@ func TestTxSearchWithCancelation(t *testing.T) { assert.Nil(t, e) assert.Equal(t, r, txResult) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancel() results, err := indexer.SearchTxEvents(ctx, query.MustCompile(`account.number = 1`)) assert.NoError(t, err) @@ -240,7 +240,7 @@ func TestTxSearchDeprecatedIndexing(t *testing.T) { {"sender = 'addr1'", []*abci.TxResult{txResult2}}, } - ctx := context.Background() + ctx := t.Context() for _, tc := range testCases { tc := tc @@ -267,7 +267,7 @@ func TestTxSearchOneTxWithMultipleSameTagsButDifferentValues(t *testing.T) { err := indexer.IndexTxEvents([]*abci.TxResult{txResult}) require.NoError(t, err) - ctx := context.Background() + ctx := t.Context() results, err := indexer.SearchTxEvents(ctx, query.MustCompile(`account.number >= 1`)) assert.NoError(t, err) @@ -324,7 +324,7 @@ func TestTxSearchMultipleTxs(t *testing.T) { err = indexer.IndexTxEvents([]*abci.TxResult{txResult4}) require.NoError(t, err) - ctx := context.Background() + ctx := t.Context() results, err := indexer.SearchTxEvents(ctx, query.MustCompile(`account.number >= 1`)) assert.NoError(t, err) diff --git a/internal/state/indexer/sink/null/null_test.go b/internal/state/indexer/sink/null/null_test.go index 9af66027f..8129143ef 100644 --- a/internal/state/indexer/sink/null/null_test.go +++ b/internal/state/indexer/sink/null/null_test.go @@ -11,7 +11,7 @@ import ( ) func TestNullEventSink(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() nullIndexer := NewEventSink() diff --git a/internal/state/indexer/sink/psql/psql_test.go b/internal/state/indexer/sink/psql/psql_test.go index 381d67b61..1f51b1b15 100644 --- a/internal/state/indexer/sink/psql/psql_test.go +++ b/internal/state/indexer/sink/psql/psql_test.go @@ -152,7 +152,7 @@ func TestType(t *testing.T) { } func TestIndexing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Run("IndexBlockEvents", func(t *testing.T) { diff --git a/internal/state/indexer/tx/kv/kv_bench_test.go b/internal/state/indexer/tx/kv/kv_bench_test.go index 219399b94..08c08e5ca 100644 --- a/internal/state/indexer/tx/kv/kv_bench_test.go +++ b/internal/state/indexer/tx/kv/kv_bench_test.go @@ -60,7 +60,7 @@ func BenchmarkTxSearch(b *testing.B) { b.ResetTimer() - ctx := context.Background() + ctx := b.Context() for i := 0; i < b.N; i++ { if _, err := indexer.Search(ctx, txQuery); err != nil { diff --git a/internal/state/indexer/tx/kv/kv_test.go b/internal/state/indexer/tx/kv/kv_test.go index 08ee0c80e..ef547b81a 100644 --- a/internal/state/indexer/tx/kv/kv_test.go +++ b/internal/state/indexer/tx/kv/kv_test.go @@ -131,7 +131,7 @@ func TestTxSearch(t *testing.T) { {"tx.height = 1", 1}, } - ctx := context.Background() + ctx := t.Context() for _, tc := range testCases { tc := tc @@ -160,7 +160,7 @@ func TestTxSearchWithCancelation(t *testing.T) { err := indexer.Index([]*abci.TxResult{txResult}) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancel() results, err := indexer.Search(ctx, query.MustCompile(`account.number = 1`)) assert.NoError(t, err) @@ -230,7 +230,7 @@ func TestTxSearchDeprecatedIndexing(t *testing.T) { {"sender = 'addr1'", []*abci.TxResult{txResult2}}, } - ctx := context.Background() + ctx := t.Context() for _, tc := range testCases { tc := tc @@ -257,7 +257,7 @@ func TestTxSearchOneTxWithMultipleSameTagsButDifferentValues(t *testing.T) { err := indexer.Index([]*abci.TxResult{txResult}) require.NoError(t, err) - ctx := context.Background() + ctx := t.Context() results, err := indexer.Search(ctx, query.MustCompile(`account.number >= 1`)) assert.NoError(t, err) @@ -314,7 +314,7 @@ func TestTxSearchMultipleTxs(t *testing.T) { err = indexer.Index([]*abci.TxResult{txResult4}) require.NoError(t, err) - ctx := context.Background() + ctx := t.Context() results, err := indexer.Search(ctx, query.MustCompile(`account.number >= 1`)) assert.NoError(t, err) diff --git a/internal/state/rollback_test.go b/internal/state/rollback_test.go index 7c2ba54b8..33d76e228 100644 --- a/internal/state/rollback_test.go +++ b/internal/state/rollback_test.go @@ -125,7 +125,7 @@ func TestRollbackDifferentStateHeight(t *testing.T) { func setupStateStore(t *testing.T, height int64) state.Store { stateStore := state.NewStore(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() valSet, _ := factory.ValidatorSet(ctx, t, 5, 10) diff --git a/internal/state/state_test.go b/internal/state/state_test.go index c30b14636..5e39cd949 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -316,7 +316,7 @@ func TestOneValidatorChangesSaveLoad(t *testing.T) { } func TestProposerFrequency(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // some explicit test cases diff --git a/internal/state/store_test.go b/internal/state/store_test.go index f6982b5bd..d2de75553 100644 --- a/internal/state/store_test.go +++ b/internal/state/store_test.go @@ -28,7 +28,7 @@ const ( func TestStoreBootstrap(t *testing.T) { stateDB := dbm.NewMemDB() stateStore := sm.NewStore(stateDB) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() val, _, err := factory.Validator(ctx, 10+int64(rand.Uint32())) require.NoError(t, err) @@ -58,7 +58,7 @@ func TestStoreBootstrap(t *testing.T) { func TestStoreLoadValidators(t *testing.T) { stateDB := dbm.NewMemDB() stateStore := sm.NewStore(stateDB) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() val, _, err := factory.Validator(ctx, 10+int64(rand.Uint32())) require.NoError(t, err) @@ -148,7 +148,7 @@ func BenchmarkLoadValidators(b *testing.B) { } func TestStoreLoadConsensusParams(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() stateDB := dbm.NewMemDB() diff --git a/internal/state/validation_test.go b/internal/state/validation_test.go index 9a4afd8ad..fb746cee7 100644 --- a/internal/state/validation_test.go +++ b/internal/state/validation_test.go @@ -31,7 +31,7 @@ import ( const validationTestsStopHeight int64 = 10 func TestValidateBlockHeader(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() proxyApp := proxy.New(abciclient.NewLocalClient(logger, &testApp{}), logger, proxy.NopMetrics()) @@ -138,7 +138,7 @@ func TestValidateBlockHeader(t *testing.T) { } func TestValidateBlockCommit(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -288,7 +288,7 @@ func TestValidateBlockCommit(t *testing.T) { } func TestValidateBlockEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/internal/statesync/block_queue_test.go b/internal/statesync/block_queue_test.go index 364a7f5b2..8f6cdcd65 100644 --- a/internal/statesync/block_queue_test.go +++ b/internal/statesync/block_queue_test.go @@ -23,7 +23,7 @@ var ( ) func TestBlockQueueBasic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") @@ -73,7 +73,7 @@ loop: // Test with spurious failures and retries func TestBlockQueueWithFailures(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") @@ -132,7 +132,7 @@ func TestBlockQueueBlocks(t *testing.T) { expectedHeight := startHeight retryHeight := stopHeight + 2 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() loop: @@ -181,7 +181,7 @@ func TestBlockQueueAcceptsNoMoreBlocks(t *testing.T) { queue := newBlockQueue(startHeight, stopHeight, 1, stopTime, 1) defer queue.close() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() loop: @@ -210,7 +210,7 @@ func TestBlockQueueStopTime(t *testing.T) { queue := newBlockQueue(startHeight, stopHeight, 1, stopTime, 1) wg := &sync.WaitGroup{} - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() baseTime := stopTime.Add(-50 * time.Second) @@ -257,7 +257,7 @@ func TestBlockQueueInitialHeight(t *testing.T) { queue := newBlockQueue(startHeight, stopHeight, initialHeight, stopTime, 1) wg := &sync.WaitGroup{} - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // asynchronously fetch blocks and add it to the queue diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index f73561e41..bb10df6fe 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -184,7 +184,7 @@ func setup( } func TestReactor_Sync(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Minute) defer cancel() const snapshotHeight = 7 @@ -238,7 +238,7 @@ func TestReactor_Sync(t *testing.T) { } func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, nil, 2) @@ -284,7 +284,7 @@ func TestReactor_ChunkRequest(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() for name, tc := range testcases { @@ -318,7 +318,7 @@ func TestReactor_ChunkRequest(t *testing.T) { } func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, nil, 2) @@ -371,7 +371,7 @@ func TestReactor_SnapshotsRequest(t *testing.T) { }, }, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for name, tc := range testcases { @@ -412,7 +412,7 @@ func TestReactor_SnapshotsRequest(t *testing.T) { } func TestReactor_LightBlockResponse(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, nil, 2) @@ -470,7 +470,7 @@ func TestReactor_LightBlockResponse(t *testing.T) { } func TestReactor_BlockProviders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, nil, 2) @@ -537,7 +537,7 @@ func TestReactor_BlockProviders(t *testing.T) { } func TestReactor_StateProviderP2P(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, nil, 3) @@ -633,7 +633,7 @@ func TestReactor_StateProviderP2P(t *testing.T) { } func TestReactor_Backfill(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // test backfill algorithm with varying failure rates [0, 10] diff --git a/internal/statesync/syncer_test.go b/internal/statesync/syncer_test.go index 3fc3f0db4..ecac38aba 100644 --- a/internal/statesync/syncer_test.go +++ b/internal/statesync/syncer_test.go @@ -22,7 +22,7 @@ import ( ) func TestSyncer_SyncAny(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() state := sm.State{ @@ -223,7 +223,7 @@ func TestSyncer_SyncAny_noSnapshots(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, stateProvider, 2) @@ -236,7 +236,7 @@ func TestSyncer_SyncAny_abort(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, stateProvider, 2) @@ -260,7 +260,7 @@ func TestSyncer_SyncAny_reject(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, stateProvider, 2) @@ -302,7 +302,7 @@ func TestSyncer_SyncAny_reject_format(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, stateProvider, 2) @@ -340,7 +340,7 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, stateProvider, 2) @@ -389,7 +389,7 @@ func TestSyncer_SyncAny_abciError(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() rts := setup(ctx, t, nil, stateProvider, 2) @@ -430,7 +430,7 @@ func TestSyncer_offerSnapshot(t *testing.T) { "unknown non-zero": {9, nil, unknownErr}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for name, tc := range testcases { @@ -483,7 +483,7 @@ func TestSyncer_applyChunks_Results(t *testing.T) { "unknown non-zero": {9, nil, unknownErr}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for name, tc := range testcases { @@ -543,7 +543,7 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { "retry_snapshot": {abci.ResponseApplySnapshotChunk_RETRY_SNAPSHOT}, "reject_snapshot": {abci.ResponseApplySnapshotChunk_REJECT_SNAPSHOT}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for name, tc := range testcases { @@ -614,7 +614,7 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { "retry_snapshot": {abci.ResponseApplySnapshotChunk_RETRY_SNAPSHOT}, "reject_snapshot": {abci.ResponseApplySnapshotChunk_REJECT_SNAPSHOT}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for name, tc := range testcases { @@ -750,7 +750,7 @@ func TestSyncer_verifyApp(t *testing.T) { }, nil, errVerifyFailed}, "error": {nil, boom, boom}, } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() for name, tc := range testcases { diff --git a/libs/cli/setup_test.go b/libs/cli/setup_test.go index 9198485ef..c7d2f05a1 100644 --- a/libs/cli/setup_test.go +++ b/libs/cli/setup_test.go @@ -18,7 +18,7 @@ import ( ) func TestSetupEnv(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cases := []struct { @@ -73,7 +73,7 @@ func writeConfigVals(dir string, vals map[string]string) error { } func TestSetupConfig(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // we pre-create two config files we can refer to in the rest of @@ -134,7 +134,7 @@ type DemoConfig struct { } func TestSetupUnmarshal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // we pre-create two config files we can refer to in the rest of @@ -208,7 +208,7 @@ func TestSetupUnmarshal(t *testing.T) { } func TestSetupTrace(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cases := []struct { diff --git a/libs/events/events_test.go b/libs/events/events_test.go index 17f8c56d1..d21080ca1 100644 --- a/libs/events/events_test.go +++ b/libs/events/events_test.go @@ -14,7 +14,7 @@ import ( // TestAddListenerForEventFireOnce sets up an EventSwitch, subscribes a single // listener to an event, and sends a string "data". func TestAddListenerForEventFireOnce(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() evsw := NewEventSwitch() @@ -39,7 +39,7 @@ func TestAddListenerForEventFireOnce(t *testing.T) { // TestAddListenerForEventFireMany sets up an EventSwitch, subscribes a single // listener to an event, and sends a thousand integers. func TestAddListenerForEventFireMany(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() evsw := NewEventSwitch() @@ -73,7 +73,7 @@ func TestAddListenerForEventFireMany(t *testing.T) { // listener to three different events and sends a thousand integers for each // of the three events. func TestAddListenerForDifferentEvents(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Cleanup(leaktest.Check(t)) @@ -135,7 +135,7 @@ func TestAddListenerForDifferentEvents(t *testing.T) { // listener to two of those three events, and then sends a thousand integers // for each of the three events. func TestAddDifferentListenerForDifferentEvents(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() t.Cleanup(leaktest.Check(t)) @@ -229,7 +229,7 @@ func TestAddDifferentListenerForDifferentEvents(t *testing.T) { // NOTE: it is important to run this test with race conditions tracking on, // `go test -race`, to examine for possible race conditions. func TestManageListenersAsync(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() evsw := NewEventSwitch() diff --git a/libs/service/service_test.go b/libs/service/service_test.go index d0b8ce57e..65f947ac2 100644 --- a/libs/service/service_test.go +++ b/libs/service/service_test.go @@ -57,7 +57,7 @@ func (t *testService) isMultiStopped() bool { func TestBaseService(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -114,7 +114,7 @@ func TestBaseService(t *testing.T) { require.False(t, ts.isMultiStopped()) }) t.Run("MultiThreaded", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() ts := &testService{} diff --git a/light/client_benchmark_test.go b/light/client_benchmark_test.go index 1888df795..b2b478f77 100644 --- a/light/client_benchmark_test.go +++ b/light/client_benchmark_test.go @@ -66,7 +66,7 @@ func (impl *providerBenchmarkImpl) ReportEvidence(_ context.Context, _ types.Evi func (impl *providerBenchmarkImpl) ID() string { return "ip-not-defined.com" } func BenchmarkSequence(b *testing.B) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) @@ -104,7 +104,7 @@ func BenchmarkSequence(b *testing.B) { } func BenchmarkBisection(b *testing.B) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) @@ -114,7 +114,7 @@ func BenchmarkBisection(b *testing.B) { logger := log.NewTestingLogger(b) c, err := light.NewClient( - context.Background(), + t.Context(), chainID, light.TrustOptions{ Period: 24 * time.Hour, @@ -141,7 +141,7 @@ func BenchmarkBisection(b *testing.B) { } func BenchmarkBackwards(b *testing.B) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) diff --git a/light/client_test.go b/light/client_test.go index 43e8a9ebc..55f0b877a 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -37,7 +37,7 @@ func init() { } func TestClient(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() var ( keys = genPrivKeys(4) @@ -228,7 +228,7 @@ func TestClient(t *testing.T) { for _, tc := range testCases { testCase := tc t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -351,7 +351,7 @@ func TestClient(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() for _, tc := range testCases { @@ -411,7 +411,7 @@ func TestClient(t *testing.T) { mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) @@ -442,7 +442,7 @@ func TestClient(t *testing.T) { mockNode.AssertExpectations(t) }) t.Run("BisectionBetweenTrustedHeaders", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) @@ -475,7 +475,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("Cleanup", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -507,7 +507,7 @@ func TestClient(t *testing.T) { t.Run("RestoresTrustedHeaderAfterStartup", func(t *testing.T) { // trustedHeader.Height == options.Height - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() // 1. options.Hash == trustedHeader.Hash @@ -585,7 +585,7 @@ func TestClient(t *testing.T) { }) }) t.Run("Update", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockFullNode := &provider_mocks.Provider{} @@ -618,7 +618,7 @@ func TestClient(t *testing.T) { }) t.Run("Concurrency", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -667,7 +667,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("AddProviders", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockFullNode := mockNodeFromHeadersAndVals(map[int64]*types.SignedHeader{ @@ -707,7 +707,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("ReplacesPrimaryWithWitnessIfPrimaryIsUnavailable", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockFullNode := &provider_mocks.Provider{} @@ -744,7 +744,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("TerminatesWitnessSearchAfterContextDeadlineExpires", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(1*time.Second)) + ctx, cancel := context.WithTimeout(t.Context(), time.Duration(1*time.Second)) defer cancel() mockDeadNode := &provider_mocks.Provider{} @@ -771,7 +771,7 @@ func TestClient(t *testing.T) { mockSlowNode.AssertExpectations(t) }) t.Run("ReplacesPrimaryWithWitnessIfPrimaryDoesntHaveBlock", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockFullNode := &provider_mocks.Provider{} @@ -800,7 +800,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("BackwardsVerification", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -950,7 +950,7 @@ func TestClient(t *testing.T) { mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) mockFullNode.On("ID", mock.Anything, mock.Anything).Return(id3, nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() lb1, _ := mockBadNode1.LightBlock(ctx, 2) @@ -1020,7 +1020,7 @@ func TestClient(t *testing.T) { mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) mockFullNode.On("ID", mock.Anything, mock.Anything).Return(id3, nil) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() lb1, _ := mockBadNode1.LightBlock(ctx, 2) @@ -1058,7 +1058,7 @@ func TestClient(t *testing.T) { mockBadNode2.AssertExpectations(t) }) t.Run("TrustedValidatorSet", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -1121,7 +1121,7 @@ func TestClient(t *testing.T) { 0: vals, }) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -1210,7 +1210,7 @@ func TestClient(t *testing.T) { for i, tc := range testCases { testCase := tc t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) diff --git a/light/detector_test.go b/light/detector_test.go index 51585c46d..bc343108a 100644 --- a/light/detector_test.go +++ b/light/detector_test.go @@ -33,7 +33,7 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { primaryValidators = make(map[int64]*types.ValidatorSet, latestHeight) ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, latestHeight, valSize, 2, bTime) @@ -135,7 +135,7 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() for _, tc := range cases { @@ -245,7 +245,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { primaryValidators = make(map[int64]*types.ValidatorSet, forgedHeight) ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -393,7 +393,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { // => light client returns an error upon creation because primary and witness // have a different view. func TestClientDivergentTraces1(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() headers, vals, _ := genLightBlocksWithKeys(t, 1, 5, 2, bTime) @@ -429,7 +429,7 @@ func TestClientDivergentTraces1(t *testing.T) { // 2. Two out of three nodes don't respond but the third has a header that matches // => verification should be successful but two unresponsive witnesses should be blacklisted func TestClientDivergentTraces2(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -473,7 +473,7 @@ func TestClientDivergentTraces3(t *testing.T) { primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() firstBlock, err := mockPrimary.LightBlock(ctx, 1) @@ -517,7 +517,7 @@ func TestClientDivergentTraces4(t *testing.T) { primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() firstBlock, err := mockPrimary.LightBlock(ctx, 1) diff --git a/light/dispatcher_test.go b/light/dispatcher_test.go index 6fc21fb0d..f0594244f 100644 --- a/light/dispatcher_test.go +++ b/light/dispatcher_test.go @@ -39,7 +39,7 @@ func TestDispatcherBasic(t *testing.T) { t.Cleanup(leaktest.Check(t)) const numPeers = 5 - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chans, ch := testChannel(100) @@ -75,7 +75,7 @@ func TestDispatcherBasic(t *testing.T) { func TestDispatcherReturnsNoBlock(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chans, ch := testChannel(100) @@ -104,7 +104,7 @@ func TestDispatcherReturnsNoBlock(t *testing.T) { func TestDispatcherTimeOutWaitingOnLightBlock(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() _, ch := testChannel(100) @@ -130,7 +130,7 @@ func TestDispatcherProviders(t *testing.T) { chainID := "test-chain" - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chans, ch := testChannel(100) @@ -160,7 +160,7 @@ func TestDispatcherProviders(t *testing.T) { func TestPeerListBasic(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerList := NewPeerList() @@ -207,7 +207,7 @@ func TestPeerListBlocksWhenEmpty(t *testing.T) { peerList := NewPeerList() require.Zero(t, peerList.Len()) doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() go func() { peerList.Pop(ctx) @@ -226,7 +226,7 @@ func TestEmptyPeerListReturnsWhenContextCanceled(t *testing.T) { require.Zero(t, peerList.Len()) doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() wrapped, cancel := context.WithCancel(ctx) @@ -251,7 +251,7 @@ func TestEmptyPeerListReturnsWhenContextCanceled(t *testing.T) { func TestPeerListConcurrent(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() peerList := NewPeerList() diff --git a/light/example_test.go b/light/example_test.go index 9858ca4d0..74b9f27b3 100644 --- a/light/example_test.go +++ b/light/example_test.go @@ -18,7 +18,7 @@ import ( // Manually getting light blocks and verifying them. func TestExampleClient(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() conf, err := rpctest.CreateConfig(t, "ExampleClient_VerifyLightBlockAtHeight") if err != nil { diff --git a/light/light_test.go b/light/light_test.go index f9cfc8b06..c7a877785 100644 --- a/light/light_test.go +++ b/light/light_test.go @@ -26,7 +26,7 @@ import ( func TestClientIntegration_Update(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() conf, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) @@ -90,7 +90,7 @@ func TestClientIntegration_Update(t *testing.T) { // Manually getting light blocks and verifying them. func TestClientIntegration_VerifyLightBlockAtHeight(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() conf, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) @@ -168,7 +168,7 @@ func waitForBlock(ctx context.Context, p provider.Provider, height int64) (*type } func TestClientStatusRPC(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() conf, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) diff --git a/light/provider/http/http_test.go b/light/provider/http/http_test.go index cb443caaf..8d6a04f71 100644 --- a/light/provider/http/http_test.go +++ b/light/provider/http/http_test.go @@ -33,7 +33,7 @@ func TestNewProvider(t *testing.T) { } func TestProvider(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) diff --git a/light/store/db/db_test.go b/light/store/db/db_test.go index cae9bbfc5..e647907c0 100644 --- a/light/store/db/db_test.go +++ b/light/store/db/db_test.go @@ -19,7 +19,7 @@ import ( func TestLast_FirstLightBlockHeight(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Empty store @@ -46,7 +46,7 @@ func TestLast_FirstLightBlockHeight(t *testing.T) { func Test_SaveLightBlock(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Empty store @@ -78,7 +78,7 @@ func Test_SaveLightBlock(t *testing.T) { func Test_LightBlockBefore(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() assert.Panics(t, func() { @@ -101,7 +101,7 @@ func Test_LightBlockBefore(t *testing.T) { func Test_Prune(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Empty store @@ -141,7 +141,7 @@ func Test_Prune(t *testing.T) { func Test_Concurrency(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() var wg sync.WaitGroup diff --git a/node/node_test.go b/node/node_test.go index 38dc95380..b90be8ff7 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -45,7 +45,7 @@ func TestNodeStartStop(t *testing.T) { defer os.RemoveAll(cfg.RootDir) - ctx, bcancel := context.WithCancel(context.Background()) + ctx, bcancel := context.WithCancel(t.Context()) defer bcancel() logger := log.NewNopLogger() @@ -112,7 +112,7 @@ func TestNodeDelayedStart(t *testing.T) { defer os.RemoveAll(cfg.RootDir) now := tmtime.Now() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -132,7 +132,7 @@ func TestNodeSetAppVersion(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -158,7 +158,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { addr := "tcp://" + testFreeAddr(t) t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -195,7 +195,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { // address without a protocol must result in error func TestPrivValidatorListenAddrNoProtocol(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() addrNoPrefix := testFreeAddr(t) @@ -221,7 +221,7 @@ func TestNodeSetPrivValIPC(t *testing.T) { tmpfile := "/tmp/kms." + tmrand.Str(6) + ".sock" defer os.Remove(tmpfile) // clean up - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := config.ResetTestRoot(t.TempDir(), "node_priv_val_tcp_test") @@ -268,7 +268,7 @@ func testFreeAddr(t *testing.T) string { // create a proposal block using real and full // mempool and evidence pool and validate it. func TestCreateProposalBlock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := config.ResetTestRoot(t.TempDir(), "node_create_proposal") @@ -369,7 +369,7 @@ func TestCreateProposalBlock(t *testing.T) { } func TestMaxTxsProposalBlockSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := config.ResetTestRoot(t.TempDir(), "node_create_proposal") @@ -443,7 +443,7 @@ func TestMaxTxsProposalBlockSize(t *testing.T) { } func TestMaxProposalBlockSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() cfg, err := config.ResetTestRoot(t.TempDir(), "node_create_proposal") @@ -592,7 +592,7 @@ func TestNodeNewSeedNode(t *testing.T) { cfg.Mode = config.ModeSeed defer os.RemoveAll(cfg.RootDir) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() nodeKey, err := types.LoadOrGenNodeKey(cfg.NodeKeyFile()) @@ -634,7 +634,7 @@ func TestNodeSetEventSink(t *testing.T) { defer os.RemoveAll(cfg.RootDir) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -756,7 +756,7 @@ func state(t *testing.T, nVals int, height int64) (sm.State, dbm.DB, []types.Pri } func TestLoadStateFromGenesis(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() _ = loadStatefromGenesis(ctx, t) diff --git a/privval/file_test.go b/privval/file_test.go index a49bb321c..f824a4d9c 100644 --- a/privval/file_test.go +++ b/privval/file_test.go @@ -36,7 +36,7 @@ func TestGenLoadValidator(t *testing.T) { } func TestResetValidator(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() privVal, _, tempStateFileName := newTestFilePV(t) @@ -146,7 +146,7 @@ func TestUnmarshalValidatorKey(t *testing.T) { } func TestSignVote(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() privVal, _, _ := newTestFilePV(t) @@ -195,7 +195,7 @@ func TestSignVote(t *testing.T) { } func TestSignProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() privVal, _, _ := newTestFilePV(t) @@ -237,7 +237,7 @@ func TestSignProposal(t *testing.T) { } func TestDifferByTimestamp(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() tempKeyFile, err := os.CreateTemp(t.TempDir(), "priv_validator_key_") @@ -278,7 +278,7 @@ func TestDifferByTimestamp(t *testing.T) { } func TestVoteExtensionsAreAlwaysSigned(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() privVal, _, _ := newTestFilePV(t) diff --git a/privval/grpc/client_test.go b/privval/grpc/client_test.go index bde66890e..827303a43 100644 --- a/privval/grpc/client_test.go +++ b/privval/grpc/client_test.go @@ -41,7 +41,7 @@ func dialer(t *testing.T, pv types.PrivValidator, logger log.Logger) (*grpc.Serv func TestSignerClient_GetPubKey(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockPV := types.NewMockPV() @@ -65,7 +65,7 @@ func TestSignerClient_GetPubKey(t *testing.T) { } func TestSignerClient_SignVote(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockPV := types.NewMockPV() @@ -120,7 +120,7 @@ func TestSignerClient_SignVote(t *testing.T) { } func TestSignerClient_SignProposal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() mockPV := types.NewMockPV() diff --git a/privval/grpc/server_test.go b/privval/grpc/server_test.go index ad63b87a6..9e80b7534 100644 --- a/privval/grpc/server_test.go +++ b/privval/grpc/server_test.go @@ -33,7 +33,7 @@ func TestGetPubKey(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewTestingLogger(t) @@ -108,7 +108,7 @@ func TestSignVote(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewTestingLogger(t) @@ -179,7 +179,7 @@ func TestSignProposal(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewTestingLogger(t) diff --git a/privval/signer_client_test.go b/privval/signer_client_test.go index fef16f2d6..2c5ade79b 100644 --- a/privval/signer_client_test.go +++ b/privval/signer_client_test.go @@ -65,7 +65,7 @@ func getSignerTestCases(ctx context.Context, t *testing.T, logger log.Logger) [] func TestSignerClose(t *testing.T) { t.Cleanup(leaktest.Check(t)) - bctx, bcancel := context.WithCancel(context.Background()) + bctx, bcancel := context.WithCancel(t.Context()) defer bcancel() logger := log.NewNopLogger() @@ -88,7 +88,7 @@ func TestSignerClose(t *testing.T) { func TestSignerPing(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -102,7 +102,7 @@ func TestSignerPing(t *testing.T) { func TestSignerGetPubKey(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -132,7 +132,7 @@ func TestSignerGetPubKey(t *testing.T) { func TestSignerProposal(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -172,7 +172,7 @@ func TestSignerProposal(t *testing.T) { func TestSignerVote(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -215,7 +215,7 @@ func TestSignerVote(t *testing.T) { func TestSignerVoteResetDeadline(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -266,7 +266,7 @@ func TestSignerVoteResetDeadline(t *testing.T) { func TestSignerVoteKeepAlive(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -316,7 +316,7 @@ func TestSignerVoteKeepAlive(t *testing.T) { func TestSignerSignProposalErrors(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -357,7 +357,7 @@ func TestSignerSignProposalErrors(t *testing.T) { func TestSignerSignVoteErrors(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -423,7 +423,7 @@ func brokenHandler(ctx context.Context, privVal types.PrivValidator, request pri func TestSignerUnexpectedResponse(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 6049c6245..21f177013 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -42,7 +42,7 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() @@ -94,7 +94,7 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { func TestRetryConnToRemoteSigner(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewNopLogger() diff --git a/rpc/client/eventstream/eventstream_test.go b/rpc/client/eventstream/eventstream_test.go index 8cd9df30f..95f668272 100644 --- a/rpc/client/eventstream/eventstream_test.go +++ b/rpc/client/eventstream/eventstream_test.go @@ -101,7 +101,7 @@ func TestMinPollTime(t *testing.T) { // Waiting for an item on a log with no matching events incurs a minimum // wait time and reports no events. - ctx := context.Background() + ctx := t.Context() filter := &coretypes.EventFilter{Query: `tm.event = 'good'`} t.Run("NoneMatch", func(t *testing.T) { @@ -198,7 +198,7 @@ func newStreamTester(t *testing.T, query string, logOpts eventlog.LogSettings, s // start starts the stream receiver, which runs until it it terminated by // calling stop. func (s *streamTester) start() { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) s.errc = make(chan error, 1) s.recv = make(chan *coretypes.EventItem) s.stop = cancel diff --git a/rpc/client/examples_test.go b/rpc/client/examples_test.go index 163093c84..be514920a 100644 --- a/rpc/client/examples_test.go +++ b/rpc/client/examples_test.go @@ -18,7 +18,7 @@ import ( ) func TestHTTPSimple(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Start a tendermint node (and kvstore) in the background to test against @@ -68,7 +68,7 @@ func TestHTTPSimple(t *testing.T) { } func TestHTTPBatching(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Start a tendermint node (and kvstore) in the background to test against diff --git a/rpc/client/helpers_test.go b/rpc/client/helpers_test.go index a66becbd5..a521ea829 100644 --- a/rpc/client/helpers_test.go +++ b/rpc/client/helpers_test.go @@ -15,7 +15,7 @@ import ( ) func TestWaitForHeight(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // test with error result - immediate failure diff --git a/rpc/client/mock/abci_test.go b/rpc/client/mock/abci_test.go index e35ccf29c..d218cbc75 100644 --- a/rpc/client/mock/abci_test.go +++ b/rpc/client/mock/abci_test.go @@ -19,7 +19,7 @@ import ( ) func TestABCIMock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() key, value := []byte("foo"), []byte("bar") @@ -80,7 +80,7 @@ func TestABCIMock(t *testing.T) { } func TestABCIRecorder(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // This mock returns errors on everything but Query @@ -165,7 +165,7 @@ func TestABCIApp(t *testing.T) { app := kvstore.NewApplication() m := mock.ABCIApp{app} - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // get some info diff --git a/rpc/client/mock/status_test.go b/rpc/client/mock/status_test.go index fb70ca9d9..055a99c6a 100644 --- a/rpc/client/mock/status_test.go +++ b/rpc/client/mock/status_test.go @@ -14,7 +14,7 @@ import ( ) func TestStatus(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() m := &mock.StatusMock{ diff --git a/rpc/client/rpc_test.go b/rpc/client/rpc_test.go index 2316b5f58..c8ffaead4 100644 --- a/rpc/client/rpc_test.go +++ b/rpc/client/rpc_test.go @@ -40,7 +40,7 @@ func getHTTPClient(t *testing.T, logger log.Logger, conf *config.Config) *rpchtt rpcAddr := conf.RPC.ListenAddress c, err := rpchttp.NewWithClient(rpcAddr, http.DefaultClient) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) require.NoError(t, c.Start(ctx)) c.Logger = logger @@ -60,7 +60,7 @@ func getHTTPClientWithTimeout(t *testing.T, logger log.Logger, conf *config.Conf tclient := &http.Client{Timeout: timeout} c, err := rpchttp.NewWithClient(rpcAddr, tclient) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) require.NoError(t, c.Start(ctx)) c.Logger = logger @@ -90,7 +90,7 @@ func GetClients(t *testing.T, ns service.Service, conf *config.Config) []client. } func TestClientOperations(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewTestingLogger(t) @@ -189,7 +189,7 @@ func TestClientOperations(t *testing.T) { // Make sure info is correct (we connect properly) func TestClientMethodCalls(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewTestingLogger(t) @@ -409,7 +409,7 @@ func TestClientMethodCalls(t *testing.T) { // XXX Test proof }) t.Run("BlockchainInfo", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() err := client.WaitForHeight(ctx, c, 10, nil) @@ -523,7 +523,7 @@ func TestClientMethodCalls(t *testing.T) { }) t.Run("Evidence", func(t *testing.T) { t.Run("BroadcastDuplicateVote", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() chainID := conf.ChainID() @@ -589,7 +589,7 @@ func getMempool(t *testing.T, srv service.Service) mempool.Mempool { // so making a separate suite makes more sense, though isn't strictly // speaking desirable. func TestClientMethodCallsAdvanced(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() logger := log.NewTestingLogger(t) diff --git a/rpc/jsonrpc/client/integration_test.go b/rpc/jsonrpc/client/integration_test.go index f53b28802..5e7023ef8 100644 --- a/rpc/jsonrpc/client/integration_test.go +++ b/rpc/jsonrpc/client/integration_test.go @@ -23,7 +23,7 @@ func TestWSClientReconnectWithJitter(t *testing.T) { const maxReconnectAttempts = 3 const maxSleepTime = time.Duration(((1< Date: Mon, 18 Aug 2025 16:07:35 +0200 Subject: [PATCH 02/41] almost --- internal/mempool/mempool_bench_test.go | 2 +- internal/p2p/address_test.go | 60 +++++++++++-------- internal/proxy/client_test.go | 2 +- internal/rpc/core/blocks_test.go | 1 - internal/state/indexer/tx/kv/kv_bench_test.go | 1 - libs/os/os.go | 2 +- light/client_benchmark_test.go | 11 +--- rpc/client/eventstream/eventstream_test.go | 14 +++-- rpc/jsonrpc/server/http_json_handler.go | 2 +- rpc/jsonrpc/server/rpc_func.go | 4 +- rpc/jsonrpc/server/ws_handler.go | 2 +- scripts/keymigrate/migrate_test.go | 1 - types/proposal_test.go | 9 ++- types/validator_set_test.go | 20 +++---- types/vote_test.go | 2 +- 15 files changed, 67 insertions(+), 66 deletions(-) diff --git a/internal/mempool/mempool_bench_test.go b/internal/mempool/mempool_bench_test.go index a16b0a8e9..e1d1d5955 100644 --- a/internal/mempool/mempool_bench_test.go +++ b/internal/mempool/mempool_bench_test.go @@ -15,7 +15,7 @@ import ( ) func BenchmarkTxMempool_CheckTx(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() client := abciclient.NewLocalClient(log.NewNopLogger(), kvstore.NewApplication()) diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index f6681a399..468c046b2 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -1,7 +1,6 @@ package p2p_test import ( - "context" "net" "strings" "testing" @@ -28,7 +27,6 @@ func TestNewNodeID(t *testing.T) { {"00112233445566778899aabbccddeeff0011223g", "", false}, } for _, tc := range testcases { - tc := tc t.Run(tc.input, func(t *testing.T) { id, err := types.NewNodeID(tc.input) if !tc.ok { @@ -61,7 +59,6 @@ func TestNodeID_Bytes(t *testing.T) { {"01g0", nil, false}, } for _, tc := range testcases { - tc := tc t.Run(string(tc.nodeID), func(t *testing.T) { bz, err := tc.nodeID.Bytes() if tc.ok { @@ -87,7 +84,6 @@ func TestNodeID_Validate(t *testing.T) { {"00112233445566778899AABBCCDDEEFF00112233", false}, } for _, tc := range testcases { - tc := tc t.Run(string(tc.nodeID), func(t *testing.T) { err := tc.nodeID.Validate() if tc.ok { @@ -189,7 +185,6 @@ func TestParseNodeAddress(t *testing.T) { {"mconn://" + user + "@:80", p2p.NodeAddress{}, false}, } for _, tc := range testcases { - tc := tc t.Run(tc.url, func(t *testing.T) { address, err := p2p.ParseNodeAddress(tc.url) if !tc.ok { @@ -205,9 +200,6 @@ func TestParseNodeAddress(t *testing.T) { func TestNodeAddress_Resolve(t *testing.T) { id := types.NodeID("00112233445566778899aabbccddeeff00112233") - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - testcases := []struct { address p2p.NodeAddress expect *p2p.Endpoint @@ -219,16 +211,6 @@ func TestNodeAddress_Resolve(t *testing.T) { &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1), Port: 80, Path: "/path"}, true, }, - { - p2p.NodeAddress{Protocol: "tcp", Hostname: "localhost", Port: 80, Path: "/path"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1), Port: 80, Path: "/path"}, - true, - }, - { - p2p.NodeAddress{Protocol: "tcp", Hostname: "localhost", Port: 80, Path: "/path"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv6loopback, Port: 80, Path: "/path"}, - true, - }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1"}, &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1)}, @@ -277,19 +259,49 @@ func TestNodeAddress_Resolve(t *testing.T) { {p2p.NodeAddress{Protocol: "tcp", Hostname: "💥"}, &p2p.Endpoint{}, false}, } for _, tc := range testcases { - tc := tc t.Run(tc.address.String(), func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() - - endpoints, err := tc.address.Resolve(ctx) + endpoints, err := tc.address.Resolve(t.Context()) if !tc.ok { require.Error(t, err) return } + + // Special handling for localhost tests - accept either IPv4 or IPv6 + if tc.address.Hostname == "localhost" && tc.address.Port == 80 && tc.address.Path == "/path" { + hasIPv4 := false + hasIPv6 := false + for _, ep := range endpoints { + if ep.Protocol == "tcp" && ep.Port == 80 && ep.Path == "/path" { + if ep.IP.Equal(net.IPv4(127, 0, 0, 1)) { + hasIPv4 = true + } + if ep.IP.Equal(net.IPv6loopback) { + hasIPv6 = true + } + } + } + require.True(t, hasIPv4 || hasIPv6, "localhost should resolve to either IPv4 or IPv6") + return + } + require.Contains(t, endpoints, tc.expect) }) } + t.Run("Resolve localhost", func(t *testing.T) { + addr := p2p.NodeAddress{Protocol: "tcp", Hostname: "localhost", Port: 80, Path: "/path"} + endpoints, err := addr.Resolve(t.Context()) + require.NoError(t, err) + + want := []*p2p.Endpoint{ + {Protocol: "tcp", IP: net.IPv6loopback, Port: 80, Path: "/path"}, + {Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1), Port: 80, Path: "/path"}, + } + + require.True(t, len(endpoints)>0) + for _, got := range endpoints { + require.Contains(t, want, got) + } + }) } func TestNodeAddress_String(t *testing.T) { @@ -348,7 +360,6 @@ func TestNodeAddress_String(t *testing.T) { }, } for _, tc := range testcases { - tc := tc t.Run(tc.address.String(), func(t *testing.T) { require.Equal(t, tc.expect, tc.address.String()) }) @@ -375,7 +386,6 @@ func TestNodeAddress_Validate(t *testing.T) { {p2p.NodeAddress{Protocol: "mconn", NodeID: id, Port: 80, Path: "path"}, false}, } for _, tc := range testcases { - tc := tc t.Run(tc.address.String(), func(t *testing.T) { err := tc.address.Validate() if tc.ok { diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go index d41b4604f..6d086b19d 100644 --- a/internal/proxy/client_test.go +++ b/internal/proxy/client_test.go @@ -105,7 +105,7 @@ func BenchmarkEcho(b *testing.B) { b.Fatal(err) } - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() // Start server diff --git a/internal/rpc/core/blocks_test.go b/internal/rpc/core/blocks_test.go index 2b44e9098..06ecb6721 100644 --- a/internal/rpc/core/blocks_test.go +++ b/internal/rpc/core/blocks_test.go @@ -1,7 +1,6 @@ package core import ( - "context" "fmt" "testing" diff --git a/internal/state/indexer/tx/kv/kv_bench_test.go b/internal/state/indexer/tx/kv/kv_bench_test.go index 08c08e5ca..609772725 100644 --- a/internal/state/indexer/tx/kv/kv_bench_test.go +++ b/internal/state/indexer/tx/kv/kv_bench_test.go @@ -1,7 +1,6 @@ package kv import ( - "context" "crypto/rand" "fmt" "testing" diff --git a/libs/os/os.go b/libs/os/os.go index eb7cff38e..61733cbfe 100644 --- a/libs/os/os.go +++ b/libs/os/os.go @@ -82,7 +82,7 @@ func Kill() error { } func Exit(s string) { - fmt.Printf(s + "\n") + fmt.Println(s) os.Exit(1) } diff --git a/light/client_benchmark_test.go b/light/client_benchmark_test.go index b2b478f77..ef5ea0de8 100644 --- a/light/client_benchmark_test.go +++ b/light/client_benchmark_test.go @@ -66,8 +66,7 @@ func (impl *providerBenchmarkImpl) ReportEvidence(_ context.Context, _ types.Evi func (impl *providerBenchmarkImpl) ID() string { return "ip-not-defined.com" } func BenchmarkSequence(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := b.Context() headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) @@ -104,9 +103,6 @@ func BenchmarkSequence(b *testing.B) { } func BenchmarkBisection(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) genesisBlock, _ := benchmarkFullNode.LightBlock(ctx, 1) @@ -114,7 +110,7 @@ func BenchmarkBisection(b *testing.B) { logger := log.NewTestingLogger(b) c, err := light.NewClient( - t.Context(), + b.Context(), chainID, light.TrustOptions{ Period: 24 * time.Hour, @@ -141,8 +137,7 @@ func BenchmarkBisection(b *testing.B) { } func BenchmarkBackwards(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := b.Context() headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) diff --git a/rpc/client/eventstream/eventstream_test.go b/rpc/client/eventstream/eventstream_test.go index 95f668272..ddd05021a 100644 --- a/rpc/client/eventstream/eventstream_test.go +++ b/rpc/client/eventstream/eventstream_test.go @@ -19,6 +19,7 @@ import ( ) func TestStream_filterOrder(t *testing.T) { + ctx := t.Context() defer leaktest.Check(t) s := newStreamTester(t, `tm.event = 'good'`, eventlog.LogSettings{ @@ -44,7 +45,7 @@ func TestStream_filterOrder(t *testing.T) { } } - s.start() + s.start(ctx) for _, itm := range items { s.mustItem(t, itm) } @@ -52,6 +53,7 @@ func TestStream_filterOrder(t *testing.T) { } func TestStream_lostItem(t *testing.T) { + ctx := t.Context() defer leaktest.Check(t) s := newStreamTester(t, ``, eventlog.LogSettings{ @@ -60,7 +62,7 @@ func TestStream_lostItem(t *testing.T) { // Publish an item and let the client observe it. cur := s.publish("ok", "whatever") - s.start() + s.start(ctx) s.mustItem(t, makeTestItem(cur, "whatever")) s.stopWait() @@ -72,7 +74,7 @@ func TestStream_lostItem(t *testing.T) { // At this point, the oldest item in the log is newer than the point at // which we continued, we should get an error. - s.start() + s.start(ctx) var missed *eventstream.MissedItemsError if err := s.mustError(t); !errors.As(err, &missed) { t.Errorf("Wrong error: got %v, want %T", err, missed) @@ -83,7 +85,7 @@ func TestStream_lostItem(t *testing.T) { // If we reset the stream and continue from head, we should catch up. s.stopWait() s.stream.Reset() - s.start() + s.start(ctx) s.mustItem(t, makeTestItem(next1, "more stuff")) s.mustItem(t, makeTestItem(next2, "still more stuff")) @@ -197,8 +199,8 @@ func newStreamTester(t *testing.T, query string, logOpts eventlog.LogSettings, s // start starts the stream receiver, which runs until it it terminated by // calling stop. -func (s *streamTester) start() { - ctx, cancel := context.WithCancel(t.Context()) +func (s *streamTester) start(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) s.errc = make(chan error, 1) s.recv = make(chan *coretypes.EventItem) s.stop = cancel diff --git a/rpc/jsonrpc/server/http_json_handler.go b/rpc/jsonrpc/server/http_json_handler.go index bbaa0727d..3707cdddb 100644 --- a/rpc/jsonrpc/server/http_json_handler.go +++ b/rpc/jsonrpc/server/http_json_handler.go @@ -64,7 +64,7 @@ func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.Han rpcFunc, ok := funcMap[req.Method] if !ok || rpcFunc.ws { - responses = append(responses, req.MakeErrorf(rpctypes.CodeMethodNotFound, req.Method)) + responses = append(responses, req.MakeErrorf(rpctypes.CodeMethodNotFound, "method %s not found", req.Method)) continue } diff --git a/rpc/jsonrpc/server/rpc_func.go b/rpc/jsonrpc/server/rpc_func.go index 01dc64b24..825cdde12 100644 --- a/rpc/jsonrpc/server/rpc_func.go +++ b/rpc/jsonrpc/server/rpc_func.go @@ -108,11 +108,11 @@ func (rf *RPCFunc) parseParams(ctx context.Context, params json.RawMessage) ([]r } bits, err := rf.adjustParams(params) if err != nil { - return nil, invalidParamsError(err.Error()) + return nil, invalidParamsError("%s", err.Error()) } arg := reflect.New(rf.param) if err := json.Unmarshal(bits, arg.Interface()); err != nil { - return nil, invalidParamsError(err.Error()) + return nil, invalidParamsError("%s", err.Error()) } return []reflect.Value{reflect.ValueOf(ctx), arg}, nil } diff --git a/rpc/jsonrpc/server/ws_handler.go b/rpc/jsonrpc/server/ws_handler.go index d3c2b9153..d59d6c9d1 100644 --- a/rpc/jsonrpc/server/ws_handler.go +++ b/rpc/jsonrpc/server/ws_handler.go @@ -321,7 +321,7 @@ func (wsc *wsConnection) readRoutine(ctx context.Context) { rpcFunc := wsc.funcMap[request.Method] if rpcFunc == nil { if err := wsc.WriteRPCResponse(writeCtx, - request.MakeErrorf(rpctypes.CodeMethodNotFound, request.Method)); err != nil { + request.MakeErrorf(rpctypes.CodeMethodNotFound, "method %s not found", request.Method)); err != nil { wsc.Logger.Error("error writing RPC response", "err", err) } continue diff --git a/scripts/keymigrate/migrate_test.go b/scripts/keymigrate/migrate_test.go index f21617107..11f9bb38e 100644 --- a/scripts/keymigrate/migrate_test.go +++ b/scripts/keymigrate/migrate_test.go @@ -1,7 +1,6 @@ package keymigrate import ( - "context" "errors" "fmt" "math" diff --git a/types/proposal_test.go b/types/proposal_test.go index b18b6d19b..d5f87d9c1 100644 --- a/types/proposal_test.go +++ b/types/proposal_test.go @@ -2,11 +2,12 @@ package types import ( "context" - "github.com/tendermint/tendermint/version" "math" "testing" "time" + "github.com/tendermint/tendermint/version" + "github.com/gogo/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -128,8 +129,7 @@ func BenchmarkProposalWriteSignBytes(b *testing.B) { } func BenchmarkProposalSign(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := b.Context() privVal := NewMockPV() @@ -147,8 +147,7 @@ func BenchmarkProposalSign(b *testing.B) { func BenchmarkProposalVerifySignature(b *testing.B) { testProposal := getTestProposal(b) pbp := testProposal.ToProto() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := b.Context() privVal := NewMockPV() err := privVal.SignProposal(ctx, "test_chain_id", pbp) diff --git a/types/validator_set_test.go b/types/validator_set_test.go index 52df6aeb4..41da55c0a 100644 --- a/types/validator_set_test.go +++ b/types/validator_set_test.go @@ -366,7 +366,7 @@ func TestProposerSelection3(t *testing.T) { got := vset.GetProposer().Address expected := proposerOrder[j%4].Address if !bytes.Equal(got, expected) { - t.Fatalf(fmt.Sprintf("vset.Proposer (%X) does not match expected proposer (%X) for (%d, %d)", got, expected, i, j)) + t.Fatalf("vset.Proposer (%X) does not match expected proposer (%X) for (%d, %d)", got, expected, i, j) } // serialize, deserialize, check proposer @@ -377,13 +377,11 @@ func TestProposerSelection3(t *testing.T) { if i != 0 { if !bytes.Equal(got, computed.Address) { t.Fatalf( - fmt.Sprintf( - "vset.Proposer (%X) does not match computed proposer (%X) for (%d, %d)", - got, - computed.Address, - i, - j, - ), + "vset.Proposer (%X) does not match computed proposer (%X) for (%d, %d)", + got, + computed.Address, + i, + j, ) } } @@ -1557,7 +1555,7 @@ func BenchmarkUpdates(b *testing.B) { } func BenchmarkValidatorSet_VerifyCommit_Ed25519(b *testing.B) { // nolint - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() for _, n := range []int{1, 8, 64, 1024} { @@ -1585,7 +1583,7 @@ func BenchmarkValidatorSet_VerifyCommit_Ed25519(b *testing.B) { // nolint } func BenchmarkValidatorSet_VerifyCommitLight_Ed25519(b *testing.B) { // nolint - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() for _, n := range []int{1, 8, 64, 1024} { @@ -1614,7 +1612,7 @@ func BenchmarkValidatorSet_VerifyCommitLight_Ed25519(b *testing.B) { // nolint } func BenchmarkValidatorSet_VerifyCommitLightTrusting_Ed25519(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() for _, n := range []int{1, 8, 64, 1024} { diff --git a/types/vote_test.go b/types/vote_test.go index a3119db25..ca364d389 100644 --- a/types/vote_test.go +++ b/types/vote_test.go @@ -545,7 +545,7 @@ func BenchmarkVoteSignBytes(b *testing.B) { } func BenchmarkCommitVoteSignBytes(b *testing.B) { - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(b.Context()) defer cancel() sampleCommit := getSampleCommit(ctx, b) From 41458d9887630ebb1e8d865a359690852ce4fc1c Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 18 Aug 2025 16:16:57 +0200 Subject: [PATCH 03/41] seems fine --- internal/p2p/address_test.go | 11 +- libs/utils/channels.go | 74 +++++++++++ libs/utils/mutex.go | 206 ++++++++++++++++++++++++++++++ libs/utils/mutex_test.go | 39 ++++++ libs/utils/option.go | 73 +++++++++++ libs/utils/option_test.go | 32 +++++ libs/utils/proto.go | 143 +++++++++++++++++++++ libs/utils/require/require.go | 81 ++++++++++++ libs/utils/ringbuf.go | 83 ++++++++++++ libs/utils/scope/parallel.go | 41 ++++++ libs/utils/scope/parallel_test.go | 54 ++++++++ libs/utils/scope/start.go | 143 +++++++++++++++++++++ libs/utils/semaphore.go | 24 ++++ libs/utils/testonly.go | 152 ++++++++++++++++++++++ libs/utils/wait.go | 119 +++++++++++++++++ libs/utils/wait_test.go | 23 ++++ 16 files changed, 1292 insertions(+), 6 deletions(-) create mode 100644 libs/utils/channels.go create mode 100644 libs/utils/mutex.go create mode 100644 libs/utils/mutex_test.go create mode 100644 libs/utils/option.go create mode 100644 libs/utils/option_test.go create mode 100644 libs/utils/proto.go create mode 100644 libs/utils/require/require.go create mode 100644 libs/utils/ringbuf.go create mode 100644 libs/utils/scope/parallel.go create mode 100644 libs/utils/scope/parallel_test.go create mode 100644 libs/utils/scope/start.go create mode 100644 libs/utils/semaphore.go create mode 100644 libs/utils/testonly.go create mode 100644 libs/utils/wait.go create mode 100644 libs/utils/wait_test.go diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index 468c046b2..4c4bff025 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -292,14 +292,13 @@ func TestNodeAddress_Resolve(t *testing.T) { endpoints, err := addr.Resolve(t.Context()) require.NoError(t, err) - want := []*p2p.Endpoint{ - {Protocol: "tcp", IP: net.IPv6loopback, Port: 80, Path: "/path"}, - {Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1), Port: 80, Path: "/path"}, - } - + want := &p2p.Endpoint{Protocol: "tcp", Port: 80, Path: "/path"} require.True(t, len(endpoints)>0) for _, got := range endpoints { - require.Contains(t, want, got) + require.True(t, got.IP.IsLoopback()) + // Any loopback address is acceptable, so ignore it in comparison. + want.IP = got.IP + require.Equal(t, want, got) } }) } diff --git a/libs/utils/channels.go b/libs/utils/channels.go new file mode 100644 index 000000000..9eed500ff --- /dev/null +++ b/libs/utils/channels.go @@ -0,0 +1,74 @@ +package utils + +import ( + "context" + + "github.com/pkg/errors" +) + +// Recv receives a value from a channel or returns an error if the context is canceled. +func Recv[T any](ctx context.Context, ch <-chan T) (zero T, err error) { + select { + case v, ok := <-ch: + if ok { + return v, nil + } + // We are not interested in channel closing, + // patiently wait for the context to be done instead. + <-ctx.Done() + return zero, ctx.Err() + case <-ctx.Done(): + return zero, ctx.Err() + } +} + +// RecvOrClosed receives a value from a channel, returns false if channel got closed, +// or returns an error if the context is canceled. +func RecvOrClosed[T any](ctx context.Context, ch <-chan T) (T, bool, error) { + select { + case v, ok := <-ch: + return v, ok, nil + case <-ctx.Done(): + var zero T + return zero, false, ctx.Err() + } +} + +// Send a value to channel or returns an error if the context is canceled. +func Send[T any](ctx context.Context, ch chan<- T, v T) error { + select { + case ch <- v: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// SendOrDrop send a value to channel if not full or drop the item if the channel is full. +func SendOrDrop[T any](ch chan<- T, v T) error { + select { + case ch <- v: + return nil + default: + // drop the item + return nil + } +} + +// ForEach is a helper function that reads from a channel and calls a handler for each item. +// this avoids needing a lot of for/select boilerplate everywhere. +func ForEach[T any](ctx context.Context, ch <-chan T, handler func(T) error) error { + for { + select { + case <-ctx.Done(): + return errors.WithStack(ctx.Err()) + case item, ok := <-ch: + if !ok { + return nil // Channel closed + } + if err := handler(item); err != nil { + return err // Stop on error + } + } + } +} diff --git a/libs/utils/mutex.go b/libs/utils/mutex.go new file mode 100644 index 000000000..b6f4a9a58 --- /dev/null +++ b/libs/utils/mutex.go @@ -0,0 +1,206 @@ +package utils + +import ( + "context" + "iter" + "sync" + "sync/atomic" + + "golang.org/x/sync/errgroup" +) + +// Mutex guards access to object of type T. +type Mutex[T any] struct { + mu sync.Mutex + value T +} + +// NewMutex creates a new Mutex with given object. +func NewMutex[T any](value T) (m Mutex[T]) { + m.value = value + // nolint:nakedret + return +} + +// Lock returns an iterator which locks the mutex and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *Mutex[T]) Lock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.Lock() + defer m.mu.Unlock() + _ = yield(m.value) + } +} + +// version of the value stored in an atomic watch. +type version[T any] struct { + updated chan struct{} + value T +} + +// newVersion constructs a new active version. +func newVersion[T any](value T) *version[T] { + return &version[T]{make(chan struct{}), value} +} + +type atomicWatch[T any] struct { + ptr atomic.Pointer[version[T]] +} + +// AtomicWatch stores a pointer to an IMMUTABLE value. +// Loading and waiting for updates do NOT require locking. +// TODO(gprusak): remove mutex and rename to AtomicSend, +// this will allow for sharing a mutex across multiple AtomicSenders. +type AtomicWatch[T any] struct { + atomicWatch[T] + mu sync.Mutex +} + +// AtomicRecv is a read-only reference to AtomicWatch. +type AtomicRecv[T any] struct{ *atomicWatch[T] } + +// NewAtomicWatch creates a new AtomicWatch with the given initial value. +func NewAtomicWatch[T any](value T) (w AtomicWatch[T]) { + w.ptr.Store(newVersion(value)) + // nolint:nakedret + return +} + +// Subscribe returns a view-only API of the atomic watch. +func (w *AtomicWatch[T]) Subscribe() AtomicRecv[T] { + return AtomicRecv[T]{&w.atomicWatch} +} + +// Load returns the current value of the atomic watch. +// Does not do any locking. +func (w *atomicWatch[T]) Load() T { return w.ptr.Load().value } + +// Store updates the value of the atomic watch. +func (w *AtomicWatch[T]) Store(value T) { + w.mu.Lock() + defer w.mu.Unlock() + close(w.ptr.Swap(newVersion(value)).updated) +} + +// Update conditionally updates the value of the atomic watch. +func (w *AtomicWatch[T]) Update(f func(T) (T, bool)) { + w.mu.Lock() + defer w.mu.Unlock() + old := w.ptr.Load() + if value, ok := f(old.value); ok { + w.ptr.Store(newVersion(value)) + close(old.updated) + } +} + +// Wait waits for the value of the atomic watch to satisfy the predicate. +// Does not do any locking. +func (w *atomicWatch[T]) Wait(ctx context.Context, pred func(T) bool) (T, error) { + for { + v := w.ptr.Load() + if pred(v.value) { + return v.value, nil + } + select { + case <-ctx.Done(): + return Zero[T](), ctx.Err() + case <-v.updated: + } + } +} + +// Iter executes sequentially the function f on each value of the atomic watch. +// Context passed to f is canceled when the next value is available. +// Exits when the returned error is different from nil and context.Canceled, +// or when the context passed to Iter is canceled (after f exits). +func (w *atomicWatch[T]) Iter(ctx context.Context, f func(ctx context.Context, v T) error) error { + for ctx.Err() == nil { + v := w.ptr.Load() + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { return f(ctx, v.value) }) + g.Go(func() error { + select { + case <-ctx.Done(): + case <-v.updated: + } + return context.Canceled + }) + if err := IgnoreCancel(g.Wait()); err != nil { + return err + } + } + return ctx.Err() +} + +// WatchCtrl controls the locked object in a Watch. +// It is provided only in the iterator returned by Lock(). +// Should NOT be stored anywhere. +type WatchCtrl struct { + mu sync.Mutex + updated chan struct{} +} + +// Watch stores a value of type T. +// Essentially a mutex, that can be awaited for updates. +type Watch[T any] struct { + ctrl WatchCtrl + val T +} + +// NewWatch constructs a new watch with the given value. +// Note that value in the watch cannot be changed, so T +// should be a pointer type if updates are required. +func NewWatch[T any](val T) Watch[T] { + return Watch[T]{ + WatchCtrl{updated: make(chan struct{})}, + val, + } +} + +// Wait waits for the value in the watch to be updated. +// Should be called only after locking the watch, i.e. within Lock() iterator. +// It unlocks -> waits for the update -> locks again. +func (c *WatchCtrl) Wait(ctx context.Context) error { + updated := c.updated + c.mu.Unlock() + defer c.mu.Lock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-updated: + return nil + } +} + +// WaitUntil waits for the value in the watch to satisfy the predicate. +// Should be called only after locking the watch, i.e. within Lock() iterator. +// The predicate is evaluated under the lock, so it can access the guarded object. +func (c *WatchCtrl) WaitUntil(ctx context.Context, pred func() bool) error { + for !pred() { + if err := c.Wait(ctx); err != nil { + return err + } + } + return nil +} + +// Updated signals waiters that the value in the watch has been updated. +func (c *WatchCtrl) Updated() { + close(c.updated) + c.updated = make(chan struct{}) +} + +// Lock returns an iterator which locks the watch and yields the guarded object. +// The watch is unlocked when the iterator is done. +// If the watch is nil, the iterator is a no-op. +// Additionally the WatchCtrl object is provided to the yield function: +// * to unlock -> wait for the update -> lock again, call ctrl.Wait(ctx) +// * to signal an update, call ctrl.Updated(). +func (w *Watch[T]) Lock() iter.Seq2[T, *WatchCtrl] { + return func(yield func(val T, ctrl *WatchCtrl) bool) { + w.ctrl.mu.Lock() + defer w.ctrl.mu.Unlock() + _ = yield(w.val, &w.ctrl) + } +} diff --git a/libs/utils/mutex_test.go b/libs/utils/mutex_test.go new file mode 100644 index 000000000..b4a85abbc --- /dev/null +++ b/libs/utils/mutex_test.go @@ -0,0 +1,39 @@ +package utils_test + +import ( + "context" + "fmt" + "testing" + + "github.com/tendermint/tendermint/libs/utils/require" + "github.com/tendermint/tendermint/libs/utils/scope" + "github.com/tendermint/tendermint/libs/utils" +) + +func TestAtomicWatch(t *testing.T) { + ctx := t.Context() + v := 5 + w := utils.NewAtomicWatch(&v) + require.Equal(t, 5, *w.Load()) + + want := 10 + if err := scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { + for i := 0; i <= want; i++ { + w.Store(&i) + } + return nil + }) + + got, err := w.Wait(ctx, func(v *int) bool { return *v >= want }) + if err != nil { + return err + } + if *got != want { + return fmt.Errorf("got %v, want %v", *got, want) + } + return nil + }); err != nil { + t.Fatal(err) + } +} diff --git a/libs/utils/option.go b/libs/utils/option.go new file mode 100644 index 000000000..85fd6a471 --- /dev/null +++ b/libs/utils/option.go @@ -0,0 +1,73 @@ +package utils + +import ( + "encoding/json" +) + +// Option type inspired https://pkg.go.dev/github.com/samber/mo. +type Option[T any] struct { + ReadOnly + isPresent bool + value T +} + +// Some creates an Option with a value. +func Some[T any](value T) Option[T] { + return Option[T]{isPresent: true, value: value} +} + +// None creates an Option without a value. +func None[T any]() (zero Option[T]) { return } + +// Get unpacks the value from the Option, returning true if it was present. +func (o Option[T]) Get() (T, bool) { + if o.isPresent { + return o.value, true + } + return Zero[T](), false +} + +// IsPresent checks if the Option contains a value. +func (o Option[T]) IsPresent() bool { + return o.isPresent +} + +// Or returns the value if present, otherwise returns the default value. +func (o *Option[T]) Or(def T) T { + if o.isPresent { + return o.value + } + return def +} + +// MapOpt applies a function to the value if present, returning a new Option. +func MapOpt[T, R any](o Option[T], f func(T) R) Option[R] { + if o.isPresent { + return Some(f(o.value)) + } + return None[R]() +} + +// MarshalJSON implements the json.Marshaler interface. +// Note that it is defined on value, not pointer, because +// json.Marshal cannot call pointer methods on fields +// (i.e. it is broken by design). +func (o Option[T]) MarshalJSON() ([]byte, error) { + if o.isPresent { + return json.Marshal(o.value) + } + return []byte("null"), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (o *Option[T]) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + o.isPresent = false + return nil + } + if err := json.Unmarshal(data, &o.value); err != nil { + return err + } + o.isPresent = true + return nil +} diff --git a/libs/utils/option_test.go b/libs/utils/option_test.go new file mode 100644 index 000000000..04a55a1e1 --- /dev/null +++ b/libs/utils/option_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + "encoding/json" + "testing" + + "github.com/tendermint/tendermint/libs/utils/require" +) + +func testJSON[T any](t *testing.T, want T) { + enc, err := json.Marshal(want) + require.NoError(t, err) + t.Logf("%s", enc) + var got T + require.NoError(t, json.Unmarshal(enc, &got)) + require.NoError(t, TestDiff(want, got)) +} + +func TestOptionJSON(t *testing.T) { + type a struct { + X Option[int] + Y Option[string] + } + type b struct { + X Option[int] `json:"X,omitzero"` + Y Option[string] `json:"Y,omitzero"` + } + testJSON(t, &a{}) + testJSON(t, &a{Some(1), Some("a")}) + testJSON(t, &b{}) + testJSON(t, &b{Some(1), Some("a")}) +} diff --git a/libs/utils/proto.go b/libs/utils/proto.go new file mode 100644 index 000000000..5f5ad7a41 --- /dev/null +++ b/libs/utils/proto.go @@ -0,0 +1,143 @@ +package utils + +import ( + "crypto/sha256" + "errors" + "fmt" + "sync" + + "google.golang.org/protobuf/proto" +) + +// Hash is a SHA-256 hash. +type Hash [sha256.Size]byte + +// GetHash computes a hash of the given data. +func GetHash(data []byte) Hash { + return sha256.Sum256(data) +} + +// ParseHash parses a Hash from bytes. +func ParseHash(raw []byte) (Hash, error) { + if got, want := len(raw), sha256.Size; got != want { + return Hash{}, fmt.Errorf("hash size = %v, want %v", got, want) + } + return Hash(raw), nil +} + +// ProtoClone clones a proto.Message object. +func ProtoClone[T proto.Message](item T) T { + return proto.Clone(item).(T) +} + +// ProtoEqual compares two proto.Message objects. +func ProtoEqual[T proto.Message](a, b T) bool { + return proto.Equal(a, b) +} + +// ProtoHash hashes a proto.Message object. +// TODO(gprusak): make it deterministic. +func ProtoHash(a proto.Message) Hash { + raw, err := proto.Marshal(a) + if err != nil { + panic(err) + } + return sha256.Sum256(raw) +} + +// ProtoMessage is comparable proto.Message. +type ProtoMessage interface { + comparable + proto.Message +} + +// ProtoConv is a pair of functions to encode and decode between a type and a ProtoMessage. +type ProtoConv[T any, P ProtoMessage] struct { + Encode func(T) P + Decode func(P) (T, error) +} + +// EncodeSlice encodes a slice of T into a slice of P. +func (c ProtoConv[T, P]) EncodeSlice(t []T) []P { + p := make([]P, len(t)) + for i := range t { + p[i] = c.Encode(t[i]) + } + return p +} + +// DecodeSlice decodes a slice of P into a slice of T. +func (c ProtoConv[T, P]) DecodeSlice(p []P) ([]T, error) { + t := make([]T, len(p)) + var err error + for i := range p { + if t[i], err = c.Decode(p[i]); err != nil { + return nil, fmt.Errorf("[%d]: %w", i, err) + } + } + return t, nil +} + +// Slice constructs a slice. +// It is a syntax sugar for `[]T{v...}`, which avoids +// spelling out T. Not very useful if you need to spell +// out T to construct the elements: in that case +// you might prefer the []T{{...},{...}} syntax instead. +func Slice[T any](v ...T) []T { return v } + +// Alloc moves value to heap. +func Alloc[T any](v T) *T { return &v } + +// Zero returns a zero value of type T. +func Zero[T any]() (zero T) { return } + +// NoCopy may be added to structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +// +// Note that it must not be embedded, otherwise Lock and Unlock methods +// will be exported. +type NoCopy struct{} + +// Lock implements sync.Locker. +func (*NoCopy) Lock() {} + +// Unlock implements sync.Locker. +func (*NoCopy) Unlock() {} + +var _ sync.Locker = (*NoCopy)(nil) + +// NoCompare may be added to structs which must not be used as +// map keys. +type NoCompare [0]func() + +// EncodeOpt encodes Option[T], mapping None to Zero[P](). +func (c ProtoConv[T, P]) EncodeOpt(mv Option[T]) P { + v, ok := mv.Get() + if !ok { + return Zero[P]() + } + return c.Encode(v) +} + +// DecodeReq decodes a ProtoMessage into a T, returning an error if p is nil. +func (c ProtoConv[T, P]) DecodeReq(p P) (T, error) { + if p == Zero[P]() { + return Zero[T](), errors.New("missing") + } + return c.Decode(p) +} + +// DecodeOpt decodes a ProtoMessage into a T, returning nil if p is nil. +func (c ProtoConv[T, P]) DecodeOpt(p P) (Option[T], error) { + if p == Zero[P]() { + return None[T](), nil + } + t, err := c.DecodeReq(p) + if err != nil { + return None[T](), err + } + return Some(t), nil +} diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go new file mode 100644 index 000000000..66bb750d3 --- /dev/null +++ b/libs/utils/require/require.go @@ -0,0 +1,81 @@ +// Package require reexports strongly typed `testify/require` API. +// We don't reexport `New`, because methods cannot be generic. +package require + +import ( + "cmp" + + "github.com/stretchr/testify/require" +) + +// TestingT . +type TestingT = require.TestingT + +// False . +var False = require.False + +// True . +var True = require.True + +// Contains . +var Contains = require.Contains + +// EqualError . +// TODO: get rid of comparing errors by strings, +// use concrete error types instead. +var EqualError = require.EqualError + +// Error . +var Error = require.Error + +// ErrorIs . +var ErrorIs = require.ErrorIs + +// NoError . +var NoError = require.NoError + +// Empty . +var Empty = require.Empty + +// NotEmpty . +var NotEmpty = require.NotEmpty + +// Len . +var Len = require.Len + +// Nil . +var Nil = require.Nil + +// NotNil . +var NotNil = require.NotNil + +// Panics . +var Panics = require.Panics + +// Fail . +var Fail = require.Fail + +// Positive . +func Positive[T cmp.Ordered](t TestingT, e T, msgAndArgs ...any) { + require.Positive(t, e, msgAndArgs...) +} + +// Less . +func Less[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.Less(t, e1, e2, msgAndArgs...) +} + +// Greater . +func Greater[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.Greater(t, e1, e2, msgAndArgs...) +} + +// Equal . +func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { + require.Equal(t, expected, actual, msgAndArgs...) +} + +// NotEqual . +func NotEqual[T any](t TestingT, expected, actual T, msgAndArgs ...any) { + require.NotEqual(t, expected, actual, msgAndArgs...) +} diff --git a/libs/utils/ringbuf.go b/libs/utils/ringbuf.go new file mode 100644 index 000000000..5b81c3379 --- /dev/null +++ b/libs/utils/ringbuf.go @@ -0,0 +1,83 @@ +package utils + +import ( + "iter" +) + +// RingBuf is a ring buffer. +// NOT thread-safe. +type RingBuf[T any] struct { + first int + len int + buf []T +} + +// NewRingBuf creates a new ring buffer with the given capacity. +func NewRingBuf[T any](capacity int) RingBuf[T] { + return RingBuf[T]{first: 0, len: 0, buf: make([]T, capacity)} +} + +// Len returns the number of elements in the ring buffer. +func (r *RingBuf[T]) Len() int { + return r.len +} + +// Full returns true if the ring buffer is full. +func (r *RingBuf[T]) Full() bool { + return r.len == len(r.buf) +} + +// Get returns the i-th element of the ring buffer. +// Panics if i is out of range. +func (r *RingBuf[T]) Get(i int) T { + if i < 0 || i >= r.len { + panic("index out of range") + } + return r.buf[(r.first+i)%len(r.buf)] +} + +// TryGet returns the i-th element of the ring buffer. +func (r *RingBuf[T]) TryGet(i int) (T, bool) { + if i < 0 || i >= r.len { + return Zero[T](), false + } + return r.buf[(r.first+i)%len(r.buf)], true +} + +// Last returns the last element of the ring buffer. +func (r *RingBuf[T]) Last() (T, bool) { + return r.TryGet(r.len - 1) +} + +// PushBack adds an element to the back of the ring buffer. +// Panics if the ring buffer is full. +func (r *RingBuf[T]) PushBack(x T) { + if r.len == len(r.buf) { + panic("ring buffer full") + } + r.buf[(r.first+r.len)%len(r.buf)] = x + r.len += 1 +} + +// PopFront removes and returns the first element of the ring buffer. +// Panics if the ring buffer is empty. +func (r *RingBuf[T]) PopFront() T { + if r.len == 0 { + panic("ring buffer empty") + } + x := r.buf[r.first] + r.first = (r.first + 1) % len(r.buf) + r.len -= 1 + return x +} + +// All iterates over all the elements in the ring buffer. +func (r *RingBuf[T]) All() iter.Seq[T] { + return func(y func(T) bool) { + for i := range r.len { + if !y(r.Get(i)) { + break + } + } + } +} diff --git a/libs/utils/scope/parallel.go b/libs/utils/scope/parallel.go new file mode 100644 index 000000000..1377184d5 --- /dev/null +++ b/libs/utils/scope/parallel.go @@ -0,0 +1,41 @@ +package scope + +import ( + "sync" + "sync/atomic" +) + +type parallelScope struct { + wg sync.WaitGroup + err atomic.Pointer[error] +} + +// ParallelScope is a scope which doesn't require cancellation token, +// just parallelization. +type ParallelScope struct{ *parallelScope } + +// Spawn spawns a new task in the scope. +func (s *parallelScope) Spawn(t func() error) { + s.wg.Add(1) + go func() { + if err := t(); err != nil { + s.err.CompareAndSwap(nil, &err) + } + s.wg.Done() + }() +} + +// Parallel executes a function in parallel scope. +// Compared to Run, it does not allow for early cancellation, +// therefore is suitable for non-blocking computations. +// Returns the first error returned by any of the spawned tasks. +// Waits until all the tasks complete, before returning. +func Parallel(main func(ParallelScope) error) error { + var s parallelScope + s.Spawn(func() error { return main(ParallelScope{&s}) }) + s.wg.Wait() + if perr := s.err.Load(); perr != nil { + return *perr + } + return nil +} diff --git a/libs/utils/scope/parallel_test.go b/libs/utils/scope/parallel_test.go new file mode 100644 index 000000000..7f98872ad --- /dev/null +++ b/libs/utils/scope/parallel_test.go @@ -0,0 +1,54 @@ +package scope + +import ( + "errors" + "testing" +) + +func TestParallelOk(t *testing.T) { + x := [10]int{} + if err := Parallel(func(s ParallelScope) error { + for i := range x { + s.Spawn(func() error { + x[i] = i + return nil + }) + } + return nil + }); err != nil { + t.Fatal(err) + } + for want, got := range x { + if want != got { + t.Fatalf("x[%d] = %d, want %d", want, got, want) + } + } +} + +func TestParallelFail(t *testing.T) { + var wantErr = errors.New("custom err") + x := [10]int{} + err := Parallel(func(s ParallelScope) error { + for i := range x { + s.Spawn(func() error { + if i%2 == 0 { + return wantErr + } + x[i] = i + return nil + }) + } + return nil + }) + if !errors.Is(err, wantErr) { + t.Fatalf("err = %v, want %v", err, wantErr) + } + for want, got := range x { + if want%2 == 0 { + want = 0 + } + if want != got { + t.Fatalf("x[%d] = %d, want %d", want, got, want) + } + } +} diff --git a/libs/utils/scope/start.go b/libs/utils/scope/start.go new file mode 100644 index 000000000..cba8d2e4d --- /dev/null +++ b/libs/utils/scope/start.go @@ -0,0 +1,143 @@ +package scope + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "golang.org/x/sync/errgroup" + + "github.com/tendermint/tendermint/libs/utils" +) + +// Scope of concurrenct tasks. +type Scope struct { + // scope is a concurrecy primitive, so no-ctx-in-struct rule does not apply + // nolint:containedctx + ctx context.Context + all *errgroup.Group + main *sync.WaitGroup +} + +// Spawn spawns a main task. +// Scope gets automatically canceled when all the main tasks return. +func (s Scope) Spawn(t func() error) { + s.main.Add(1) + s.all.Go(func() error { + defer s.main.Done() + return t() + }) +} + +// JoinHandle is a handle to an awaitable task. +type JoinHandle[R any] struct { + result *utils.AtomicWatch[*R] +} + +// Spawn1 is the same as Scope.Spawn, but allows awaiting completion of a task and getting its result. +func Spawn1[R any](s Scope, t func() (R, error)) JoinHandle[R] { + result := utils.NewAtomicWatch[*R](nil) + s.Spawn(func() error { + v, err := t() + if err != nil { + return err + } + result.Store(&v) + return nil + }) + return JoinHandle[R]{&result} +} + +// Join awaits completion of a task and returns its result. +// WARNING: it does NOT return the error of the task - error is returned from the Run() command. +// Join() can only fail when context is canceled. +func (h JoinHandle[R]) Join(ctx context.Context) (R, error) { + res, err := h.result.Wait(ctx, func(v *R) bool { return v != nil }) + if err != nil { + return utils.Zero[R](), err + } + return *res, nil +} + +// If true, tasks that do not respect context cancellation will be logged. +// This is useful for debugging, but causes unnecessary overhead. +// Since this is a constant, debug guard should be optimized out by the compiler. +const enableDebugGuard = false + +func (s Scope) debugGuard(name string, done chan struct{}) { + select { + case <-done: + return + case <-s.ctx.Done(): + } + for { + select { + case <-done: + return + case <-time.After(10 * time.Second): + } + log.Printf("task %q still running", name) + } +} + +// SpawnNamed spawns a named main task. +func (s Scope) SpawnNamed(name string, t func() error) { + done := make(chan struct{}) + s.Spawn(func() error { + defer close(done) + if err := t(); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + return nil + }) + if enableDebugGuard { + go s.debugGuard(name, done) + } +} + +// SpawnBgNamed spawns a named background task. +func (s Scope) SpawnBgNamed(name string, t func() error) { + done := make(chan struct{}) + s.SpawnBg(func() error { + defer close(done) + if err := t(); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + return nil + }) + if enableDebugGuard { + go s.debugGuard(name, done) + } +} + +// SpawnBg spawns a background task. +// Background tasks get canceled when all the main tasks return. +func (s Scope) SpawnBg(t func() error) { s.all.Go(t) } + +// Run runs a scope capable of spawning tasks. +// It is guaranteed that all the spawned tasks will be executed (even if spawned after the context is cancelled), +// and that `Run` will return only after all the tasks have completed. +// Context of the tasks will be automatically cancelled as soon as ANY task returns an error. +// Returns the first error returned by any task (main or background). +func Run(ctx context.Context, main func(context.Context, Scope) error) error { + ctx, cancel := context.WithCancel(ctx) + all, ctx := errgroup.WithContext(ctx) + s := Scope{ctx, all, &sync.WaitGroup{}} + s.Spawn(func() error { return main(ctx, s) }) + s.main.Wait() + cancel() + return s.all.Wait() +} + +// Run1 is the same as Run, but returns the result of the main task. +func Run1[R any](ctx context.Context, main func(context.Context, Scope) (R, error)) (res R, err error) { + err = Run(ctx, func(ctx context.Context, s Scope) error { + var err error + res, err = main(ctx, s) + return err + }) + //nolint:nakedret + return +} diff --git a/libs/utils/semaphore.go b/libs/utils/semaphore.go new file mode 100644 index 000000000..728c12a5c --- /dev/null +++ b/libs/utils/semaphore.go @@ -0,0 +1,24 @@ +package utils + +import ( + "context" +) + +// Semaphore provides a way to bound concurrenct access to a resource. +type Semaphore struct { + ch chan struct{} +} + +// NewSemaphore constructs a new semaphore with n permits. +func NewSemaphore(n int) *Semaphore { + return &Semaphore{ch: make(chan struct{}, n)} +} + +// Acquire acquires a permit from the semaphore. +// Blocks until a permit is available. +func (s *Semaphore) Acquire(ctx context.Context) (relase func(), err error) { + if err := Send(ctx, s.ch, struct{}{}); err != nil { + return nil, err + } + return func() { <-s.ch }, nil +} diff --git a/libs/utils/testonly.go b/libs/utils/testonly.go new file mode 100644 index 000000000..afd6b8aa8 --- /dev/null +++ b/libs/utils/testonly.go @@ -0,0 +1,152 @@ +package utils + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" +) + +// ReadOnly - if a struct embeds ReadOnly, +// its private fields will be compared by TestEqual. +type ReadOnly struct{} + +// isReadOnly returns true if t embeds ReadOnly. +func isReadOnly(t reflect.Type) bool { + want := reflect.TypeOf(ReadOnly{}) + if t.Kind() != reflect.Struct { + return false + } + for i := range t.NumField() { + if f := t.Field(i); f.Anonymous || f.Type == want { + return true + } + } + return false +} + +func cmpComparer[T any, PT interface { + Cmp(b *T) int + *T +}](a PT, b PT) bool { + if a == nil || b == nil { + return a == b + } + return a.Cmp(b) == 0 +} + +var cmpOpts = []cmp.Option{ + protocmp.Transform(), + cmp.Exporter(isReadOnly), + cmpopts.EquateEmpty(), + cmp.Comparer(cmpComparer[big.Int]), +} + +// TestDiff generates a human-readable diff between two objects. +func TestDiff[T any](want, got T) error { + if diff := cmp.Diff(want, got, cmpOpts...); diff != "" { + return fmt.Errorf("want (-) got (+):\n%s", diff) + } + return nil +} + +// TestEqual is a more robust replacement for reflect.DeepEqual for tests. +func TestEqual[T any](a, b T) bool { + return cmp.Equal(a, b, cmpOpts...) +} + +// TestRngSplit returns a new random number splitted from the given one. +// This is a very primitive splitting, known to result with dependent randomness. +// If that ever causes a problem, we can switch to SplitMix. +func TestRngSplit(rng *rand.Rand) *rand.Rand { + return rand.New(rand.NewSource(rng.Int63())) +} + +// TestRng returns a deterministic random number generator. +func TestRng() *rand.Rand { + return rand.New(rand.NewSource(789345342)) +} + +var alphanum = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + +// GenString generates a random string of length n. +func GenString(rng *rand.Rand, n int) string { + s := make([]rune, n) + for i := range n { + s[i] = alphanum[rand.Intn(len(alphanum))] + } + return string(s) +} + +// GenBytes generates a random byte slice. +func GenBytes(rng *rand.Rand, n int) []byte { + s := make([]byte, n) + _, _ = rng.Read(s) + return s +} + +// GenF is a function which generates T. +type GenF[T any] = func(rng *rand.Rand) T + +// GenSlice generates a slice of small random length. +func GenSlice[T any](rng *rand.Rand, gen GenF[T]) []T { + return GenSliceN(rng, 2+rng.Intn(3), gen) +} + +// GenSliceN generates a slice of n elements. +func GenSliceN[T any](rng *rand.Rand, n int, gen GenF[T]) []T { + s := make([]T, n) + for i := range s { + s[i] = gen(rng) + } + return s +} + +// GenMap generates a map of small random length. +func GenMap[K comparable, V any](rng *rand.Rand, genK GenF[K], genV GenF[V]) map[K]V { + return GenMapN(rng, 2+rng.Intn(3), genK, genV) +} + +// GenMapN generates a map of n elements. +func GenMapN[K comparable, V any](rng *rand.Rand, n int, genK GenF[K], genV GenF[V]) map[K]V { + m := make(map[K]V, n) + for len(m) < n { + m[genK(rng)] = genV(rng) + } + return m +} + +// GenTimestamp generates a random timestamp. +func GenTimestamp(rng *rand.Rand) time.Time { + return time.Unix(0, rng.Int63()) +} + +// GenHash generates a random Hash. +func GenHash(rng *rand.Rand) Hash { + var h Hash + _, _ = rng.Read(h[:]) + return h +} + +// Test tests whether reencoding a value is an identity operation. +func (c *ProtoConv[T, P]) Test(want T) error { + p := c.Encode(want) + raw, err := proto.Marshal(p) + if err != nil { + return fmt.Errorf("Marshal(): %w", err) + } + if err := proto.Unmarshal(raw, p); err != nil { + return fmt.Errorf("Unmarshal(): %w", err) + } + got, err := c.Decode(p) + if err != nil { + return fmt.Errorf("Decode(Encode()): %w", err) + } + return TestDiff(want, got) +} diff --git a/libs/utils/wait.go b/libs/utils/wait.go new file mode 100644 index 000000000..4c8c6634f --- /dev/null +++ b/libs/utils/wait.go @@ -0,0 +1,119 @@ +package utils + +import ( + "context" + "encoding" + "errors" + "time" +) + +// IgnoreCancel returns nil if the error is context.Canceled, err otherwise. +func IgnoreCancel(err error) error { + if errors.Is(err, context.Canceled) { + return nil + } + return err +} + +// WithTimeout executes a function with a timeout. +func WithTimeout(ctx context.Context, d time.Duration, f func(ctx context.Context) error) error { + ctx, cancel := context.WithTimeout(ctx, d) + defer cancel() + return f(ctx) +} + +// WithTimeout1 executes a function with a timeout. +func WithTimeout1[R any](ctx context.Context, d time.Duration, f func(ctx context.Context) (R, error)) (R, error) { + ctx, cancel := context.WithTimeout(ctx, d) + defer cancel() + return f(ctx) +} + +// Sleep sleeps for a duration or until the context is canceled. +func Sleep(ctx context.Context, d time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(d): + return nil + } +} + +// SleepUntil sleeps until deadline t or until the context is canceled. +func SleepUntil(ctx context.Context, t time.Time) error { + return Sleep(ctx, time.Until(t)) +} + +// WaitFor polls a check function until it returns true or the context is canceled. +func WaitFor(ctx context.Context, interval time.Duration, check func() bool) error { + if check() { + return nil + } + ticker := time.NewTicker(interval) + for { + if check() { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// WaitForWithTimeout polls a check function until it returns true, the context is canceled, or the timeout is reached. +func WaitForWithTimeout(ctx context.Context, interval, timeout time.Duration, check func() bool) error { + if check() { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + if check() { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Duration is a wrapper type around time.Duration that supports JSON marshaling/unmarshaling. +// nolint:recvcheck +type Duration time.Duration + +// MarshalText implements json.TextMarshaler interface to convert Duration to JSON string. +func (d Duration) MarshalText() ([]byte, error) { + return []byte(time.Duration(d).String()), nil +} + +// UnmarshalText implements json.TextUnmarshaler. +func (d *Duration) UnmarshalText(b []byte) error { + tmp, err := time.ParseDuration(string(b)) + if err != nil { + return err + } + *d = Duration(tmp) + return nil +} + +var _ encoding.TextMarshaler = Zero[Duration]() +var _ encoding.TextUnmarshaler = (*Duration)(nil) + +// Duration returns the underlying time.Duration value. +func (d Duration) Duration() time.Duration { + return time.Duration(d) +} + +// Seconds returns the underlying time.Duration value in seconds. +func (d Duration) Seconds() float64 { + return time.Duration(d).Seconds() +} diff --git a/libs/utils/wait_test.go b/libs/utils/wait_test.go new file mode 100644 index 000000000..91edc1267 --- /dev/null +++ b/libs/utils/wait_test.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/json" + "testing" + "time" +) + +func TestJSON(t *testing.T) { + var got, want struct{ X Duration } + want.X = Duration(100 * time.Millisecond) + j, err := json.Marshal(want) + if err != nil { + t.Fatal(err) + } + t.Logf("%s", j) + if err := json.Unmarshal(j, &got); err != nil { + t.Fatal(err) + } + if err := TestDiff(want, got); err != nil { + t.Fatal(err) + } +} From f787ad444d4cb927d9ccb25e4cc3bc464cfbfe27 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 18 Aug 2025 18:36:04 +0200 Subject: [PATCH 04/41] should work --- abci/example/example_test.go | 9 +- abci/example/kvstore/kvstore_test.go | 23 +- abci/tests/client_server_test.go | 4 +- cmd/tendermint/commands/reindex_event_test.go | 3 +- cmd/tendermint/commands/reset_test.go | 9 +- cmd/tendermint/commands/rollback_test.go | 15 +- cmd/tendermint/commands/root_test.go | 9 +- internal/blocksync/pool_test.go | 17 +- internal/blocksync/reactor_test.go | 6 +- internal/consensus/mempool_test.go | 15 +- internal/consensus/msgs_test.go | 4 +- internal/consensus/pbts_test.go | 15 +- internal/consensus/replay_test.go | 25 +-- internal/consensus/state_test.go | 151 +++++-------- .../consensus/types/height_vote_set_test.go | 3 +- internal/consensus/wal_test.go | 13 +- internal/eventbus/event_bus_test.go | 30 +-- internal/eventlog/eventlog_test.go | 1 - internal/evidence/pool_test.go | 30 +-- internal/evidence/reactor_test.go | 18 +- internal/evidence/verify_test.go | 18 +- internal/inspect/inspect_test.go | 115 +++++----- internal/libs/autofile/autofile_test.go | 7 +- internal/libs/autofile/group_test.go | 21 +- internal/libs/queue/queue_test.go | 6 +- internal/mempool/mempool_bench_test.go | 4 +- internal/mempool/mempool_test.go | 71 ++---- internal/mempool/reactor_test.go | 21 +- internal/p2p/address_test.go | 2 +- internal/p2p/channel_test.go | 44 ++-- internal/p2p/conn/connection_test.go | 42 ++-- internal/p2p/peermanager_scoring_test.go | 4 +- internal/p2p/peermanager_test.go | 102 ++++----- internal/p2p/pex/reactor_test.go | 27 +-- internal/p2p/pqueue_test.go | 4 +- internal/p2p/router_filter_test.go | 3 +- internal/p2p/router_init_test.go | 4 +- internal/p2p/router_test.go | 52 ++--- internal/p2p/rqueue_test.go | 4 +- internal/p2p/transport_mconn_test.go | 13 +- internal/p2p/transport_test.go | 75 ++----- internal/proxy/client_test.go | 27 +-- internal/pubsub/example_test.go | 4 +- internal/pubsub/pubsub_test.go | 44 ++-- internal/state/execution_test.go | 38 ++-- internal/state/indexer/block/kv/kv_test.go | 4 +- .../state/indexer/indexer_service_test.go | 4 +- internal/state/indexer/sink/kv/kv_test.go | 3 +- internal/state/indexer/sink/null/null_test.go | 4 +- internal/state/indexer/sink/psql/psql_test.go | 4 +- internal/state/rollback_test.go | 4 +- internal/state/state_test.go | 4 +- internal/state/store_test.go | 10 +- internal/state/validation_test.go | 10 +- internal/statesync/block_queue_test.go | 18 +- internal/statesync/reactor_test.go | 33 +-- internal/statesync/syncer_test.go | 56 +---- libs/cli/setup_test.go | 12 +- libs/events/events_test.go | 15 +- libs/service/service_test.go | 23 +- libs/utils/channels.go | 74 ------- libs/utils/mutex.go | 206 ------------------ libs/utils/mutex_test.go | 39 ---- libs/utils/option.go | 73 ------- libs/utils/option_test.go | 32 --- libs/utils/proto.go | 143 ------------ libs/utils/require/require.go | 81 ------- libs/utils/ringbuf.go | 83 ------- libs/utils/scope/parallel.go | 41 ---- libs/utils/scope/parallel_test.go | 54 ----- libs/utils/scope/start.go | 143 ------------ libs/utils/semaphore.go | 24 -- libs/utils/testonly.go | 152 ------------- libs/utils/wait.go | 119 ---------- libs/utils/wait_test.go | 23 -- light/client_benchmark_test.go | 1 + light/client_test.go | 64 ++---- light/detector_test.go | 58 ++--- light/dispatcher_test.go | 18 +- light/example_test.go | 4 +- light/light_test.go | 9 +- light/store/db/db_test.go | 15 +- node/node_test.go | 52 ++--- privval/file_test.go | 16 +- privval/grpc/client_test.go | 9 +- privval/grpc/server_test.go | 10 +- privval/signer_client_test.go | 30 +-- privval/signer_listener_endpoint_test.go | 6 +- rpc/client/examples_test.go | 7 +- rpc/client/helpers_test.go | 4 +- rpc/client/mock/abci_test.go | 10 +- rpc/client/mock/status_test.go | 4 +- rpc/client/rpc_test.go | 46 ++-- rpc/jsonrpc/client/integration_test.go | 3 +- rpc/jsonrpc/client/ws_client_test.go | 6 +- rpc/jsonrpc/jsonrpc_test.go | 3 +- rpc/jsonrpc/server/http_server_test.go | 7 +- scripts/confix/confix_test.go | 1 - scripts/scmigrate/migrate_test.go | 7 +- test/e2e/tests/block_test.go | 3 +- test/e2e/tests/e2e_test.go | 3 +- test/e2e/tests/evidence_test.go | 4 +- test/e2e/tests/validator_test.go | 6 +- types/block_test.go | 33 +-- types/evidence_test.go | 27 +-- types/light_test.go | 10 +- types/proposal_test.go | 8 +- types/validation_test.go | 20 +- types/validator_set_test.go | 18 +- types/validator_test.go | 6 +- types/vote_set_test.go | 21 +- types/vote_test.go | 30 +-- 112 files changed, 680 insertions(+), 2542 deletions(-) delete mode 100644 libs/utils/channels.go delete mode 100644 libs/utils/mutex.go delete mode 100644 libs/utils/mutex_test.go delete mode 100644 libs/utils/option.go delete mode 100644 libs/utils/option_test.go delete mode 100644 libs/utils/proto.go delete mode 100644 libs/utils/require/require.go delete mode 100644 libs/utils/ringbuf.go delete mode 100644 libs/utils/scope/parallel.go delete mode 100644 libs/utils/scope/parallel_test.go delete mode 100644 libs/utils/scope/start.go delete mode 100644 libs/utils/semaphore.go delete mode 100644 libs/utils/testonly.go delete mode 100644 libs/utils/wait.go delete mode 100644 libs/utils/wait_test.go diff --git a/abci/example/example_test.go b/abci/example/example_test.go index df514d8fa..9baa2b35b 100644 --- a/abci/example/example_test.go +++ b/abci/example/example_test.go @@ -28,8 +28,7 @@ func init() { } func TestKVStore(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() t.Log("### Testing KVStore") @@ -37,8 +36,7 @@ func TestKVStore(t *testing.T) { } func TestBaseApp(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() t.Log("### Testing BaseApp") @@ -46,8 +44,7 @@ func TestBaseApp(t *testing.T) { } func TestGRPC(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() diff --git a/abci/example/kvstore/kvstore_test.go b/abci/example/kvstore/kvstore_test.go index 357917e71..100dab705 100644 --- a/abci/example/kvstore/kvstore_test.go +++ b/abci/example/kvstore/kvstore_test.go @@ -67,8 +67,7 @@ func testKVStore(ctx context.Context, t *testing.T, app types.Application, tx [] } func TestKVStoreKV(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() kvstore := NewApplication() key := testKey @@ -82,8 +81,7 @@ func TestKVStoreKV(t *testing.T) { } func TestPersistentKVStoreKV(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() dir := t.TempDir() logger := log.NewNopLogger() @@ -100,8 +98,7 @@ func TestPersistentKVStoreKV(t *testing.T) { } func TestPersistentKVStoreInfo(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() dir := t.TempDir() logger := log.NewNopLogger() @@ -144,8 +141,7 @@ func TestPersistentKVStoreInfo(t *testing.T) { // add a validator, remove a validator, update a validator func TestValUpdates(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() kvstore := NewApplication() @@ -313,16 +309,15 @@ func makeGRPCClientServer( } func TestClientServer(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() // set up socket app kvstore := NewApplication() client, server, err := makeSocketClientServer(ctx, t, logger, kvstore, "kvstore-socket") require.NoError(t, err) - t.Cleanup(func() { cancel(); server.Wait() }) - t.Cleanup(func() { cancel(); client.Wait() }) + t.Cleanup(func() { server.Wait() }) + t.Cleanup(func() { client.Wait() }) runClientTests(ctx, t, client) @@ -331,8 +326,8 @@ func TestClientServer(t *testing.T) { gclient, gserver, err := makeGRPCClientServer(ctx, t, logger, kvstore, "/tmp/kvstore-grpc") require.NoError(t, err) - t.Cleanup(func() { cancel(); gserver.Wait() }) - t.Cleanup(func() { cancel(); gclient.Wait() }) + t.Cleanup(func() { gserver.Wait() }) + t.Cleanup(func() { gclient.Wait() }) runClientTests(ctx, t, gclient) } diff --git a/abci/tests/client_server_test.go b/abci/tests/client_server_test.go index b3a66804f..a5e749dbb 100644 --- a/abci/tests/client_server_test.go +++ b/abci/tests/client_server_test.go @@ -1,7 +1,6 @@ package tests import ( - "context" "testing" "github.com/fortytw2/leaktest" @@ -16,8 +15,7 @@ import ( func TestClientServerNoAddrPrefix(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() const ( addr = "localhost:26658" diff --git a/cmd/tendermint/commands/reindex_event_test.go b/cmd/tendermint/commands/reindex_event_test.go index c845dd3ab..000f6cad8 100644 --- a/cmd/tendermint/commands/reindex_event_test.go +++ b/cmd/tendermint/commands/reindex_event_test.go @@ -175,8 +175,7 @@ func TestReIndexEvent(t *testing.T) { {height, height, false}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() conf := config.DefaultConfig() diff --git a/cmd/tendermint/commands/reset_test.go b/cmd/tendermint/commands/reset_test.go index eec7f3a64..ff27736b4 100644 --- a/cmd/tendermint/commands/reset_test.go +++ b/cmd/tendermint/commands/reset_test.go @@ -13,12 +13,13 @@ import ( ) func Test_ResetAll(t *testing.T) { + ctx := t.Context() config := cfg.TestConfig() dir := t.TempDir() config.SetRoot(dir) logger := log.NewNopLogger() cfg.EnsureRoot(dir) - require.NoError(t, initFilesWithConfig(t.Context(), config, logger, types.ABCIPubKeyTypeEd25519)) + require.NoError(t, initFilesWithConfig(ctx, config, logger, types.ABCIPubKeyTypeEd25519)) pv, err := privval.LoadFilePV(config.PrivValidator.KeyFile(), config.PrivValidator.StateFile()) require.NoError(t, err) pv.LastSignState.Height = 10 @@ -37,12 +38,13 @@ func Test_ResetAll(t *testing.T) { } func Test_ResetState(t *testing.T) { + ctx := t.Context() config := cfg.TestConfig() dir := t.TempDir() config.SetRoot(dir) logger := log.NewNopLogger() cfg.EnsureRoot(dir) - require.NoError(t, initFilesWithConfig(t.Context(), config, logger, types.ABCIPubKeyTypeEd25519)) + require.NoError(t, initFilesWithConfig(ctx, config, logger, types.ABCIPubKeyTypeEd25519)) pv, err := privval.LoadFilePV(config.PrivValidator.KeyFile(), config.PrivValidator.StateFile()) require.NoError(t, err) pv.LastSignState.Height = 10 @@ -61,12 +63,13 @@ func Test_ResetState(t *testing.T) { } func Test_UnsafeResetAll(t *testing.T) { + ctx := t.Context() config := cfg.TestConfig() dir := t.TempDir() config.SetRoot(dir) logger := log.NewNopLogger() cfg.EnsureRoot(dir) - require.NoError(t, initFilesWithConfig(t.Context(), config, logger, types.ABCIPubKeyTypeEd25519)) + require.NoError(t, initFilesWithConfig(ctx, config, logger, types.ABCIPubKeyTypeEd25519)) pv, err := privval.LoadFilePV(config.PrivValidator.KeyFile(), config.PrivValidator.StateFile()) require.NoError(t, err) pv.LastSignState.Height = 10 diff --git a/cmd/tendermint/commands/rollback_test.go b/cmd/tendermint/commands/rollback_test.go index bdbd06d46..bbf52050a 100644 --- a/cmd/tendermint/commands/rollback_test.go +++ b/cmd/tendermint/commands/rollback_test.go @@ -17,8 +17,6 @@ import ( func TestRollbackIntegration(t *testing.T) { var height int64 dir := t.TempDir() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() cfg, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) cfg.BaseConfig.DBBackend = "goleveldb" @@ -27,18 +25,17 @@ func TestRollbackIntegration(t *testing.T) { require.NoError(t, err) t.Run("First run", func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx := t.Context() require.NoError(t, err) node, _, err := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) require.NoError(t, err) require.True(t, node.IsRunning()) time.Sleep(3 * time.Second) - cancel() - node.Wait() - - require.False(t, node.IsRunning()) + t.Cleanup(func() { + node.Wait() + require.False(t, node.IsRunning()) + }) }) t.Run("Rollback", func(t *testing.T) { time.Sleep(time.Second) @@ -49,7 +46,7 @@ func TestRollbackIntegration(t *testing.T) { t.Run("Restart", func(t *testing.T) { require.True(t, height > 0, "%d", height) - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() node2, _, err2 := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) require.NoError(t, err2) diff --git a/cmd/tendermint/commands/root_test.go b/cmd/tendermint/commands/root_test.go index 99315823d..91c5d294f 100644 --- a/cmd/tendermint/commands/root_test.go +++ b/cmd/tendermint/commands/root_test.go @@ -77,8 +77,7 @@ func TestRootHome(t *testing.T) { {nil, map[string]string{"TMHOME": newRoot}, newRoot}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() for i, tc := range cases { t.Run(fmt.Sprint(i), func(t *testing.T) { @@ -114,8 +113,7 @@ func TestRootFlagsEnv(t *testing.T) { {nil, map[string]string{"TM_LOG_LEVEL": "debug"}, "debug"}, // right env } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() for i, tc := range cases { t.Run(fmt.Sprint(i), func(t *testing.T) { @@ -131,8 +129,7 @@ func TestRootFlagsEnv(t *testing.T) { } func TestRootConfig(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // write non-default config nonDefaultLogLvl := "debug" diff --git a/internal/blocksync/pool_test.go b/internal/blocksync/pool_test.go index cf1dc8a25..77c1166bd 100644 --- a/internal/blocksync/pool_test.go +++ b/internal/blocksync/pool_test.go @@ -1,7 +1,6 @@ package blocksync import ( - "context" "crypto/rand" "encoding/hex" "fmt" @@ -107,8 +106,7 @@ func makePeerManager(peers map[types.NodeID]testPeer) *p2p.PeerManager { return peerManager } func TestBlockPoolBasic(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() start := int64(42) peers := makePeers(10, start, 1000) @@ -120,7 +118,7 @@ func TestBlockPoolBasic(t *testing.T) { t.Error(err) } - t.Cleanup(func() { cancel(); pool.Wait() }) + t.Cleanup(func() { pool.Wait() }) peers.start() defer peers.stop() @@ -163,8 +161,7 @@ func TestBlockPoolBasic(t *testing.T) { } func TestBlockPoolTimeout(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() start := int64(42) peers := makePeers(10, start, 1000) @@ -175,9 +172,6 @@ func TestBlockPoolTimeout(t *testing.T) { if err != nil { t.Error(err) } - t.Cleanup(func() { - cancel() - }) // Introduce each peer. go func() { @@ -213,8 +207,7 @@ func TestBlockPoolTimeout(t *testing.T) { } func TestBlockPoolRemovePeer(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peers := make(testPeers, 10) for i := 0; i < 10; i++ { @@ -233,7 +226,7 @@ func TestBlockPoolRemovePeer(t *testing.T) { pool := NewBlockPool(log.NewNopLogger(), 1, requestsCh, errorsCh, makePeerManager(peers)) err := pool.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { cancel(); pool.Wait() }) + t.Cleanup(func() { pool.Wait() }) // add peers for peerID, peer := range peers { diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index 6f6be7fca..1816ad925 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -277,8 +277,7 @@ func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { } func TestReactor_AbruptDisconnect(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := config.ResetTestRoot(t.TempDir(), "block_sync_reactor_test") require.NoError(t, err) @@ -317,8 +316,7 @@ func TestReactor_AbruptDisconnect(t *testing.T) { } func TestReactor_SyncTime(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := config.ResetTestRoot(t.TempDir(), "block_sync_reactor_test") require.NoError(t, err) diff --git a/internal/consensus/mempool_test.go b/internal/consensus/mempool_test.go index 28f7199c9..abc67a331 100644 --- a/internal/consensus/mempool_test.go +++ b/internal/consensus/mempool_test.go @@ -32,8 +32,7 @@ func assertMempool(t *testing.T, txn txNotifier) mempool.Mempool { } func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() baseConfig := configSetup(t) @@ -61,9 +60,8 @@ func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { } func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { + ctx := t.Context() baseConfig := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() config, err := ResetConfig(t.TempDir(), "consensus_mempool_txs_available_test") require.NoError(t, err) @@ -87,9 +85,8 @@ func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { } func TestMempoolProgressInHigherRound(t *testing.T) { + ctx := t.Context() baseConfig := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() config, err := ResetConfig(t.TempDir(), "consensus_mempool_txs_available_test") require.NoError(t, err) @@ -144,8 +141,7 @@ func checkTxsRange(ctx context.Context, t *testing.T, cs *State, start, end int) } func TestMempoolTxConcurrentWithCommit(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) logger := log.NewNopLogger() @@ -183,9 +179,8 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { } func TestMempoolRmBadTx(t *testing.T) { + ctx := t.Context() config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() state, privVals := makeGenesisState(ctx, t, config, genesisStateArgs{ Validators: 1, diff --git a/internal/consensus/msgs_test.go b/internal/consensus/msgs_test.go index 3c2790764..33ca4c496 100644 --- a/internal/consensus/msgs_test.go +++ b/internal/consensus/msgs_test.go @@ -1,7 +1,6 @@ package consensus import ( - "context" "encoding/hex" "fmt" "github.com/tendermint/tendermint/version" @@ -26,8 +25,7 @@ import ( ) func TestMsgToProto(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() psh := types.PartSetHeader{ Total: 1, diff --git a/internal/consensus/pbts_test.go b/internal/consensus/pbts_test.go index 5123df930..0ed1dfd9c 100644 --- a/internal/consensus/pbts_test.go +++ b/internal/consensus/pbts_test.go @@ -338,8 +338,7 @@ func (hr heightResult) isComplete() bool { // until after the genesis time has passed. The test sets the genesis time in the // future and then ensures that the observed validator waits to propose a block. func TestProposerWaitsForGenesisTime(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // create a genesis time far (enough) in the future. initialTime := time.Now().Add(800 * time.Millisecond) @@ -369,8 +368,7 @@ func TestProposerWaitsForGenesisTime(t *testing.T) { // and then verifies that the observed validator waits until after the block time // of height 4 to propose a block at height 5. func TestProposerWaitsForPreviousBlock(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() initialTime := time.Now().Add(time.Millisecond * 50) cfg := pbtsTestConfiguration{ synchronyParams: types.SynchronyParams{ @@ -436,8 +434,7 @@ func TestProposerWaitTime(t *testing.T) { } func TestTimelyProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() initialTime := time.Now() @@ -458,8 +455,7 @@ func TestTimelyProposal(t *testing.T) { } func TestTooFarInThePastProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // localtime > proposedBlockTime + MsgDelay + Precision cfg := pbtsTestConfiguration{ @@ -479,8 +475,7 @@ func TestTooFarInThePastProposal(t *testing.T) { } func TestTooFarInTheFutureProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // localtime < proposedBlockTime - Precision cfg := pbtsTestConfiguration{ diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index 3f20da2d1..3f9b94935 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -136,8 +136,7 @@ func TestWALCrash(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() consensusReplayConfig, err := ResetConfig(t.TempDir(), tc.name) require.NoError(t, err) @@ -596,8 +595,7 @@ func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { // Sync from scratch func TestHandshakeReplayAll(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() sim := setupSimulator(ctx, t) @@ -613,8 +611,7 @@ func TestHandshakeReplayAll(t *testing.T) { // Sync many, not from scratch func TestHandshakeReplaySome(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() sim := setupSimulator(ctx, t) @@ -630,8 +627,7 @@ func TestHandshakeReplaySome(t *testing.T) { // Sync from lagging by one func TestHandshakeReplayOne(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() sim := setupSimulator(ctx, t) @@ -645,8 +641,7 @@ func TestHandshakeReplayOne(t *testing.T) { // Sync from caught up func TestHandshakeReplayNone(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() sim := setupSimulator(ctx, t) @@ -956,8 +951,7 @@ func TestHandshakeErrorsIfAppReturnsWrongAppHash(t *testing.T) { // - 0x02 // - 0x03 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := ResetConfig(t.TempDir(), "handshake_test_") require.NoError(t, err) @@ -992,7 +986,7 @@ func TestHandshakeErrorsIfAppReturnsWrongAppHash(t *testing.T) { proxyApp := proxy.New(client, logger, proxy.NopMetrics()) err := proxyApp.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { cancel(); proxyApp.Wait() }) + t.Cleanup(func() { proxyApp.Wait() }) h := NewHandshaker(logger, stateStore, state, store, eventBus, genDoc) assert.Error(t, h.Handshake(ctx, proxyApp)) @@ -1008,7 +1002,7 @@ func TestHandshakeErrorsIfAppReturnsWrongAppHash(t *testing.T) { proxyApp := proxy.New(client, logger, proxy.NopMetrics()) err := proxyApp.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { cancel(); proxyApp.Wait() }) + t.Cleanup(func() { proxyApp.Wait() }) h := NewHandshaker(logger, stateStore, state, store, eventBus, genDoc) require.Error(t, h.Handshake(ctx, proxyApp)) @@ -1244,8 +1238,7 @@ func (bs *mockBlockStore) DeleteLatestBlock() error { return nil } // Test handshake/init chain func TestHandshakeUpdatesValidators(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() votePower := 10 + int64(rand.Uint32()) diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 5d9d8c405..97f02e5df 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -69,8 +69,7 @@ x * TestHalt1 - if we see +2/3 precommits after timing out into new round, we sh // ProposeSuite func TestStateProposerSelection0(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -114,8 +113,7 @@ func TestStateProposerSelection0(t *testing.T) { func TestStateProposerSelection2(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) // test needs more work for more than 3 validators height := cs1.roundState.Height() newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) @@ -151,8 +149,7 @@ func TestStateProposerSelection2(t *testing.T) { // a non-validator should timeout into the prevote round func TestStateEnterProposeNoPrivValidator(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) cs.SetPrivValidator(ctx, nil) @@ -174,8 +171,7 @@ func TestStateEnterProposeNoPrivValidator(t *testing.T) { // a validator should not timeout of the prevote round (TODO: unless the block is really big!) func TestStateEnterProposeYesPrivValidator(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) height, round := cs.roundState.Height(), cs.roundState.Round() @@ -208,8 +204,7 @@ func TestStateEnterProposeYesPrivValidator(t *testing.T) { func TestStateBadProposal(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) height, round := cs1.roundState.Height(), cs1.roundState.Round() @@ -271,8 +266,7 @@ func TestStateBadProposal(t *testing.T) { func TestStateOversizedBlock(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) cs1.state.ConsensusParams.Block.MaxBytes = 2000 @@ -336,8 +330,7 @@ func TestStateOversizedBlock(t *testing.T) { // propose, prevote, and precommit a block func TestStateFullRound1(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) height, round := cs.roundState.Height(), cs.roundState.Round() @@ -366,8 +359,7 @@ func TestStateFullRound1(t *testing.T) { // nil is proposed, so prevote and precommit nil func TestStateFullRoundNil(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) height, round := cs.roundState.Height(), cs.roundState.Round() @@ -385,8 +377,7 @@ func TestStateFullRoundNil(t *testing.T) { // where the first validator has to wait for votes from the second func TestStateFullRound2(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) vs2 := vss[1] @@ -429,8 +420,7 @@ func TestStateFullRound2(t *testing.T) { // two vals take turns proposing. val1 locks on first one, precommits nil on everything else func TestStateLock_NoPOL(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) vs2 := vss[1] @@ -635,8 +625,7 @@ func TestStateLock_POLUpdateLock(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, logger: logger}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -741,8 +730,7 @@ func TestStateLock_POLUpdateLock(t *testing.T) { // it receives votes representing over 2/3 of the voting power on the network // for a block that it is already locked in. func TestStateLock_POLRelock(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -840,8 +828,7 @@ func TestStateLock_POLRelock(t *testing.T) { // TestStateLock_PrevoteNilWhenLockedAndMissProposal tests that a validator prevotes nil // if it is locked on a block and misses the proposal in a round. func TestStateLock_PrevoteNilWhenLockedAndMissProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -850,7 +837,7 @@ func TestStateLock_PrevoteNilWhenLockedAndMissProposal(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(t.Context()) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -920,8 +907,7 @@ func TestStateLock_PrevoteNilWhenLockedAndMissProposal(t *testing.T) { // TestStateLock_PrevoteNilWhenLockedAndMissProposal tests that a validator prevotes nil // if it is locked on a block and misses the proposal in a round. func TestStateLock_PrevoteNilWhenLockedAndDifferentProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() config := configSetup(t) /* @@ -937,7 +923,7 @@ func TestStateLock_PrevoteNilWhenLockedAndDifferentProposal(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(t.Context()) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1022,8 +1008,7 @@ func TestStateLock_PrevoteNilWhenLockedAndDifferentProposal(t *testing.T) { func TestStateLock_POLDoesNotUnlock(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() /* All of the assertions in this test occur on the `cs1` validator. The test sends signed votes from the other validators to cs1 and @@ -1039,7 +1024,7 @@ func TestStateLock_POLDoesNotUnlock(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) lockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryLock) - pv1, err := cs1.privValidator.GetPubKey(t.Context()) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1159,8 +1144,7 @@ func TestStateLock_POLDoesNotUnlock(t *testing.T) { func TestStateLock_MissingProposalWhenPOLSeenDoesNotUpdateLock(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, logger: logger}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1246,8 +1230,7 @@ func TestStateLock_MissingProposalWhenPOLSeenDoesNotUpdateLock(t *testing.T) { // block if a proposal was not seen for that block in the current round, but // was seen in a previous round. func TestStateLock_DoesNotLockOnOldProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -1256,7 +1239,7 @@ func TestStateLock_DoesNotLockOnOldProposal(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(t.Context()) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1323,8 +1306,7 @@ func TestStateLock_DoesNotLockOnOldProposal(t *testing.T) { func TestStateLock_POLSafety1(t *testing.T) { config := configSetup(t) logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, logger: logger}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1439,8 +1421,7 @@ func TestStateLock_POLSafety1(t *testing.T) { // dont see P0, lock on P1 at R1, dont unlock using P0 at R2 func TestStateLock_POLSafety2(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1535,8 +1516,7 @@ func TestStateLock_POLSafety2(t *testing.T) { // for a block if it is locked on a different block but saw a POL for the block // it is not locked on in a previous round. func TestState_PrevotePOLFromPreviousRound(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) logger := log.NewNopLogger() @@ -1548,7 +1528,7 @@ func TestState_PrevotePOLFromPreviousRound(t *testing.T) { timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(t.Context()) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() voteCh := subscribeToVoter(ctx, t, cs1, addr) @@ -1681,8 +1661,7 @@ func TestState_PrevotePOLFromPreviousRound(t *testing.T) { // P0 proposes B0 at R3. func TestProposeValidBlock(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1773,8 +1752,7 @@ func TestProposeValidBlock(t *testing.T) { // P0 miss to lock B but set valid block to B after receiving delayed prevote. func TestSetValidBlockOnDelayedPrevote(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1842,8 +1820,7 @@ func TestSetValidBlockOnDelayedPrevote(t *testing.T) { // receiving delayed Block Proposal. func TestSetValidBlockOnDelayedProposal(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1921,8 +1898,7 @@ func TestProcessProposalAccept(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() m := abcimocks.NewApplication(t) status := abci.ResponseProcessProposal_REJECT @@ -1974,8 +1950,7 @@ func TestFinalizeBlockCalled(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() m := abcimocks.NewApplication(t) m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{ @@ -2057,8 +2032,7 @@ func TestExtendVoteCalledWhenEnabled(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() m := abcimocks.NewApplication(t) m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) @@ -2144,8 +2118,7 @@ func TestExtendVoteCalledWhenEnabled(t *testing.T) { // method is not called for a validator's vote that is never delivered. func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() m := abcimocks.NewApplication(t) m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) @@ -2214,8 +2187,7 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { // is the proposer again and ensures that the mock application receives the set of // vote extensions from the previous consensus instance. func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) @@ -2353,8 +2325,7 @@ func TestVoteExtensionEnableHeight(t *testing.T) { } { t.Run(testCase.name, func(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() numValidators := 3 m := abcimocks.NewApplication(t) @@ -2431,8 +2402,7 @@ func TestVoteExtensionEnableHeight(t *testing.T) { // What we want: // P0 waits for timeoutPrecommit before starting next round func TestWaitingTimeoutOnNilPolka(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() config := configSetup(t) cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) @@ -2457,8 +2427,7 @@ func TestWaitingTimeoutOnNilPolka(t *testing.T) { // P0 waits for timeoutPropose in the next round before entering prevote func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -2495,8 +2464,7 @@ func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { // P0 jump to higher round, precommit and start precommit wait func TestRoundSkipOnNilPolkaFromHigherRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -2535,8 +2503,7 @@ func TestRoundSkipOnNilPolkaFromHigherRound(t *testing.T) { // P0 wait for timeoutPropose to expire before sending prevote. func TestWaitTimeoutProposeOnNilPolkaForTheCurrentRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -2565,8 +2532,7 @@ func TestWaitTimeoutProposeOnNilPolkaForTheCurrentRound(t *testing.T) { // P0 emit NewValidBlock event upon receiving 2/3+ Precommit for B but hasn't received block B yet func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -2607,8 +2573,7 @@ func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { // After receiving block, it executes block and moves to the next height. func TestCommitFromPreviousRound(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -2668,8 +2633,7 @@ func (n *fakeTxNotifier) Notify() { // start of the next round func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) cs1.state.ConsensusParams.Timeout.BypassCommitTimeout = false @@ -2734,8 +2698,7 @@ func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) cs1.state.ConsensusParams.Timeout.BypassCommitTimeout = false @@ -2804,8 +2767,7 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { // we receive a final precommit after going into next round, but others might have gone to commit already! func TestStateHalt1(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -2877,8 +2839,7 @@ func TestStateHalt1(t *testing.T) { func TestStateOutputsBlockPartsStats(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // create dummy peer cs, _ := makeState(ctx, t, makeStateArgs{config: config, validators: 1}) @@ -2925,8 +2886,7 @@ func TestStateOutputsBlockPartsStats(t *testing.T) { func TestGossipTransactionKeyOnlyConfig(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) vs2 := vss[1] @@ -2968,8 +2928,7 @@ func TestGossipTransactionKeyOnlyConfig(t *testing.T) { func TestStateOutputVoteStats(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) // create dummy peer @@ -3009,8 +2968,7 @@ func TestStateOutputVoteStats(t *testing.T) { func TestSignSameVoteTwice(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() _, vss := makeState(ctx, t, makeStateArgs{config: config, validators: 2}) @@ -3049,8 +3007,7 @@ func TestSignSameVoteTwice(t *testing.T) { // corresponding proposal message. func TestStateTimestamp_ProposalNotMatch(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) height, round := cs1.roundState.Height(), cs1.roundState.Round() @@ -3099,8 +3056,7 @@ func TestStateTimestamp_ProposalNotMatch(t *testing.T) { // corresponding proposal message. func TestStateTimestamp_ProposalMatch(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, vss := makeState(ctx, t, makeStateArgs{config: config}) height, round := cs1.roundState.Height(), cs1.roundState.Round() @@ -3192,8 +3148,7 @@ func signAddPrecommitWithExtension(ctx context.Context, func TestAddProposalBlockPartMemoryLimit(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, _ := makeState(ctx, t, makeStateArgs{config: config}) height, round := cs1.roundState.Height(), cs1.roundState.Round() @@ -3237,8 +3192,7 @@ func TestAddProposalBlockPartMemoryLimit(t *testing.T) { func TestAddProposalBlockPartWrongHeight(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, _ := makeState(ctx, t, makeStateArgs{config: config}) height, round := cs1.roundState.Height(), cs1.roundState.Round() @@ -3267,8 +3221,7 @@ func TestAddProposalBlockPartWrongHeight(t *testing.T) { func TestAddProposalBlockPartNilProposalBlockParts(t *testing.T) { config := configSetup(t) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cs1, _ := makeState(ctx, t, makeStateArgs{config: config}) height, round := cs1.roundState.Height(), cs1.roundState.Round() diff --git a/internal/consensus/types/height_vote_set_test.go b/internal/consensus/types/height_vote_set_test.go index 1b0d70b89..959b946c3 100644 --- a/internal/consensus/types/height_vote_set_test.go +++ b/internal/consensus/types/height_vote_set_test.go @@ -21,8 +21,7 @@ func TestPeerCatchupRounds(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() valSet, privVals := factory.ValidatorSet(ctx, t, 10, 1) diff --git a/internal/consensus/wal_test.go b/internal/consensus/wal_test.go index fe7573606..74df7185b 100644 --- a/internal/consensus/wal_test.go +++ b/internal/consensus/wal_test.go @@ -2,7 +2,6 @@ package consensus import ( "bytes" - "context" "os" "path/filepath" @@ -28,8 +27,7 @@ func TestWALTruncate(t *testing.T) { walFile := filepath.Join(walDir, "wal") logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // this magic number 4K can truncate the content when RotateFile. // defaultHeadSizeLimit(10M) is hard to simulate. @@ -105,8 +103,7 @@ func TestWALWrite(t *testing.T) { walDir := t.TempDir() walFile := filepath.Join(walDir, "wal") - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() wal, err := NewWAL(ctx, log.NewNopLogger(), walFile) require.NoError(t, err) @@ -138,8 +135,7 @@ func TestWALWrite(t *testing.T) { } func TestWALSearchForEndHeight(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -170,8 +166,7 @@ func TestWALSearchForEndHeight(t *testing.T) { } func TestWALPeriodicSync(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() walDir := t.TempDir() walFile := filepath.Join(walDir, "wal") diff --git a/internal/eventbus/event_bus_test.go b/internal/eventbus/event_bus_test.go index c1f335fa2..22b29b263 100644 --- a/internal/eventbus/event_bus_test.go +++ b/internal/eventbus/event_bus_test.go @@ -19,8 +19,7 @@ import ( ) func TestEventBusPublishEventTx(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -73,8 +72,7 @@ func TestEventBusPublishEventTx(t *testing.T) { } func TestEventBusPublishEventNewBlock(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) require.NoError(t, err) @@ -127,8 +125,7 @@ func TestEventBusPublishEventNewBlock(t *testing.T) { } func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) require.NoError(t, err) @@ -200,7 +197,7 @@ func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { } t.Run(name, func(t *testing.T) { - + ctx := t.Context() sub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: fmt.Sprintf("client-%d", i), Query: tmquery.MustCompile(tc.query), @@ -244,8 +241,7 @@ func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { } func TestEventBusPublishEventNewBlockHeader(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -294,8 +290,7 @@ func TestEventBusPublishEventNewBlockHeader(t *testing.T) { } func TestEventBusPublishEventEvidenceValidated(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -336,8 +331,7 @@ func TestEventBusPublishEventEvidenceValidated(t *testing.T) { } func TestEventBusPublishEventNewEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -378,8 +372,7 @@ func TestEventBusPublishEventNewEvidence(t *testing.T) { } func TestEventBusPublish(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() eventBus := eventbus.NewDefault(log.NewNopLogger()) err := eventBus.Start(ctx) @@ -455,18 +448,15 @@ func BenchmarkEventBus(b *testing.B) { for _, bm := range benchmarks { bm := bm b.Run(bm.name, func(b *testing.B) { - benchmarkEventBus(bm.numClients, bm.randQueries, bm.randEvents, b) + benchmarkEventBus(b.Context(), bm.numClients, bm.randQueries, bm.randEvents, b) }) } } -func benchmarkEventBus(numClients int, randQueries bool, randEvents bool, b *testing.B) { +func benchmarkEventBus(ctx context.Context, numClients int, randQueries bool, randEvents bool, b *testing.B) { // for random* functions mrand.Seed(time.Now().Unix()) - ctx, cancel := context.WithCancel(b.Context()) - defer cancel() - eventBus := eventbus.NewDefault(log.NewNopLogger()) // set buffer capacity to 0 so we are not testing cache err := eventBus.Start(ctx) if err != nil { diff --git a/internal/eventlog/eventlog_test.go b/internal/eventlog/eventlog_test.go index 8db70e966..19d91daf0 100644 --- a/internal/eventlog/eventlog_test.go +++ b/internal/eventlog/eventlog_test.go @@ -105,7 +105,6 @@ func TestConcurrent(t *testing.T) { } ctx, cancel := context.WithCancel(t.Context()) - defer cancel() var wg sync.WaitGroup // Publisher: Add events and handle expirations. diff --git a/internal/evidence/pool_test.go b/internal/evidence/pool_test.go index 9a8e820ae..1e0ec8fa6 100644 --- a/internal/evidence/pool_test.go +++ b/internal/evidence/pool_test.go @@ -54,8 +54,7 @@ func TestEvidencePoolBasic(t *testing.T) { blockStore = &mocks.BlockStore{} ) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() valSet, privVals := factory.ValidatorSet(ctx, t, 1, 10) blockStore.On("LoadBlockMeta", mock.AnythingOfType("int64")).Return( &types.BlockMeta{Header: types.Header{Time: defaultEvidenceTime}}, @@ -111,8 +110,7 @@ func TestEvidencePoolBasic(t *testing.T) { // Tests inbound evidence for the right time and height func TestAddExpiredEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() var ( val = types.NewMockPV() @@ -156,8 +154,7 @@ func TestAddExpiredEvidence(t *testing.T) { tc := tc t.Run(tc.evDescription, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() ev, err := types.NewMockDuplicateVoteEvidenceWithValidator(ctx, tc.evHeight, tc.evTime, val, evidenceChainID) require.NoError(t, err) @@ -174,8 +171,7 @@ func TestAddExpiredEvidence(t *testing.T) { func TestReportConflictingVotes(t *testing.T) { var height int64 = 10 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() pool, pv, _ := defaultTestPool(ctx, t, height) @@ -214,8 +210,7 @@ func TestReportConflictingVotes(t *testing.T) { func TestEvidencePoolUpdate(t *testing.T) { height := int64(21) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() pool, val, _ := defaultTestPool(ctx, t, height) @@ -284,8 +279,7 @@ func TestEvidencePoolUpdate(t *testing.T) { func TestVerifyPendingEvidencePasses(t *testing.T) { var height int64 = 1 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() pool, val, _ := defaultTestPool(ctx, t, height) @@ -304,8 +298,7 @@ func TestVerifyPendingEvidencePasses(t *testing.T) { func TestVerifyDuplicatedEvidenceFails(t *testing.T) { var height int64 = 1 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() pool, val, _ := defaultTestPool(ctx, t, height) @@ -328,8 +321,7 @@ func TestVerifyDuplicatedEvidenceFails(t *testing.T) { func TestEventOnEvidenceValidated(t *testing.T) { const height = 1 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() pool, val, eventBus := defaultTestPool(ctx, t, height) @@ -379,8 +371,7 @@ func TestLightClientAttackEvidenceLifecycle(t *testing.T) { height int64 = 100 commonHeight int64 = 90 ) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() ev, trusted, common := makeLunaticEvidence(ctx, t, height, commonHeight, 10, 5, 5, defaultEvidenceTime, defaultEvidenceTime.Add(1*time.Hour)) @@ -443,8 +434,7 @@ func TestLightClientAttackEvidenceLifecycle(t *testing.T) { // Tests that restarting the evidence pool after a potential failure will recover the // pending evidence and continue to gossip it func TestRecoverPendingEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() height := int64(10) val := types.NewMockPV() diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index 97c7bb03c..8846ebed5 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -231,8 +231,7 @@ func createEvidenceList( } func TestReactorMultiDisconnect(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() val := types.NewMockPV() height := int64(numEvidence) + 10 @@ -271,8 +270,7 @@ func TestReactorMultiDisconnect(t *testing.T) { func TestReactorBroadcastEvidence(t *testing.T) { numPeers := 7 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // create a stateDB for all test suites (nodes) stateDBs := make([]sm.Store, numPeers) @@ -335,8 +333,7 @@ func TestReactorBroadcastEvidence_Lagging(t *testing.T) { height1 := int64(numEvidence) + 10 height2 := int64(numEvidence) / 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // stateDB1 is ahead of stateDB2, where stateDB1 has all heights (1-20) and // stateDB2 only has heights 1-5. @@ -371,8 +368,7 @@ func TestReactorBroadcastEvidence_Pending(t *testing.T) { val := types.NewMockPV() height := int64(10) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() stateDB1 := initializeValidatorState(ctx, t, val, height) stateDB2 := initializeValidatorState(ctx, t, val, height) @@ -412,8 +408,7 @@ func TestReactorBroadcastEvidence_Committed(t *testing.T) { val := types.NewMockPV() height := int64(10) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() stateDB1 := initializeValidatorState(ctx, t, val, height) stateDB2 := initializeValidatorState(ctx, t, val, height) @@ -467,8 +462,7 @@ func TestReactorBroadcastEvidence_FullyConnected(t *testing.T) { stateDBs := make([]sm.Store, numPeers) val := types.NewMockPV() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // We need all validators saved for heights at least as high as we have // evidence for. diff --git a/internal/evidence/verify_test.go b/internal/evidence/verify_test.go index c83ed90c5..ee1385abe 100644 --- a/internal/evidence/verify_test.go +++ b/internal/evidence/verify_test.go @@ -33,8 +33,7 @@ func TestVerifyLightClientAttack_Lunatic(t *testing.T) { totalVals = 10 byzVals = 4 ) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() attackTime := defaultEvidenceTime.Add(1 * time.Hour) // create valid lunatic evidence @@ -74,8 +73,7 @@ func TestVerify_LunaticAttackAgainstState(t *testing.T) { totalVals = 10 byzVals = 4 ) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() attackTime := defaultEvidenceTime.Add(1 * time.Hour) @@ -149,8 +147,7 @@ func TestVerify_ForwardLunaticAttack(t *testing.T) { ) attackTime := defaultEvidenceTime.Add(1 * time.Hour) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -205,8 +202,7 @@ func TestVerify_ForwardLunaticAttack(t *testing.T) { } func TestVerifyLightClientAttack_Equivocation(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -307,8 +303,7 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { } func TestVerifyLightClientAttack_Amnesia(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -407,8 +402,7 @@ type voteData struct { } func TestVerifyDuplicateVoteEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() val := types.NewMockPV() diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 670da2ecf..67450c01a 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -86,7 +86,7 @@ func TestBlock(t *testing.T) { rpcConfig := config.TestRPCConfig() l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -104,11 +104,11 @@ func TestBlock(t *testing.T) { require.NoError(t, err) require.Equal(t, testBlock.Height, resultBlock.Block.Height) require.Equal(t, testBlock.LastCommitHash, resultBlock.Block.LastCommitHash) - cancel() - wg.Wait() - - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + t.Cleanup(func() { + wg.Wait() + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestTxSearch(t *testing.T) { @@ -133,7 +133,7 @@ func TestTxSearch(t *testing.T) { rpcConfig := config.TestRPCConfig() l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -157,12 +157,13 @@ func TestTxSearch(t *testing.T) { require.Len(t, resultTxSearch.Txs, 1) require.Equal(t, types.Tx(testTx), resultTxSearch.Txs[0].Tx) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - eventSinkMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) - blockStoreMock.AssertExpectations(t) + eventSinkMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + }) } func TestTx(t *testing.T) { testHash := []byte("test") @@ -180,7 +181,7 @@ func TestTx(t *testing.T) { rpcConfig := config.TestRPCConfig() l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -202,12 +203,13 @@ func TestTx(t *testing.T) { require.NoError(t, err) require.Equal(t, types.Tx(testTx), res.Tx) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - eventSinkMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) - blockStoreMock.AssertExpectations(t) + eventSinkMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + }) } func TestConsensusParams(t *testing.T) { testHeight := int64(1) @@ -229,7 +231,7 @@ func TestConsensusParams(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -250,11 +252,12 @@ func TestConsensusParams(t *testing.T) { require.NoError(t, err) require.Equal(t, params.ConsensusParams.Block.MaxGas, testMaxGas) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestBlockResults(t *testing.T) { @@ -280,7 +283,7 @@ func TestBlockResults(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -301,11 +304,12 @@ func TestBlockResults(t *testing.T) { require.NoError(t, err) require.Equal(t, res.TotalGasUsed, testGasUsed) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestCommit(t *testing.T) { @@ -328,7 +332,7 @@ func TestCommit(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -350,11 +354,12 @@ func TestCommit(t *testing.T) { require.NotNil(t, res) require.Equal(t, res.SignedHeader.Commit.Round, testRound) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestBlockByHash(t *testing.T) { @@ -382,7 +387,7 @@ func TestBlockByHash(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -404,11 +409,12 @@ func TestBlockByHash(t *testing.T) { require.NotNil(t, res) require.Equal(t, []byte(res.BlockID.Hash), testHash) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestBlockchain(t *testing.T) { @@ -435,7 +441,7 @@ func TestBlockchain(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -457,11 +463,12 @@ func TestBlockchain(t *testing.T) { require.NotNil(t, res) require.Equal(t, testBlockHash, []byte(res.BlockMetas[0].BlockID.Hash)) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestValidators(t *testing.T) { @@ -488,7 +495,7 @@ func TestValidators(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -513,11 +520,12 @@ func TestValidators(t *testing.T) { require.NotNil(t, res) require.Equal(t, testVotingPower, res.Validators[0].VotingPower) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func TestBlockSearch(t *testing.T) { @@ -547,7 +555,7 @@ func TestBlockSearch(t *testing.T) { l := log.NewNopLogger() d := inspect.New(rpcConfig, blockStoreMock, stateStoreMock, []indexer.EventSink{eventSinkMock}, l) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() wg := &sync.WaitGroup{} wg.Add(1) @@ -573,11 +581,12 @@ func TestBlockSearch(t *testing.T) { require.NotNil(t, res) require.Equal(t, testBlockHash, []byte(res.Blocks[0].BlockID.Hash)) - cancel() - wg.Wait() + t.Cleanup(func() { + wg.Wait() - blockStoreMock.AssertExpectations(t) - stateStoreMock.AssertExpectations(t) + blockStoreMock.AssertExpectations(t) + stateStoreMock.AssertExpectations(t) + }) } func requireConnect(t testing.TB, addr string, retries int) { diff --git a/internal/libs/autofile/autofile_test.go b/internal/libs/autofile/autofile_test.go index 34c265eab..13936a3a9 100644 --- a/internal/libs/autofile/autofile_test.go +++ b/internal/libs/autofile/autofile_test.go @@ -1,7 +1,6 @@ package autofile import ( - "context" "os" "path/filepath" "syscall" @@ -13,8 +12,7 @@ import ( ) func TestSIGHUP(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() origDir, err := os.Getwd() require.NoError(t, err) @@ -102,8 +100,7 @@ func TestSIGHUP(t *testing.T) { // } func TestAutoFileSize(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // First, create an AutoFile writing to a tempfile dir f, err := os.CreateTemp(t.TempDir(), "sighup_test") diff --git a/internal/libs/autofile/group_test.go b/internal/libs/autofile/group_test.go index c67b8ac52..a57eb0b41 100644 --- a/internal/libs/autofile/group_test.go +++ b/internal/libs/autofile/group_test.go @@ -44,8 +44,7 @@ func assertGroupInfo(t *testing.T, gInfo GroupInfo, minIndex, maxIndex int, tota } func TestCheckHeadSizeLimit(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -118,8 +117,7 @@ func TestCheckHeadSizeLimit(t *testing.T) { func TestRotateFile(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) // Create a different temporary directory and move into it, to make sure @@ -183,8 +181,7 @@ func TestRotateFile(t *testing.T) { func TestWrite(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -211,8 +208,7 @@ func TestWrite(t *testing.T) { func TestGroupReaderRead(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -249,8 +245,7 @@ func TestGroupReaderRead(t *testing.T) { func TestGroupReaderRead2(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -288,8 +283,7 @@ func TestGroupReaderRead2(t *testing.T) { func TestMinIndex(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) @@ -301,8 +295,7 @@ func TestMinIndex(t *testing.T) { func TestMaxIndex(t *testing.T) { logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) diff --git a/internal/libs/queue/queue_test.go b/internal/libs/queue/queue_test.go index f0c08bc62..224a8514f 100644 --- a/internal/libs/queue/queue_test.go +++ b/internal/libs/queue/queue_test.go @@ -125,8 +125,7 @@ func TestClose(t *testing.T) { } func TestWait(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() q := mustQueue(t, Options{SoftQuota: 2, HardLimit: 2}) @@ -144,8 +143,7 @@ func TestWait(t *testing.T) { // A wait on a non-empty queue should report an item. t.Run("WaitNonEmpty", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() const input = "figgy pudding" q.mustAdd(input) diff --git a/internal/mempool/mempool_bench_test.go b/internal/mempool/mempool_bench_test.go index e1d1d5955..4eac0d51d 100644 --- a/internal/mempool/mempool_bench_test.go +++ b/internal/mempool/mempool_bench_test.go @@ -1,7 +1,6 @@ package mempool import ( - "context" "fmt" "math/rand" "testing" @@ -15,8 +14,7 @@ import ( ) func BenchmarkTxMempool_CheckTx(b *testing.B) { - ctx, cancel := context.WithCancel(b.Context()) - defer cancel() + ctx := b.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), kvstore.NewApplication()) if err := client.Start(ctx); err != nil { diff --git a/internal/mempool/mempool_test.go b/internal/mempool/mempool_test.go index c9795151d..e4987a82d 100644 --- a/internal/mempool/mempool_test.go +++ b/internal/mempool/mempool_test.go @@ -233,8 +233,7 @@ func (e *TestPeerEvictor) Errored(peerID types.NodeID, err error) { } func TestTxMempool_TxsAvailable(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -296,8 +295,7 @@ func TestTxMempool_TxsAvailable(t *testing.T) { } func TestTxMempool_Size(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -330,8 +328,7 @@ func TestTxMempool_Size(t *testing.T) { } func TestTxMempool_Flush(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -364,8 +361,7 @@ func TestTxMempool_Flush(t *testing.T) { } func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() gasEstimated := int64(1) // gas estimated set to 1 client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication(), gasEstimated: &gasEstimated}) @@ -460,8 +456,7 @@ func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { } func TestTxMempool_ReapMaxBytesMaxGas_FallbackToGasWanted(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() gasEstimated := int64(0) // gas estimated not set so fallback to gas wanted client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication(), gasEstimated: &gasEstimated}) @@ -509,8 +504,7 @@ func TestTxMempool_ReapMaxBytesMaxGas_FallbackToGasWanted(t *testing.T) { } func TestTxMempool_ReapMaxTxs(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -583,8 +577,7 @@ func TestTxMempool_ReapMaxTxs(t *testing.T) { } func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -608,8 +601,7 @@ func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { } func TestTxMempool_Prioritization(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -677,8 +669,7 @@ func TestTxMempool_Prioritization(t *testing.T) { } func TestTxMempool_PendingStoreSize(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -699,8 +690,7 @@ func TestTxMempool_PendingStoreSize(t *testing.T) { } func TestTxMempool_RemoveCacheWhenPendingTxIsFull(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -721,8 +711,7 @@ func TestTxMempool_RemoveCacheWhenPendingTxIsFull(t *testing.T) { } func TestTxMempool_EVMEviction(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -764,8 +753,7 @@ func TestTxMempool_EVMEviction(t *testing.T) { } func TestTxMempool_CheckTxSamePeer(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -788,8 +776,7 @@ func TestTxMempool_CheckTxSamePeer(t *testing.T) { } func TestTxMempool_CheckTxSameSender(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -819,8 +806,7 @@ func TestTxMempool_CheckTxSameSender(t *testing.T) { } func TestTxMempool_ConcurrentTxs(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -892,8 +878,7 @@ func TestTxMempool_ConcurrentTxs(t *testing.T) { } func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -951,9 +936,6 @@ func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { } func TestTxMempool_CheckTxPostCheckError(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - cases := []struct { name string err error @@ -968,10 +950,8 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { }, } for _, tc := range cases { - testCase := tc - t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -980,7 +960,7 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { t.Cleanup(client.Wait) postCheckFn := func(_ types.Tx, _ *abci.ResponseCheckTx) error { - return testCase.err + return tc.err } txmp := setup(t, client, 0, WithPostCheck(postCheckFn)) rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -990,14 +970,14 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { callback := func(res *abci.ResponseCheckTx) { expectedErrString := "" - if testCase.err != nil { - expectedErrString = testCase.err.Error() + if tc.err != nil { + expectedErrString = tc.err.Error() require.Equal(t, expectedErrString, txmp.postCheck(tx, res).Error()) } else { require.Equal(t, nil, txmp.postCheck(tx, res)) } } - if testCase.err == nil { + if tc.err == nil { require.NoError(t, txmp.CheckTx(ctx, tx, callback, TxInfo{SenderID: 0})) } else { err = txmp.CheckTx(ctx, tx, callback, TxInfo{SenderID: 0}) @@ -1008,8 +988,7 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { } func TestTxMempool_FailedCheckTxCount(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -1050,8 +1029,7 @@ func TestTxMempool_FailedCheckTxCount(t *testing.T) { func TestAppendCheckTxErr(t *testing.T) { // Setup - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { @@ -1075,8 +1053,7 @@ func TestAppendCheckTxErr(t *testing.T) { } func TestMempoolExpiration(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() client := abciclient.NewLocalClient(log.NewNopLogger(), &application{Application: kvstore.NewApplication()}) if err := client.Start(ctx); err != nil { diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index ef3d259d2..d26f2bec3 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -145,8 +145,7 @@ func (rts *reactorTestSuite) waitForTxns(t *testing.T, txs []types.Tx, ids ...ty } func TestReactorBroadcastDoesNotPanic(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() const numNodes = 2 @@ -191,8 +190,7 @@ func TestReactorBroadcastDoesNotPanic(t *testing.T) { func TestReactorBroadcastTxs(t *testing.T) { numTxs := 512 numNodes := 4 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -217,8 +215,7 @@ func TestReactorConcurrency(t *testing.T) { numTxs := 10 numNodes := 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() rts := setupReactors(ctx, t, logger, numNodes, 0) @@ -276,8 +273,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) { numTxs := 1000 numNodes := 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() rts := setupReactors(ctx, t, logger, numNodes, uint(numTxs)) @@ -301,8 +297,7 @@ func TestReactor_MaxTxBytes(t *testing.T) { numNodes := 2 cfg := config.TestConfig() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -339,8 +334,7 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) { // we're creating a single node network, but not starting the // network. - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() rts := setupReactors(ctx, t, logger, 1, MaxActiveIDs+1) @@ -392,8 +386,7 @@ func TestBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { t.Skip("skipping test in short mode") } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index 4c4bff025..7c6fdb9bc 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -293,7 +293,7 @@ func TestNodeAddress_Resolve(t *testing.T) { require.NoError(t, err) want := &p2p.Endpoint{Protocol: "tcp", Port: 80, Path: "/path"} - require.True(t, len(endpoints)>0) + require.True(t, len(endpoints) > 0) for _, got := range endpoints { require.True(t, got.IP.IsLoopback()) // Any loopback address is acceptable, so ignore it in comparison. diff --git a/internal/p2p/channel_test.go b/internal/p2p/channel_test.go index 35e178670..4bbe178ac 100644 --- a/internal/p2p/channel_test.go +++ b/internal/p2p/channel_test.go @@ -33,16 +33,14 @@ func testChannel(size int) (*channelInternal, *Channel) { func TestChannel(t *testing.T) { t.Cleanup(leaktest.Check(t)) - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - testCases := []struct { Name string - Case func(context.Context, *testing.T) + Case func(*testing.T) }{ { Name: "Send", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() ins, ch := testChannel(1) require.NoError(t, ch.Send(ctx, Envelope{From: "kip", To: "merlin"})) @@ -54,7 +52,8 @@ func TestChannel(t *testing.T) { }, { Name: "SendError", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() ins, ch := testChannel(1) require.NoError(t, ch.SendError(ctx, PeerError{NodeID: "kip", Err: errors.New("merlin")})) @@ -66,7 +65,8 @@ func TestChannel(t *testing.T) { }, { Name: "SendWithCanceledContext", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() _, ch := testChannel(0) cctx, ccancel := context.WithCancel(ctx) ccancel() @@ -75,7 +75,8 @@ func TestChannel(t *testing.T) { }, { Name: "SendErrorWithCanceledContext", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() _, ch := testChannel(0) cctx, ccancel := context.WithCancel(ctx) ccancel() @@ -85,7 +86,8 @@ func TestChannel(t *testing.T) { }, { Name: "ReceiveEmptyIteratorBlocks", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() _, ch := testChannel(1) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -107,7 +109,8 @@ func TestChannel(t *testing.T) { }, { Name: "ReceiveWithData", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() ins, ch := testChannel(1) ins.In <- Envelope{From: "kip", To: "merlin"} iter := ch.Receive(ctx) @@ -121,7 +124,8 @@ func TestChannel(t *testing.T) { }, { Name: "ReceiveWithCanceledContext", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() _, ch := testChannel(0) cctx, ccancel := context.WithCancel(ctx) ccancel() @@ -134,7 +138,8 @@ func TestChannel(t *testing.T) { }, { Name: "IteratorWithCanceledContext", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() _, ch := testChannel(0) iter := ch.Receive(ctx) @@ -148,7 +153,8 @@ func TestChannel(t *testing.T) { }, { Name: "IteratorCanceledAfterFirstUseBecomesNil", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() ins, ch := testChannel(1) ins.In <- Envelope{From: "kip", To: "merlin"} @@ -170,7 +176,8 @@ func TestChannel(t *testing.T) { }, { Name: "IteratorMultipleNextCalls", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() ins, ch := testChannel(1) ins.In <- Envelope{From: "kip", To: "merlin"} @@ -189,7 +196,8 @@ func TestChannel(t *testing.T) { }, { Name: "IteratorProducesNilObjectBeforeNext", - Case: func(ctx context.Context, t *testing.T) { + Case: func(t *testing.T) { + ctx := t.Context() ins, ch := testChannel(1) iter := ch.Receive(ctx) @@ -211,11 +219,7 @@ func TestChannel(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - - ctx, cancel := context.WithCancel(bctx) - defer cancel() - - tc.Case(ctx, t) + tc.Case(t) }) } } diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index efac9c519..ca1e7159b 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -51,8 +51,7 @@ func TestMConnectionSendFlushStop(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() clientConn := createTestMConnection(log.NewNopLogger(), client) err := clientConn.Start(ctx) @@ -88,8 +87,7 @@ func TestMConnectionSend(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createTestMConnection(log.NewNopLogger(), client) err := mconn.Start(ctx) @@ -135,8 +133,7 @@ func TestMConnectionReceive(t *testing.T) { } logger := log.NewNopLogger() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError) err := mconn1.Start(ctx) @@ -165,8 +162,7 @@ func TestMConnectionWillEventuallyTimeout(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, nil, nil) err := mconn.Start(ctx) @@ -221,8 +217,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { } } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) err := mconn.Start(ctx) @@ -279,8 +274,7 @@ func TestMConnectionMultiplePings(t *testing.T) { case <-ctx.Done(): } } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) err := mconn.Start(ctx) @@ -336,8 +330,7 @@ func TestMConnectionPingPongs(t *testing.T) { } } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) err := mconn.Start(ctx) @@ -395,8 +388,7 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { case <-ctx.Done(): } } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createMConnectionWithCallbacks(log.NewNopLogger(), client, onReceive, onError) err := mconn.Start(ctx) @@ -466,8 +458,7 @@ func expectSend(ch chan struct{}) bool { } func TestMConnectionReadErrorBadEncoding(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) @@ -482,8 +473,7 @@ func TestMConnectionReadErrorBadEncoding(t *testing.T) { } func TestMConnectionReadErrorUnknownChannel(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) @@ -504,8 +494,7 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) @@ -544,8 +533,7 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { } func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() chOnErr := make(chan struct{}) mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) @@ -560,8 +548,7 @@ func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { func TestMConnectionTrySend(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconn := createTestMConnection(log.NewNopLogger(), client) err := mconn.Start(ctx) @@ -609,8 +596,7 @@ func TestMConnectionChannelOverflow(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(waitAll(mconnClient, mconnServer)) diff --git a/internal/p2p/peermanager_scoring_test.go b/internal/p2p/peermanager_scoring_test.go index b8034ca47..1b3048f6a 100644 --- a/internal/p2p/peermanager_scoring_test.go +++ b/internal/p2p/peermanager_scoring_test.go @@ -1,7 +1,6 @@ package p2p import ( - "context" "strings" "testing" "time" @@ -31,8 +30,7 @@ func TestPeerScoring(t *testing.T) { require.NoError(t, err) require.True(t, added) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Run("Synchronous", func(t *testing.T) { // update the manager and make sure it's correct diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index e0040ab73..2933d5afa 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -194,7 +194,7 @@ func TestNewPeerManager_Persistence(t *testing.T) { }, peerManager.Scores()) // Introduce a dial failure and persistent peer score should be reduced by one - ctx, _ := context.WithCancel(t.Context()) + ctx := t.Context() peerManager.DialFailed(ctx, bAddresses[0]) require.Equal(t, map[types.NodeID]p2p.PeerScore{ aID: 0, @@ -344,8 +344,7 @@ func TestPeerManager_Add(t *testing.T) { } func TestPeerManager_DialNext(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -370,8 +369,7 @@ func TestPeerManager_DialNext(t *testing.T) { } func TestPeerManager_DialNext_Retry(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -388,7 +386,7 @@ func TestPeerManager_DialNext_Retry(t *testing.T) { // Do five dial retries (six dials total). The retry time should double for // each failure. At the forth retry, MaxRetryTime should kick in. - ctx, cancel = context.WithTimeout(ctx, 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() for i := 0; i < 3; i++ { @@ -405,8 +403,7 @@ func TestPeerManager_DialNext_Retry(t *testing.T) { } func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -423,7 +420,7 @@ func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { // Do five dial retries (six dials total). The retry time should double for // each failure. At the forth retry, MaxRetryTime should kick in. - ctx, cancel = context.WithTimeout(ctx, 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() for i := 0; i < 4; i++ { @@ -444,8 +441,7 @@ func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ MaxConnected: 1, @@ -489,8 +485,7 @@ func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() options := p2p.PeerManagerOptions{MinRetryTime: 200 * time.Millisecond} peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), options, p2p.NopMetrics()) @@ -510,7 +505,7 @@ func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { // The retry timer should unblock DialNext and make a available again after // the retry time passes. - ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() dial, err = peerManager.DialNext(ctx) require.NoError(t, err) @@ -519,8 +514,7 @@ func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDisconnected(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -544,7 +538,7 @@ func TestPeerManager_DialNext_WakeOnDisconnected(t *testing.T) { peerManager.Disconnected(dctx, a.NodeID) }() - ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() dial, err = peerManager.DialNext(ctx) require.NoError(t, err) @@ -588,8 +582,7 @@ func TestPeerManager_TryDialNext_MaxConnected(t *testing.T) { } func TestPeerManager_TryDialNext_MaxConnectedUpgrade(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -759,8 +752,7 @@ func TestPeerManager_TryDialNext_DialingConnected(t *testing.T) { } func TestPeerManager_TryDialNext_Multiple(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() aID := types.NodeID(strings.Repeat("a", 40)) bID := types.NodeID(strings.Repeat("b", 40)) @@ -811,8 +803,7 @@ func TestPeerManager_DialFailed(t *testing.T) { require.NoError(t, err) require.True(t, added) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Dialing and then calling DialFailed with a different address (same // NodeID) should unmark as dialing and allow us to dial the other address @@ -838,8 +829,7 @@ func TestPeerManager_DialFailed(t *testing.T) { } func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1059,8 +1049,7 @@ func TestPeerManager_Dialed_Upgrade(t *testing.T) { } func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1116,8 +1105,7 @@ func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { } func TestPeerManager_Dialed_UpgradeNoEvict(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1287,8 +1275,7 @@ func TestPeerManager_Accepted_MaxConnectedUpgrade(t *testing.T) { } func TestPeerManager_Accepted_Upgrade(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1383,8 +1370,7 @@ func TestPeerManager_Ready(t *testing.T) { a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) @@ -1417,8 +1403,7 @@ func TestPeerManager_Ready(t *testing.T) { } func TestPeerManager_Ready_Channels(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() pm, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) @@ -1441,8 +1426,7 @@ func TestPeerManager_Ready_Channels(t *testing.T) { // See TryEvictNext for most tests, this just tests blocking behavior. func TestPeerManager_EvictNext(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1470,15 +1454,13 @@ func TestPeerManager_EvictNext(t *testing.T) { // Since there are no more peers to evict, the next call should block. timeoutCtx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) - defer cancel() _, err = peerManager.EvictNext(timeoutCtx) require.Error(t, err) require.Equal(t, context.DeadlineExceeded, err) } func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1498,7 +1480,7 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { }() // This will block until peer errors above. - ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) @@ -1506,8 +1488,7 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1539,7 +1520,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { }() // This will block until peer is upgraded above. - ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) @@ -1547,8 +1528,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1574,15 +1554,14 @@ func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { }() // This will block until peer is upgraded above. - ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) require.Equal(t, a.NodeID, evict) } func TestPeerManager_TryEvictNext(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1625,8 +1604,7 @@ func TestPeerManager_Disconnected(t *testing.T) { peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() sub := peerManager.Subscribe(ctx) @@ -1676,8 +1654,7 @@ func TestPeerManager_Disconnected(t *testing.T) { } func TestPeerManager_Errored(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1715,8 +1692,7 @@ func TestPeerManager_Errored(t *testing.T) { } func TestPeerManager_Subscribe(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1778,8 +1754,7 @@ func TestPeerManager_Subscribe(t *testing.T) { } func TestPeerManager_Subscribe_Close(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1799,14 +1774,14 @@ func TestPeerManager_Subscribe_Close(t *testing.T) { require.Equal(t, p2p.PeerUpdate{NodeID: a.NodeID, Status: p2p.PeerStatusUp}, <-sub.Updates()) // Closing the subscription should not send us the disconnected update. - cancel() - peerManager.Disconnected(ctx, a.NodeID) - require.Empty(t, sub.Updates()) + t.Cleanup(func() { + peerManager.Disconnected(ctx, a.NodeID) + require.Empty(t, sub.Updates()) + }) } func TestPeerManager_Subscribe_Broadcast(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Cleanup(leaktest.Check(t)) @@ -1855,8 +1830,7 @@ func TestPeerManager_Close(t *testing.T) { // leaktest will check that spawned goroutines are closed. t.Cleanup(leaktest.CheckTimeout(t, 1*time.Second)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 67d3f5a59..49e860130 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -33,8 +33,7 @@ const ( ) func TestReactorBasic(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // start a network with one mock reactor and one "real" reactor testNet := setupNetwork(ctx, t, testOptions{ MockNodes: 1, @@ -53,8 +52,7 @@ func TestReactorBasic(t *testing.T) { } func TestReactorConnectFullNetwork(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 4, @@ -72,8 +70,7 @@ func TestReactorConnectFullNetwork(t *testing.T) { } func TestReactorSendsRequestsTooOften(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() r := setupSingle(ctx, t) @@ -103,8 +100,7 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { func TestReactorSendsResponseWithoutRequest(t *testing.T) { t.Skip("This test needs updated https://github.com/tendermint/tendermint/issue/7634") - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ MockNodes: 1, @@ -125,8 +121,7 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) { func TestReactorNeverSendsTooManyPeers(t *testing.T) { t.Skip("This test needs updated https://github.com/tendermint/tendermint/issue/7634") - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ MockNodes: 1, @@ -148,8 +143,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { } func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() r := setupSingle(ctx, t) peer := p2p.NodeAddress{Protocol: p2p.MemoryProtocol, NodeID: randomNodeID()} @@ -195,8 +189,7 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { } func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 8, @@ -219,8 +212,7 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { } func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 3, @@ -240,8 +232,7 @@ func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { func TestReactorWithNetworkGrowth(t *testing.T) { t.Skip("This test needs updated https://github.com/tendermint/tendermint/issue/7634") - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 5, diff --git a/internal/p2p/pqueue_test.go b/internal/p2p/pqueue_test.go index a9c32e95e..614954589 100644 --- a/internal/p2p/pqueue_test.go +++ b/internal/p2p/pqueue_test.go @@ -1,7 +1,6 @@ package p2p import ( - "context" "testing" "time" @@ -26,8 +25,7 @@ func TestCloseWhileDequeueFull(t *testing.T) { } } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() go pqueue.process(ctx) diff --git a/internal/p2p/router_filter_test.go b/internal/p2p/router_filter_test.go index 1c879577a..5b1d7219a 100644 --- a/internal/p2p/router_filter_test.go +++ b/internal/p2p/router_filter_test.go @@ -13,8 +13,7 @@ import ( ) func TestConnectionFiltering(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() filterByIPCount := 0 diff --git a/internal/p2p/router_init_test.go b/internal/p2p/router_init_test.go index 5daa5fbdf..31d06338f 100644 --- a/internal/p2p/router_init_test.go +++ b/internal/p2p/router_init_test.go @@ -1,7 +1,6 @@ package p2p import ( - "context" "os" "testing" @@ -12,8 +11,7 @@ import ( ) func TestRouter_ConstructQueueFactory(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Run("ValidateOptionsPopulatesDefaultQueue", func(t *testing.T) { opts := RouterOptions{} diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index e41df42c2..8cd85abec 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -41,8 +41,7 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { } func TestRouter_Network(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Cleanup(leaktest.Check(t)) @@ -98,8 +97,7 @@ func TestRouter_Network(t *testing.T) { func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Set up a router with no transports (so no peers). peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -163,8 +161,7 @@ func TestRouter_Channel_Basic(t *testing.T) { // Channel tests are hairy to mock, so we use an in-memory network instead. func TestRouter_Channel_SendReceive(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Cleanup(leaktest.Check(t)) @@ -227,8 +224,7 @@ func TestRouter_Channel_SendReceive(t *testing.T) { func TestRouter_Channel_Broadcast(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 4}) @@ -258,8 +254,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) { func TestRouter_Channel_Wrapper(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 2}) @@ -328,8 +323,7 @@ func (w *wrapperMessage) Unwrap() (proto.Message, error) { func TestRouter_Channel_Error(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) @@ -371,19 +365,15 @@ func TestRouter_AcceptPeers(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() + ctx := t.Context() t.Cleanup(leaktest.Check(t)) // Set up a mock transport that handshakes. - connCtx, connCancel := context.WithCancel(t.Context()) + connCtx, connCancel := context.WithCancel(ctx) mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). @@ -450,8 +440,7 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { t.Run(err.Error(), func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Set up a mock transport that returns io.EOF once, which should prevent // the router from calling Accept again. @@ -492,8 +481,7 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Set up a mock transport that returns a connection that blocks during the // handshake. It should be able to accept several of these in parallel, i.e. @@ -573,21 +561,17 @@ func TestRouter_DialPeers(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(bctx) - defer cancel() + ctx := t.Context() address := p2p.NodeAddress{Protocol: "mock", NodeID: tc.dialID} endpoint := &p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} // Set up a mock transport that handshakes. - connCtx, connCancel := context.WithCancel(t.Context()) + connCtx, connCancel := context.WithCancel(ctx) defer connCancel() mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") @@ -665,8 +649,7 @@ func TestRouter_DialPeers(t *testing.T) { func TestRouter_DialPeers_Parallel(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() a := p2p.NodeAddress{Protocol: "mock", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "mock", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -754,8 +737,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func TestRouter_EvictPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Set up a mock transport that we can evict. closeCh := make(chan time.Time) @@ -820,8 +802,7 @@ func TestRouter_EvictPeers(t *testing.T) { func TestRouter_ChannelCompatability(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() incompatiblePeer := types.NodeInfo{ NodeID: peerID, @@ -872,8 +853,7 @@ func TestRouter_ChannelCompatability(t *testing.T) { func TestRouter_DontSendOnInvalidChannel(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peer := types.NodeInfo{ NodeID: peerID, diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 96ea4a77d..4e5e4c8bf 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -1,14 +1,12 @@ package p2p import ( - "context" "testing" "time" ) func TestSimpleQueue(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // set up a small queue with very small buffers so we can // watch it shed load, then send a bunch of messages to the diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index f550832c6..8830902f9 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -1,7 +1,6 @@ package p2p_test import ( - "context" "io" "net" "testing" @@ -50,8 +49,7 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) { t.Cleanup(func() { _ = transport.Close() }) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() _, err := transport.Accept(ctx) require.Error(t, err) @@ -59,8 +57,7 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) { } func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() transport := p2p.NewMConnTransport( log.NewNopLogger(), @@ -129,8 +126,7 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { } func TestMConnTransport_Listen(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() testcases := []struct { endpoint *p2p.Endpoint @@ -195,8 +191,7 @@ func TestMConnTransport_Listen(t *testing.T) { go func() { // Dialing the endpoint should work. var err error - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peerConn, err = transport.Dial(ctx, endpoint) require.NoError(t, err) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index b461e2d0d..e939ca116 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -25,26 +25,21 @@ var testTransports = map[string]transportFactory{} // withTransports is a test helper that runs a test against all transports // registered in testTransports. -func withTransports(ctx context.Context, t *testing.T, tester func(context.Context, *testing.T, transportFactory)) { +func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) { t.Helper() for name, transportFactory := range testTransports { transportFactory := transportFactory t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - tctx, cancel := context.WithCancel(ctx) - defer cancel() - - tester(tctx, t, transportFactory) + tester(t, transportFactory) }) } } func TestTransport_AcceptClose(t *testing.T) { // Just test accept unblock on close, happy path is tested widely elsewhere. - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) opctx, opcancel := context.WithTimeout(ctx, 200*time.Millisecond) defer opcancel() @@ -82,10 +77,8 @@ func TestTransport_DialEndpoints(t *testing.T) { {[]byte{1, 2, 3, 4, 5}, false}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) endpoint, err := a.Endpoint() require.NoError(t, err) @@ -160,11 +153,9 @@ func TestTransport_DialEndpoints(t *testing.T) { } func TestTransport_Dial(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - // Most just tests dial failures, happy path is tested widely elsewhere. - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) @@ -205,10 +196,8 @@ func TestTransport_Dial(t *testing.T) { } func TestTransport_Endpoints(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) @@ -237,10 +226,7 @@ func TestTransport_Endpoints(t *testing.T) { } func TestTransport_Protocols(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { a := makeTransport(t) protocols := a.Protocols() endpoint, err := a.Endpoint() @@ -253,20 +239,15 @@ func TestTransport_Protocols(t *testing.T) { } func TestTransport_String(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { a := makeTransport(t) require.NotEmpty(t, a.String()) }) } func TestConnection_Handshake(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) ab, ba := dialAccept(ctx, t, a, b) @@ -317,10 +298,8 @@ func TestConnection_Handshake(t *testing.T) { } func TestConnection_HandshakeCancel(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) @@ -337,7 +316,6 @@ func TestConnection_HandshakeCancel(t *testing.T) { // Handshake should error on context timeout. ab, ba = dialAccept(ctx, t, a, b) timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond) - defer cancel() _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) require.Error(t, err) require.Equal(t, context.DeadlineExceeded, err) @@ -347,10 +325,8 @@ func TestConnection_HandshakeCancel(t *testing.T) { } func TestConnection_FlushClose(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) ab, _ := dialAcceptHandshake(ctx, t, a, b) @@ -368,10 +344,8 @@ func TestConnection_FlushClose(t *testing.T) { } func TestConnection_LocalRemoteEndpoint(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) ab, ba := dialAcceptHandshake(ctx, t, a, b) @@ -385,10 +359,9 @@ func TestConnection_LocalRemoteEndpoint(t *testing.T) { } func TestConnection_SendReceive(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) ab, ba := dialAcceptHandshake(ctx, t, a, b) @@ -446,10 +419,8 @@ func TestConnection_SendReceive(t *testing.T) { } func TestConnection_String(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx := t.Context() a := makeTransport(t) b := makeTransport(t) ab, _ := dialAccept(ctx, t, a, b) diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go index 6d086b19d..08e2540a1 100644 --- a/internal/proxy/client_test.go +++ b/internal/proxy/client_test.go @@ -65,13 +65,12 @@ func TestEcho(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Start server s := server.NewSocketServer(logger.With("module", "abci-server"), sockPath, kvstore.NewApplication()) require.NoError(t, s.Start(ctx), "error starting socket server") - t.Cleanup(func() { cancel(); s.Wait() }) + t.Cleanup(func() { s.Wait() }) // Start client require.NoError(t, client.Start(ctx), "Error starting ABCI client") @@ -105,13 +104,12 @@ func BenchmarkEcho(b *testing.B) { b.Fatal(err) } - ctx, cancel := context.WithCancel(b.Context()) - defer cancel() + ctx := b.Context() // Start server s := server.NewSocketServer(logger.With("module", "abci-server"), sockPath, kvstore.NewApplication()) require.NoError(b, s.Start(ctx), "Error starting socket server") - b.Cleanup(func() { cancel(); s.Wait() }) + b.Cleanup(func() { s.Wait() }) // Start client require.NoError(b, client.Start(ctx), "Error starting ABCI client") @@ -143,8 +141,7 @@ func BenchmarkEcho(b *testing.B) { } func TestInfo(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() sockPath := fmt.Sprintf("unix://%s/echo_%v.sock", t.TempDir(), tmrand.Str(6)) logger := log.NewNopLogger() @@ -156,7 +153,7 @@ func TestInfo(t *testing.T) { // Start server s := server.NewSocketServer(logger.With("module", "abci-server"), sockPath, kvstore.NewApplication()) require.NoError(t, s.Start(ctx), "Error starting socket server") - t.Cleanup(func() { cancel(); s.Wait() }) + t.Cleanup(func() { s.Wait() }) // Start client require.NoError(t, client.Start(ctx), "Error starting ABCI client") @@ -180,8 +177,7 @@ type noopStoppableClientImpl struct { func (c *noopStoppableClientImpl) Stop() { c.count++ } func TestAppConns_Start_Stop(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() clientMock := &abcimocks.Client{} clientMock.On("Start", mock.Anything).Return(nil) @@ -197,11 +193,12 @@ func TestAppConns_Start_Stop(t *testing.T) { time.Sleep(200 * time.Millisecond) - cancel() - appConns.Wait() + t.Cleanup(func() { + appConns.Wait() - clientMock.AssertExpectations(t) - assert.Equal(t, 1, cl.count) + clientMock.AssertExpectations(t) + assert.Equal(t, 1, cl.count) + }) } // Upon failure, we call tmos.Kill diff --git a/internal/pubsub/example_test.go b/internal/pubsub/example_test.go index e1873375a..703de1c52 100644 --- a/internal/pubsub/example_test.go +++ b/internal/pubsub/example_test.go @@ -1,7 +1,6 @@ package pubsub_test import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -13,8 +12,7 @@ import ( ) func TestExample(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() s := newTestServer(ctx, t, log.NewNopLogger()) diff --git a/internal/pubsub/pubsub_test.go b/internal/pubsub/pubsub_test.go index 9a992a9b9..f71e97215 100644 --- a/internal/pubsub/pubsub_test.go +++ b/internal/pubsub/pubsub_test.go @@ -29,8 +29,7 @@ func (pubstring) TypeTag() string { return "pubstring" } func (e pubstring) ToLegacy() types.LegacyEventData { return e } func TestSubscribeWithArgs(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -59,8 +58,7 @@ func TestSubscribeWithArgs(t *testing.T) { } func TestObserver(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -80,8 +78,7 @@ func TestObserver(t *testing.T) { } func TestObserverErrors(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -93,8 +90,7 @@ func TestObserverErrors(t *testing.T) { } func TestPublishDoesNotBlock(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -123,8 +119,7 @@ func TestPublishDoesNotBlock(t *testing.T) { } func TestSubscribeErrors(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -140,8 +135,7 @@ func TestSubscribeErrors(t *testing.T) { } func TestSlowSubscriber(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -162,8 +156,7 @@ func TestSlowSubscriber(t *testing.T) { } func TestDifferentClients(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -217,8 +210,7 @@ func TestDifferentClients(t *testing.T) { } func TestSubscribeDuplicateKeys(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -273,8 +265,7 @@ func TestSubscribeDuplicateKeys(t *testing.T) { } func TestClientSubscribesTwice(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -309,8 +300,7 @@ func TestClientSubscribesTwice(t *testing.T) { } func TestUnsubscribe(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -334,8 +324,7 @@ func TestUnsubscribe(t *testing.T) { } func TestClientUnsubscribesTwice(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -356,8 +345,7 @@ func TestClientUnsubscribesTwice(t *testing.T) { } func TestResubscribe(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -380,8 +368,7 @@ func TestResubscribe(t *testing.T) { } func TestUnsubscribeAll(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() s := newTestServer(ctx, t, logger) @@ -409,13 +396,12 @@ func TestBufferCapacity(t *testing.T) { require.Equal(t, 2, s.BufferCapacity()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() require.NoError(t, s.Publish(pubstring("Nighthawk"))) require.NoError(t, s.Publish(pubstring("Sage"))) - ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() sig := make(chan struct{}) diff --git a/internal/state/execution_test.go b/internal/state/execution_test.go index f6351dea2..d7ab0c582 100644 --- a/internal/state/execution_test.go +++ b/internal/state/execution_test.go @@ -43,8 +43,7 @@ func TestApplyBlock(t *testing.T) { cc := abciclient.NewLocalClient(logger, app) proxyApp := proxy.New(cc, logger, proxy.NopMetrics()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() require.NoError(t, proxyApp.Start(ctx)) @@ -86,8 +85,7 @@ func TestApplyBlock(t *testing.T) { // DecidedLastCommit properly reflects which validators signed the preceding // block. func TestFinalizeBlockDecidedLastCommit(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() app := &testApp{} @@ -162,8 +160,7 @@ func TestFinalizeBlockDecidedLastCommit(t *testing.T) { // TestFinalizeBlockByzantineValidators ensures we send byzantine validators list. func TestFinalizeBlockByzantineValidators(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() app := &testApp{} logger := log.NewNopLogger() @@ -283,8 +280,7 @@ func TestFinalizeBlockByzantineValidators(t *testing.T) { func TestProcessProposal(t *testing.T) { const height = 2 txs := factory.MakeNTxs(height, 10) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() app := abcimocks.NewApplication(t) logger := log.NewNopLogger() @@ -506,8 +502,7 @@ func TestUpdateValidators(t *testing.T) { // TestFinalizeBlockValidatorUpdates ensures we update validator set and send an event. func TestFinalizeBlockValidatorUpdates(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() app := &testApp{} logger := log.NewNopLogger() @@ -577,7 +572,7 @@ func TestFinalizeBlockValidatorUpdates(t *testing.T) { } // test we threw an event - ctx, cancel = context.WithTimeout(ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() msg, err := updatesSub.Next(ctx) require.NoError(t, err) @@ -592,8 +587,7 @@ func TestFinalizeBlockValidatorUpdates(t *testing.T) { // TestFinalizeBlockValidatorUpdatesResultingInEmptySet checks that processing validator updates that // would result in empty set causes no panic, an error is raised and NextValidators is not updated func TestFinalizeBlockValidatorUpdatesResultingInEmptySet(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() app := &testApp{} logger := log.NewNopLogger() @@ -640,8 +634,7 @@ func TestFinalizeBlockValidatorUpdatesResultingInEmptySet(t *testing.T) { func TestEmptyPrepareProposal(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -694,8 +687,7 @@ func TestEmptyPrepareProposal(t *testing.T) { // a transaction as REMOVED that was not present in the original proposal. func TestPrepareProposalErrorOnNonExistingRemoved(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() eventBus := eventbus.NewDefault(logger) @@ -751,8 +743,7 @@ func TestPrepareProposalErrorOnNonExistingRemoved(t *testing.T) { // in the order matching the order they are returned from PrepareProposal. func TestPrepareProposalReorderTxs(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() eventBus := eventbus.NewDefault(logger) @@ -808,8 +799,7 @@ func TestPrepareProposalReorderTxs(t *testing.T) { // an error if the ResponsePrepareProposal returned from the application is invalid. func TestPrepareProposalErrorOnTooManyTxs(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() eventBus := eventbus.NewDefault(logger) @@ -865,8 +855,7 @@ func TestPrepareProposalErrorOnTooManyTxs(t *testing.T) { // upon calling PrepareProposal on it. func TestPrepareProposalErrorOnPrepareProposalError(t *testing.T) { const height = 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() eventBus := eventbus.NewDefault(logger) @@ -953,8 +942,7 @@ func TestCreateProposalAbsentVoteExtensions(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() diff --git a/internal/state/indexer/block/kv/kv_test.go b/internal/state/indexer/block/kv/kv_test.go index a915bee9d..8409da946 100644 --- a/internal/state/indexer/block/kv/kv_test.go +++ b/internal/state/indexer/block/kv/kv_test.go @@ -1,7 +1,6 @@ package kv_test import ( - "context" "fmt" "testing" @@ -133,8 +132,7 @@ func TestBlockIndexer(t *testing.T) { for name, tc := range testCases { tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() results, err := indexer.Search(ctx, tc.q) require.NoError(t, err) diff --git a/internal/state/indexer/indexer_service_test.go b/internal/state/indexer/indexer_service_test.go index bcd808ef3..d3c16f9c2 100644 --- a/internal/state/indexer/indexer_service_test.go +++ b/internal/state/indexer/indexer_service_test.go @@ -1,7 +1,6 @@ package indexer_test import ( - "context" "database/sql" "fmt" "os" @@ -40,8 +39,7 @@ var ( ) func TestIndexerServiceIndexesBlocks(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := tmlog.NewNopLogger() // event bus diff --git a/internal/state/indexer/sink/kv/kv_test.go b/internal/state/indexer/sink/kv/kv_test.go index f49b7c379..373431c0a 100644 --- a/internal/state/indexer/sink/kv/kv_test.go +++ b/internal/state/indexer/sink/kv/kv_test.go @@ -144,8 +144,7 @@ func TestBlockFuncs(t *testing.T) { for name, tc := range testCases { tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() results, err := indexer.SearchBlockEvents(ctx, tc.q) require.NoError(t, err) diff --git a/internal/state/indexer/sink/null/null_test.go b/internal/state/indexer/sink/null/null_test.go index 8129143ef..ba71b488f 100644 --- a/internal/state/indexer/sink/null/null_test.go +++ b/internal/state/indexer/sink/null/null_test.go @@ -1,7 +1,6 @@ package null import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -11,8 +10,7 @@ import ( ) func TestNullEventSink(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() nullIndexer := NewEventSink() diff --git a/internal/state/indexer/sink/psql/psql_test.go b/internal/state/indexer/sink/psql/psql_test.go index 1f51b1b15..37d35bfd2 100644 --- a/internal/state/indexer/sink/psql/psql_test.go +++ b/internal/state/indexer/sink/psql/psql_test.go @@ -1,7 +1,6 @@ package psql import ( - "context" "database/sql" "flag" "fmt" @@ -152,8 +151,7 @@ func TestType(t *testing.T) { } func TestIndexing(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Run("IndexBlockEvents", func(t *testing.T) { indexer := &EventSink{store: testDB(), chainID: chainID} diff --git a/internal/state/rollback_test.go b/internal/state/rollback_test.go index 33d76e228..1fc3528b4 100644 --- a/internal/state/rollback_test.go +++ b/internal/state/rollback_test.go @@ -1,7 +1,6 @@ package state_test import ( - "context" "math/rand" "testing" "time" @@ -125,8 +124,7 @@ func TestRollbackDifferentStateHeight(t *testing.T) { func setupStateStore(t *testing.T, height int64) state.Store { stateStore := state.NewStore(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() valSet, _ := factory.ValidatorSet(ctx, t, 5, 10) params := types.DefaultConsensusParams() diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 5e39cd949..5dbe6b52f 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -2,7 +2,6 @@ package state_test import ( "bytes" - "context" "fmt" "math" "math/big" @@ -316,8 +315,7 @@ func TestOneValidatorChangesSaveLoad(t *testing.T) { } func TestProposerFrequency(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // some explicit test cases testCases := []struct { diff --git a/internal/state/store_test.go b/internal/state/store_test.go index d2de75553..722676c06 100644 --- a/internal/state/store_test.go +++ b/internal/state/store_test.go @@ -1,7 +1,6 @@ package state_test import ( - "context" "fmt" "math/rand" "os" @@ -28,8 +27,7 @@ const ( func TestStoreBootstrap(t *testing.T) { stateDB := dbm.NewMemDB() stateStore := sm.NewStore(stateDB) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() val, _, err := factory.Validator(ctx, 10+int64(rand.Uint32())) require.NoError(t, err) val2, _, err := factory.Validator(ctx, 10+int64(rand.Uint32())) @@ -58,8 +56,7 @@ func TestStoreBootstrap(t *testing.T) { func TestStoreLoadValidators(t *testing.T) { stateDB := dbm.NewMemDB() stateStore := sm.NewStore(stateDB) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() val, _, err := factory.Validator(ctx, 10+int64(rand.Uint32())) require.NoError(t, err) val2, _, err := factory.Validator(ctx, 10+int64(rand.Uint32())) @@ -148,8 +145,7 @@ func BenchmarkLoadValidators(b *testing.B) { } func TestStoreLoadConsensusParams(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() stateDB := dbm.NewMemDB() stateStore := sm.NewStore(stateDB) diff --git a/internal/state/validation_test.go b/internal/state/validation_test.go index fb746cee7..faec27145 100644 --- a/internal/state/validation_test.go +++ b/internal/state/validation_test.go @@ -1,7 +1,6 @@ package state_test import ( - "context" "testing" "time" @@ -31,8 +30,7 @@ import ( const validationTestsStopHeight int64 = 10 func TestValidateBlockHeader(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() proxyApp := proxy.New(abciclient.NewLocalClient(logger, &testApp{}), logger, proxy.NopMetrics()) require.NoError(t, proxyApp.Start(ctx)) @@ -138,8 +136,7 @@ func TestValidateBlockHeader(t *testing.T) { } func TestValidateBlockCommit(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() proxyApp := proxy.New(abciclient.NewLocalClient(logger, &testApp{}), logger, proxy.NopMetrics()) @@ -288,8 +285,7 @@ func TestValidateBlockCommit(t *testing.T) { } func TestValidateBlockEvidence(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() proxyApp := proxy.New(abciclient.NewLocalClient(logger, &testApp{}), logger, proxy.NopMetrics()) diff --git a/internal/statesync/block_queue_test.go b/internal/statesync/block_queue_test.go index 8f6cdcd65..eea17881c 100644 --- a/internal/statesync/block_queue_test.go +++ b/internal/statesync/block_queue_test.go @@ -23,8 +23,7 @@ var ( ) func TestBlockQueueBasic(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") require.NoError(t, err) @@ -73,8 +72,7 @@ loop: // Test with spurious failures and retries func TestBlockQueueWithFailures(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") require.NoError(t, err) @@ -132,8 +130,7 @@ func TestBlockQueueBlocks(t *testing.T) { expectedHeight := startHeight retryHeight := stopHeight + 2 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() loop: for { @@ -181,8 +178,7 @@ func TestBlockQueueAcceptsNoMoreBlocks(t *testing.T) { queue := newBlockQueue(startHeight, stopHeight, 1, stopTime, 1) defer queue.close() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() loop: for { @@ -210,8 +206,7 @@ func TestBlockQueueStopTime(t *testing.T) { queue := newBlockQueue(startHeight, stopHeight, 1, stopTime, 1) wg := &sync.WaitGroup{} - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() baseTime := stopTime.Add(-50 * time.Second) @@ -257,8 +252,7 @@ func TestBlockQueueInitialHeight(t *testing.T) { queue := newBlockQueue(startHeight, stopHeight, initialHeight, stopTime, 1) wg := &sync.WaitGroup{} - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // asynchronously fetch blocks and add it to the queue for i := 0; i <= numWorkers; i++ { diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index bb10df6fe..a27fdac89 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -238,8 +238,7 @@ func TestReactor_Sync(t *testing.T) { } func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, nil, 2) @@ -284,13 +283,9 @@ func TestReactor_ChunkRequest(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - for name, tc := range testcases { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() + ctx := t.Context() // mock ABCI connection to return local snapshots conn := &clientmocks.Client{} @@ -318,8 +313,7 @@ func TestReactor_ChunkRequest(t *testing.T) { } func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, nil, 2) @@ -371,15 +365,9 @@ func TestReactor_SnapshotsRequest(t *testing.T) { }, }, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - for name, tc := range testcases { - tc := tc - t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx := t.Context() // mock ABCI connection to return local snapshots conn := &clientmocks.Client{} @@ -412,8 +400,7 @@ func TestReactor_SnapshotsRequest(t *testing.T) { } func TestReactor_LightBlockResponse(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, nil, 2) @@ -537,8 +524,7 @@ func TestReactor_BlockProviders(t *testing.T) { } func TestReactor_StateProviderP2P(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, nil, 3) // make syncer non nil else test won't think we are state syncing @@ -633,17 +619,12 @@ func TestReactor_StateProviderP2P(t *testing.T) { } func TestReactor_Backfill(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - // test backfill algorithm with varying failure rates [0, 10] failureRates := []int{0, 2, 9} for _, failureRate := range failureRates { failureRate := failureRate t.Run(fmt.Sprintf("failure rate: %d", failureRate), func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - + ctx := t.Context() t.Cleanup(leaktest.CheckTimeout(t, 1*time.Minute)) rts := setup(ctx, t, nil, nil, 21) diff --git a/internal/statesync/syncer_test.go b/internal/statesync/syncer_test.go index ecac38aba..2f2623d90 100644 --- a/internal/statesync/syncer_test.go +++ b/internal/statesync/syncer_test.go @@ -1,7 +1,6 @@ package statesync import ( - "context" "errors" "sync" "testing" @@ -22,8 +21,7 @@ import ( ) func TestSyncer_SyncAny(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() state := sm.State{ ChainID: "chain", @@ -223,8 +221,7 @@ func TestSyncer_SyncAny_noSnapshots(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, stateProvider, 2) @@ -236,8 +233,7 @@ func TestSyncer_SyncAny_abort(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, stateProvider, 2) @@ -260,8 +256,7 @@ func TestSyncer_SyncAny_reject(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, stateProvider, 2) @@ -302,8 +297,7 @@ func TestSyncer_SyncAny_reject_format(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, stateProvider, 2) @@ -340,8 +334,7 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, stateProvider, 2) @@ -389,8 +382,7 @@ func TestSyncer_SyncAny_abciError(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, stateProvider, 2) @@ -430,14 +422,9 @@ func TestSyncer_offerSnapshot(t *testing.T) { "unknown non-zero": {9, nil, unknownErr}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx := t.Context() stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) @@ -483,15 +470,9 @@ func TestSyncer_applyChunks_Results(t *testing.T) { "unknown non-zero": {9, nil, unknownErr}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - + ctx := t.Context() stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) @@ -543,14 +524,10 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { "retry_snapshot": {abci.ResponseApplySnapshotChunk_RETRY_SNAPSHOT}, "reject_snapshot": {abci.ResponseApplySnapshotChunk_REJECT_SNAPSHOT}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx := t.Context() stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) @@ -614,14 +591,9 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { "retry_snapshot": {abci.ResponseApplySnapshotChunk_RETRY_SNAPSHOT}, "reject_snapshot": {abci.ResponseApplySnapshotChunk_REJECT_SNAPSHOT}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx := t.Context() stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) @@ -750,14 +722,10 @@ func TestSyncer_verifyApp(t *testing.T) { }, nil, errVerifyFailed}, "error": {nil, boom, boom}, } - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx := t.Context() rts := setup(ctx, t, nil, nil, 2) diff --git a/libs/cli/setup_test.go b/libs/cli/setup_test.go index c7d2f05a1..c5748747c 100644 --- a/libs/cli/setup_test.go +++ b/libs/cli/setup_test.go @@ -18,8 +18,7 @@ import ( ) func TestSetupEnv(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cases := []struct { args []string @@ -73,8 +72,7 @@ func writeConfigVals(dir string, vals map[string]string) error { } func TestSetupConfig(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // we pre-create two config files we can refer to in the rest of // the test cases. @@ -134,8 +132,7 @@ type DemoConfig struct { } func TestSetupUnmarshal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // we pre-create two config files we can refer to in the rest of // the test cases. @@ -208,8 +205,7 @@ func TestSetupUnmarshal(t *testing.T) { } func TestSetupTrace(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cases := []struct { args []string diff --git a/libs/events/events_test.go b/libs/events/events_test.go index d21080ca1..d30c9c6ad 100644 --- a/libs/events/events_test.go +++ b/libs/events/events_test.go @@ -14,8 +14,7 @@ import ( // TestAddListenerForEventFireOnce sets up an EventSwitch, subscribes a single // listener to an event, and sends a string "data". func TestAddListenerForEventFireOnce(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() evsw := NewEventSwitch() @@ -39,8 +38,7 @@ func TestAddListenerForEventFireOnce(t *testing.T) { // TestAddListenerForEventFireMany sets up an EventSwitch, subscribes a single // listener to an event, and sends a thousand integers. func TestAddListenerForEventFireMany(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() evsw := NewEventSwitch() @@ -73,8 +71,7 @@ func TestAddListenerForEventFireMany(t *testing.T) { // listener to three different events and sends a thousand integers for each // of the three events. func TestAddListenerForDifferentEvents(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Cleanup(leaktest.Check(t)) @@ -135,8 +132,7 @@ func TestAddListenerForDifferentEvents(t *testing.T) { // listener to two of those three events, and then sends a thousand integers // for each of the three events. func TestAddDifferentListenerForDifferentEvents(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() t.Cleanup(leaktest.Check(t)) @@ -229,8 +225,7 @@ func TestAddDifferentListenerForDifferentEvents(t *testing.T) { // NOTE: it is important to run this test with race conditions tracking on, // `go test -race`, to examine for possible race conditions. func TestManageListenersAsync(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() evsw := NewEventSwitch() diff --git a/libs/service/service_test.go b/libs/service/service_test.go index 65f947ac2..98735149f 100644 --- a/libs/service/service_test.go +++ b/libs/service/service_test.go @@ -57,14 +57,10 @@ func (t *testService) isMultiStopped() bool { func TestBaseService(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - logger := log.NewNopLogger() t.Run("Wait", func(t *testing.T) { - wctx, wcancel := context.WithCancel(ctx) - defer wcancel() + wctx, wcancel := context.WithCancel(t.Context()) ts := &testService{} ts.BaseService = *NewBaseService(logger, t.Name(), ts) err := ts.Start(wctx) @@ -88,6 +84,7 @@ func TestBaseService(t *testing.T) { } }) t.Run("ManualStop", func(t *testing.T) { + ctx := t.Context() ts := &testService{} ts.BaseService = *NewBaseService(logger, t.Name(), ts) require.False(t, ts.IsRunning()) @@ -102,6 +99,7 @@ func TestBaseService(t *testing.T) { }) t.Run("MultiStop", func(t *testing.T) { t.Run("SingleThreaded", func(t *testing.T) { + ctx := t.Context() ts := &testService{} ts.BaseService = *NewBaseService(logger, t.Name(), ts) @@ -114,8 +112,7 @@ func TestBaseService(t *testing.T) { require.False(t, ts.isMultiStopped()) }) t.Run("MultiThreaded", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() ts := &testService{} ts.BaseService = *NewBaseService(logger, t.Name(), ts) @@ -123,13 +120,13 @@ func TestBaseService(t *testing.T) { require.NoError(t, ts.Start(ctx)) require.True(t, ts.isStarted()) - go ts.Stop() - go cancel() - - ts.Wait() + t.Cleanup(func() { + ts.Stop() + ts.Wait() - require.True(t, ts.isStopped()) - require.False(t, ts.isMultiStopped()) + require.True(t, ts.isStopped()) + require.False(t, ts.isMultiStopped()) + }) }) }) diff --git a/libs/utils/channels.go b/libs/utils/channels.go deleted file mode 100644 index 9eed500ff..000000000 --- a/libs/utils/channels.go +++ /dev/null @@ -1,74 +0,0 @@ -package utils - -import ( - "context" - - "github.com/pkg/errors" -) - -// Recv receives a value from a channel or returns an error if the context is canceled. -func Recv[T any](ctx context.Context, ch <-chan T) (zero T, err error) { - select { - case v, ok := <-ch: - if ok { - return v, nil - } - // We are not interested in channel closing, - // patiently wait for the context to be done instead. - <-ctx.Done() - return zero, ctx.Err() - case <-ctx.Done(): - return zero, ctx.Err() - } -} - -// RecvOrClosed receives a value from a channel, returns false if channel got closed, -// or returns an error if the context is canceled. -func RecvOrClosed[T any](ctx context.Context, ch <-chan T) (T, bool, error) { - select { - case v, ok := <-ch: - return v, ok, nil - case <-ctx.Done(): - var zero T - return zero, false, ctx.Err() - } -} - -// Send a value to channel or returns an error if the context is canceled. -func Send[T any](ctx context.Context, ch chan<- T, v T) error { - select { - case ch <- v: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// SendOrDrop send a value to channel if not full or drop the item if the channel is full. -func SendOrDrop[T any](ch chan<- T, v T) error { - select { - case ch <- v: - return nil - default: - // drop the item - return nil - } -} - -// ForEach is a helper function that reads from a channel and calls a handler for each item. -// this avoids needing a lot of for/select boilerplate everywhere. -func ForEach[T any](ctx context.Context, ch <-chan T, handler func(T) error) error { - for { - select { - case <-ctx.Done(): - return errors.WithStack(ctx.Err()) - case item, ok := <-ch: - if !ok { - return nil // Channel closed - } - if err := handler(item); err != nil { - return err // Stop on error - } - } - } -} diff --git a/libs/utils/mutex.go b/libs/utils/mutex.go deleted file mode 100644 index b6f4a9a58..000000000 --- a/libs/utils/mutex.go +++ /dev/null @@ -1,206 +0,0 @@ -package utils - -import ( - "context" - "iter" - "sync" - "sync/atomic" - - "golang.org/x/sync/errgroup" -) - -// Mutex guards access to object of type T. -type Mutex[T any] struct { - mu sync.Mutex - value T -} - -// NewMutex creates a new Mutex with given object. -func NewMutex[T any](value T) (m Mutex[T]) { - m.value = value - // nolint:nakedret - return -} - -// Lock returns an iterator which locks the mutex and yields the guarded object. -// The mutex is unlocked when the iterator is done. -// If the mutex is nil, the iterator is a no-op. -func (m *Mutex[T]) Lock() iter.Seq[T] { - return func(yield func(val T) bool) { - m.mu.Lock() - defer m.mu.Unlock() - _ = yield(m.value) - } -} - -// version of the value stored in an atomic watch. -type version[T any] struct { - updated chan struct{} - value T -} - -// newVersion constructs a new active version. -func newVersion[T any](value T) *version[T] { - return &version[T]{make(chan struct{}), value} -} - -type atomicWatch[T any] struct { - ptr atomic.Pointer[version[T]] -} - -// AtomicWatch stores a pointer to an IMMUTABLE value. -// Loading and waiting for updates do NOT require locking. -// TODO(gprusak): remove mutex and rename to AtomicSend, -// this will allow for sharing a mutex across multiple AtomicSenders. -type AtomicWatch[T any] struct { - atomicWatch[T] - mu sync.Mutex -} - -// AtomicRecv is a read-only reference to AtomicWatch. -type AtomicRecv[T any] struct{ *atomicWatch[T] } - -// NewAtomicWatch creates a new AtomicWatch with the given initial value. -func NewAtomicWatch[T any](value T) (w AtomicWatch[T]) { - w.ptr.Store(newVersion(value)) - // nolint:nakedret - return -} - -// Subscribe returns a view-only API of the atomic watch. -func (w *AtomicWatch[T]) Subscribe() AtomicRecv[T] { - return AtomicRecv[T]{&w.atomicWatch} -} - -// Load returns the current value of the atomic watch. -// Does not do any locking. -func (w *atomicWatch[T]) Load() T { return w.ptr.Load().value } - -// Store updates the value of the atomic watch. -func (w *AtomicWatch[T]) Store(value T) { - w.mu.Lock() - defer w.mu.Unlock() - close(w.ptr.Swap(newVersion(value)).updated) -} - -// Update conditionally updates the value of the atomic watch. -func (w *AtomicWatch[T]) Update(f func(T) (T, bool)) { - w.mu.Lock() - defer w.mu.Unlock() - old := w.ptr.Load() - if value, ok := f(old.value); ok { - w.ptr.Store(newVersion(value)) - close(old.updated) - } -} - -// Wait waits for the value of the atomic watch to satisfy the predicate. -// Does not do any locking. -func (w *atomicWatch[T]) Wait(ctx context.Context, pred func(T) bool) (T, error) { - for { - v := w.ptr.Load() - if pred(v.value) { - return v.value, nil - } - select { - case <-ctx.Done(): - return Zero[T](), ctx.Err() - case <-v.updated: - } - } -} - -// Iter executes sequentially the function f on each value of the atomic watch. -// Context passed to f is canceled when the next value is available. -// Exits when the returned error is different from nil and context.Canceled, -// or when the context passed to Iter is canceled (after f exits). -func (w *atomicWatch[T]) Iter(ctx context.Context, f func(ctx context.Context, v T) error) error { - for ctx.Err() == nil { - v := w.ptr.Load() - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { return f(ctx, v.value) }) - g.Go(func() error { - select { - case <-ctx.Done(): - case <-v.updated: - } - return context.Canceled - }) - if err := IgnoreCancel(g.Wait()); err != nil { - return err - } - } - return ctx.Err() -} - -// WatchCtrl controls the locked object in a Watch. -// It is provided only in the iterator returned by Lock(). -// Should NOT be stored anywhere. -type WatchCtrl struct { - mu sync.Mutex - updated chan struct{} -} - -// Watch stores a value of type T. -// Essentially a mutex, that can be awaited for updates. -type Watch[T any] struct { - ctrl WatchCtrl - val T -} - -// NewWatch constructs a new watch with the given value. -// Note that value in the watch cannot be changed, so T -// should be a pointer type if updates are required. -func NewWatch[T any](val T) Watch[T] { - return Watch[T]{ - WatchCtrl{updated: make(chan struct{})}, - val, - } -} - -// Wait waits for the value in the watch to be updated. -// Should be called only after locking the watch, i.e. within Lock() iterator. -// It unlocks -> waits for the update -> locks again. -func (c *WatchCtrl) Wait(ctx context.Context) error { - updated := c.updated - c.mu.Unlock() - defer c.mu.Lock() - select { - case <-ctx.Done(): - return ctx.Err() - case <-updated: - return nil - } -} - -// WaitUntil waits for the value in the watch to satisfy the predicate. -// Should be called only after locking the watch, i.e. within Lock() iterator. -// The predicate is evaluated under the lock, so it can access the guarded object. -func (c *WatchCtrl) WaitUntil(ctx context.Context, pred func() bool) error { - for !pred() { - if err := c.Wait(ctx); err != nil { - return err - } - } - return nil -} - -// Updated signals waiters that the value in the watch has been updated. -func (c *WatchCtrl) Updated() { - close(c.updated) - c.updated = make(chan struct{}) -} - -// Lock returns an iterator which locks the watch and yields the guarded object. -// The watch is unlocked when the iterator is done. -// If the watch is nil, the iterator is a no-op. -// Additionally the WatchCtrl object is provided to the yield function: -// * to unlock -> wait for the update -> lock again, call ctrl.Wait(ctx) -// * to signal an update, call ctrl.Updated(). -func (w *Watch[T]) Lock() iter.Seq2[T, *WatchCtrl] { - return func(yield func(val T, ctrl *WatchCtrl) bool) { - w.ctrl.mu.Lock() - defer w.ctrl.mu.Unlock() - _ = yield(w.val, &w.ctrl) - } -} diff --git a/libs/utils/mutex_test.go b/libs/utils/mutex_test.go deleted file mode 100644 index b4a85abbc..000000000 --- a/libs/utils/mutex_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package utils_test - -import ( - "context" - "fmt" - "testing" - - "github.com/tendermint/tendermint/libs/utils/require" - "github.com/tendermint/tendermint/libs/utils/scope" - "github.com/tendermint/tendermint/libs/utils" -) - -func TestAtomicWatch(t *testing.T) { - ctx := t.Context() - v := 5 - w := utils.NewAtomicWatch(&v) - require.Equal(t, 5, *w.Load()) - - want := 10 - if err := scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { - s.Spawn(func() error { - for i := 0; i <= want; i++ { - w.Store(&i) - } - return nil - }) - - got, err := w.Wait(ctx, func(v *int) bool { return *v >= want }) - if err != nil { - return err - } - if *got != want { - return fmt.Errorf("got %v, want %v", *got, want) - } - return nil - }); err != nil { - t.Fatal(err) - } -} diff --git a/libs/utils/option.go b/libs/utils/option.go deleted file mode 100644 index 85fd6a471..000000000 --- a/libs/utils/option.go +++ /dev/null @@ -1,73 +0,0 @@ -package utils - -import ( - "encoding/json" -) - -// Option type inspired https://pkg.go.dev/github.com/samber/mo. -type Option[T any] struct { - ReadOnly - isPresent bool - value T -} - -// Some creates an Option with a value. -func Some[T any](value T) Option[T] { - return Option[T]{isPresent: true, value: value} -} - -// None creates an Option without a value. -func None[T any]() (zero Option[T]) { return } - -// Get unpacks the value from the Option, returning true if it was present. -func (o Option[T]) Get() (T, bool) { - if o.isPresent { - return o.value, true - } - return Zero[T](), false -} - -// IsPresent checks if the Option contains a value. -func (o Option[T]) IsPresent() bool { - return o.isPresent -} - -// Or returns the value if present, otherwise returns the default value. -func (o *Option[T]) Or(def T) T { - if o.isPresent { - return o.value - } - return def -} - -// MapOpt applies a function to the value if present, returning a new Option. -func MapOpt[T, R any](o Option[T], f func(T) R) Option[R] { - if o.isPresent { - return Some(f(o.value)) - } - return None[R]() -} - -// MarshalJSON implements the json.Marshaler interface. -// Note that it is defined on value, not pointer, because -// json.Marshal cannot call pointer methods on fields -// (i.e. it is broken by design). -func (o Option[T]) MarshalJSON() ([]byte, error) { - if o.isPresent { - return json.Marshal(o.value) - } - return []byte("null"), nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (o *Option[T]) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - o.isPresent = false - return nil - } - if err := json.Unmarshal(data, &o.value); err != nil { - return err - } - o.isPresent = true - return nil -} diff --git a/libs/utils/option_test.go b/libs/utils/option_test.go deleted file mode 100644 index 04a55a1e1..000000000 --- a/libs/utils/option_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package utils - -import ( - "encoding/json" - "testing" - - "github.com/tendermint/tendermint/libs/utils/require" -) - -func testJSON[T any](t *testing.T, want T) { - enc, err := json.Marshal(want) - require.NoError(t, err) - t.Logf("%s", enc) - var got T - require.NoError(t, json.Unmarshal(enc, &got)) - require.NoError(t, TestDiff(want, got)) -} - -func TestOptionJSON(t *testing.T) { - type a struct { - X Option[int] - Y Option[string] - } - type b struct { - X Option[int] `json:"X,omitzero"` - Y Option[string] `json:"Y,omitzero"` - } - testJSON(t, &a{}) - testJSON(t, &a{Some(1), Some("a")}) - testJSON(t, &b{}) - testJSON(t, &b{Some(1), Some("a")}) -} diff --git a/libs/utils/proto.go b/libs/utils/proto.go deleted file mode 100644 index 5f5ad7a41..000000000 --- a/libs/utils/proto.go +++ /dev/null @@ -1,143 +0,0 @@ -package utils - -import ( - "crypto/sha256" - "errors" - "fmt" - "sync" - - "google.golang.org/protobuf/proto" -) - -// Hash is a SHA-256 hash. -type Hash [sha256.Size]byte - -// GetHash computes a hash of the given data. -func GetHash(data []byte) Hash { - return sha256.Sum256(data) -} - -// ParseHash parses a Hash from bytes. -func ParseHash(raw []byte) (Hash, error) { - if got, want := len(raw), sha256.Size; got != want { - return Hash{}, fmt.Errorf("hash size = %v, want %v", got, want) - } - return Hash(raw), nil -} - -// ProtoClone clones a proto.Message object. -func ProtoClone[T proto.Message](item T) T { - return proto.Clone(item).(T) -} - -// ProtoEqual compares two proto.Message objects. -func ProtoEqual[T proto.Message](a, b T) bool { - return proto.Equal(a, b) -} - -// ProtoHash hashes a proto.Message object. -// TODO(gprusak): make it deterministic. -func ProtoHash(a proto.Message) Hash { - raw, err := proto.Marshal(a) - if err != nil { - panic(err) - } - return sha256.Sum256(raw) -} - -// ProtoMessage is comparable proto.Message. -type ProtoMessage interface { - comparable - proto.Message -} - -// ProtoConv is a pair of functions to encode and decode between a type and a ProtoMessage. -type ProtoConv[T any, P ProtoMessage] struct { - Encode func(T) P - Decode func(P) (T, error) -} - -// EncodeSlice encodes a slice of T into a slice of P. -func (c ProtoConv[T, P]) EncodeSlice(t []T) []P { - p := make([]P, len(t)) - for i := range t { - p[i] = c.Encode(t[i]) - } - return p -} - -// DecodeSlice decodes a slice of P into a slice of T. -func (c ProtoConv[T, P]) DecodeSlice(p []P) ([]T, error) { - t := make([]T, len(p)) - var err error - for i := range p { - if t[i], err = c.Decode(p[i]); err != nil { - return nil, fmt.Errorf("[%d]: %w", i, err) - } - } - return t, nil -} - -// Slice constructs a slice. -// It is a syntax sugar for `[]T{v...}`, which avoids -// spelling out T. Not very useful if you need to spell -// out T to construct the elements: in that case -// you might prefer the []T{{...},{...}} syntax instead. -func Slice[T any](v ...T) []T { return v } - -// Alloc moves value to heap. -func Alloc[T any](v T) *T { return &v } - -// Zero returns a zero value of type T. -func Zero[T any]() (zero T) { return } - -// NoCopy may be added to structs which must not be copied -// after the first use. -// -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -// -// Note that it must not be embedded, otherwise Lock and Unlock methods -// will be exported. -type NoCopy struct{} - -// Lock implements sync.Locker. -func (*NoCopy) Lock() {} - -// Unlock implements sync.Locker. -func (*NoCopy) Unlock() {} - -var _ sync.Locker = (*NoCopy)(nil) - -// NoCompare may be added to structs which must not be used as -// map keys. -type NoCompare [0]func() - -// EncodeOpt encodes Option[T], mapping None to Zero[P](). -func (c ProtoConv[T, P]) EncodeOpt(mv Option[T]) P { - v, ok := mv.Get() - if !ok { - return Zero[P]() - } - return c.Encode(v) -} - -// DecodeReq decodes a ProtoMessage into a T, returning an error if p is nil. -func (c ProtoConv[T, P]) DecodeReq(p P) (T, error) { - if p == Zero[P]() { - return Zero[T](), errors.New("missing") - } - return c.Decode(p) -} - -// DecodeOpt decodes a ProtoMessage into a T, returning nil if p is nil. -func (c ProtoConv[T, P]) DecodeOpt(p P) (Option[T], error) { - if p == Zero[P]() { - return None[T](), nil - } - t, err := c.DecodeReq(p) - if err != nil { - return None[T](), err - } - return Some(t), nil -} diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go deleted file mode 100644 index 66bb750d3..000000000 --- a/libs/utils/require/require.go +++ /dev/null @@ -1,81 +0,0 @@ -// Package require reexports strongly typed `testify/require` API. -// We don't reexport `New`, because methods cannot be generic. -package require - -import ( - "cmp" - - "github.com/stretchr/testify/require" -) - -// TestingT . -type TestingT = require.TestingT - -// False . -var False = require.False - -// True . -var True = require.True - -// Contains . -var Contains = require.Contains - -// EqualError . -// TODO: get rid of comparing errors by strings, -// use concrete error types instead. -var EqualError = require.EqualError - -// Error . -var Error = require.Error - -// ErrorIs . -var ErrorIs = require.ErrorIs - -// NoError . -var NoError = require.NoError - -// Empty . -var Empty = require.Empty - -// NotEmpty . -var NotEmpty = require.NotEmpty - -// Len . -var Len = require.Len - -// Nil . -var Nil = require.Nil - -// NotNil . -var NotNil = require.NotNil - -// Panics . -var Panics = require.Panics - -// Fail . -var Fail = require.Fail - -// Positive . -func Positive[T cmp.Ordered](t TestingT, e T, msgAndArgs ...any) { - require.Positive(t, e, msgAndArgs...) -} - -// Less . -func Less[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { - require.Less(t, e1, e2, msgAndArgs...) -} - -// Greater . -func Greater[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { - require.Greater(t, e1, e2, msgAndArgs...) -} - -// Equal . -func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { - require.Equal(t, expected, actual, msgAndArgs...) -} - -// NotEqual . -func NotEqual[T any](t TestingT, expected, actual T, msgAndArgs ...any) { - require.NotEqual(t, expected, actual, msgAndArgs...) -} diff --git a/libs/utils/ringbuf.go b/libs/utils/ringbuf.go deleted file mode 100644 index 5b81c3379..000000000 --- a/libs/utils/ringbuf.go +++ /dev/null @@ -1,83 +0,0 @@ -package utils - -import ( - "iter" -) - -// RingBuf is a ring buffer. -// NOT thread-safe. -type RingBuf[T any] struct { - first int - len int - buf []T -} - -// NewRingBuf creates a new ring buffer with the given capacity. -func NewRingBuf[T any](capacity int) RingBuf[T] { - return RingBuf[T]{first: 0, len: 0, buf: make([]T, capacity)} -} - -// Len returns the number of elements in the ring buffer. -func (r *RingBuf[T]) Len() int { - return r.len -} - -// Full returns true if the ring buffer is full. -func (r *RingBuf[T]) Full() bool { - return r.len == len(r.buf) -} - -// Get returns the i-th element of the ring buffer. -// Panics if i is out of range. -func (r *RingBuf[T]) Get(i int) T { - if i < 0 || i >= r.len { - panic("index out of range") - } - return r.buf[(r.first+i)%len(r.buf)] -} - -// TryGet returns the i-th element of the ring buffer. -func (r *RingBuf[T]) TryGet(i int) (T, bool) { - if i < 0 || i >= r.len { - return Zero[T](), false - } - return r.buf[(r.first+i)%len(r.buf)], true -} - -// Last returns the last element of the ring buffer. -func (r *RingBuf[T]) Last() (T, bool) { - return r.TryGet(r.len - 1) -} - -// PushBack adds an element to the back of the ring buffer. -// Panics if the ring buffer is full. -func (r *RingBuf[T]) PushBack(x T) { - if r.len == len(r.buf) { - panic("ring buffer full") - } - r.buf[(r.first+r.len)%len(r.buf)] = x - r.len += 1 -} - -// PopFront removes and returns the first element of the ring buffer. -// Panics if the ring buffer is empty. -func (r *RingBuf[T]) PopFront() T { - if r.len == 0 { - panic("ring buffer empty") - } - x := r.buf[r.first] - r.first = (r.first + 1) % len(r.buf) - r.len -= 1 - return x -} - -// All iterates over all the elements in the ring buffer. -func (r *RingBuf[T]) All() iter.Seq[T] { - return func(y func(T) bool) { - for i := range r.len { - if !y(r.Get(i)) { - break - } - } - } -} diff --git a/libs/utils/scope/parallel.go b/libs/utils/scope/parallel.go deleted file mode 100644 index 1377184d5..000000000 --- a/libs/utils/scope/parallel.go +++ /dev/null @@ -1,41 +0,0 @@ -package scope - -import ( - "sync" - "sync/atomic" -) - -type parallelScope struct { - wg sync.WaitGroup - err atomic.Pointer[error] -} - -// ParallelScope is a scope which doesn't require cancellation token, -// just parallelization. -type ParallelScope struct{ *parallelScope } - -// Spawn spawns a new task in the scope. -func (s *parallelScope) Spawn(t func() error) { - s.wg.Add(1) - go func() { - if err := t(); err != nil { - s.err.CompareAndSwap(nil, &err) - } - s.wg.Done() - }() -} - -// Parallel executes a function in parallel scope. -// Compared to Run, it does not allow for early cancellation, -// therefore is suitable for non-blocking computations. -// Returns the first error returned by any of the spawned tasks. -// Waits until all the tasks complete, before returning. -func Parallel(main func(ParallelScope) error) error { - var s parallelScope - s.Spawn(func() error { return main(ParallelScope{&s}) }) - s.wg.Wait() - if perr := s.err.Load(); perr != nil { - return *perr - } - return nil -} diff --git a/libs/utils/scope/parallel_test.go b/libs/utils/scope/parallel_test.go deleted file mode 100644 index 7f98872ad..000000000 --- a/libs/utils/scope/parallel_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package scope - -import ( - "errors" - "testing" -) - -func TestParallelOk(t *testing.T) { - x := [10]int{} - if err := Parallel(func(s ParallelScope) error { - for i := range x { - s.Spawn(func() error { - x[i] = i - return nil - }) - } - return nil - }); err != nil { - t.Fatal(err) - } - for want, got := range x { - if want != got { - t.Fatalf("x[%d] = %d, want %d", want, got, want) - } - } -} - -func TestParallelFail(t *testing.T) { - var wantErr = errors.New("custom err") - x := [10]int{} - err := Parallel(func(s ParallelScope) error { - for i := range x { - s.Spawn(func() error { - if i%2 == 0 { - return wantErr - } - x[i] = i - return nil - }) - } - return nil - }) - if !errors.Is(err, wantErr) { - t.Fatalf("err = %v, want %v", err, wantErr) - } - for want, got := range x { - if want%2 == 0 { - want = 0 - } - if want != got { - t.Fatalf("x[%d] = %d, want %d", want, got, want) - } - } -} diff --git a/libs/utils/scope/start.go b/libs/utils/scope/start.go deleted file mode 100644 index cba8d2e4d..000000000 --- a/libs/utils/scope/start.go +++ /dev/null @@ -1,143 +0,0 @@ -package scope - -import ( - "context" - "fmt" - "log" - "sync" - "time" - - "golang.org/x/sync/errgroup" - - "github.com/tendermint/tendermint/libs/utils" -) - -// Scope of concurrenct tasks. -type Scope struct { - // scope is a concurrecy primitive, so no-ctx-in-struct rule does not apply - // nolint:containedctx - ctx context.Context - all *errgroup.Group - main *sync.WaitGroup -} - -// Spawn spawns a main task. -// Scope gets automatically canceled when all the main tasks return. -func (s Scope) Spawn(t func() error) { - s.main.Add(1) - s.all.Go(func() error { - defer s.main.Done() - return t() - }) -} - -// JoinHandle is a handle to an awaitable task. -type JoinHandle[R any] struct { - result *utils.AtomicWatch[*R] -} - -// Spawn1 is the same as Scope.Spawn, but allows awaiting completion of a task and getting its result. -func Spawn1[R any](s Scope, t func() (R, error)) JoinHandle[R] { - result := utils.NewAtomicWatch[*R](nil) - s.Spawn(func() error { - v, err := t() - if err != nil { - return err - } - result.Store(&v) - return nil - }) - return JoinHandle[R]{&result} -} - -// Join awaits completion of a task and returns its result. -// WARNING: it does NOT return the error of the task - error is returned from the Run() command. -// Join() can only fail when context is canceled. -func (h JoinHandle[R]) Join(ctx context.Context) (R, error) { - res, err := h.result.Wait(ctx, func(v *R) bool { return v != nil }) - if err != nil { - return utils.Zero[R](), err - } - return *res, nil -} - -// If true, tasks that do not respect context cancellation will be logged. -// This is useful for debugging, but causes unnecessary overhead. -// Since this is a constant, debug guard should be optimized out by the compiler. -const enableDebugGuard = false - -func (s Scope) debugGuard(name string, done chan struct{}) { - select { - case <-done: - return - case <-s.ctx.Done(): - } - for { - select { - case <-done: - return - case <-time.After(10 * time.Second): - } - log.Printf("task %q still running", name) - } -} - -// SpawnNamed spawns a named main task. -func (s Scope) SpawnNamed(name string, t func() error) { - done := make(chan struct{}) - s.Spawn(func() error { - defer close(done) - if err := t(); err != nil { - return fmt.Errorf("%s: %w", name, err) - } - return nil - }) - if enableDebugGuard { - go s.debugGuard(name, done) - } -} - -// SpawnBgNamed spawns a named background task. -func (s Scope) SpawnBgNamed(name string, t func() error) { - done := make(chan struct{}) - s.SpawnBg(func() error { - defer close(done) - if err := t(); err != nil { - return fmt.Errorf("%s: %w", name, err) - } - return nil - }) - if enableDebugGuard { - go s.debugGuard(name, done) - } -} - -// SpawnBg spawns a background task. -// Background tasks get canceled when all the main tasks return. -func (s Scope) SpawnBg(t func() error) { s.all.Go(t) } - -// Run runs a scope capable of spawning tasks. -// It is guaranteed that all the spawned tasks will be executed (even if spawned after the context is cancelled), -// and that `Run` will return only after all the tasks have completed. -// Context of the tasks will be automatically cancelled as soon as ANY task returns an error. -// Returns the first error returned by any task (main or background). -func Run(ctx context.Context, main func(context.Context, Scope) error) error { - ctx, cancel := context.WithCancel(ctx) - all, ctx := errgroup.WithContext(ctx) - s := Scope{ctx, all, &sync.WaitGroup{}} - s.Spawn(func() error { return main(ctx, s) }) - s.main.Wait() - cancel() - return s.all.Wait() -} - -// Run1 is the same as Run, but returns the result of the main task. -func Run1[R any](ctx context.Context, main func(context.Context, Scope) (R, error)) (res R, err error) { - err = Run(ctx, func(ctx context.Context, s Scope) error { - var err error - res, err = main(ctx, s) - return err - }) - //nolint:nakedret - return -} diff --git a/libs/utils/semaphore.go b/libs/utils/semaphore.go deleted file mode 100644 index 728c12a5c..000000000 --- a/libs/utils/semaphore.go +++ /dev/null @@ -1,24 +0,0 @@ -package utils - -import ( - "context" -) - -// Semaphore provides a way to bound concurrenct access to a resource. -type Semaphore struct { - ch chan struct{} -} - -// NewSemaphore constructs a new semaphore with n permits. -func NewSemaphore(n int) *Semaphore { - return &Semaphore{ch: make(chan struct{}, n)} -} - -// Acquire acquires a permit from the semaphore. -// Blocks until a permit is available. -func (s *Semaphore) Acquire(ctx context.Context) (relase func(), err error) { - if err := Send(ctx, s.ch, struct{}{}); err != nil { - return nil, err - } - return func() { <-s.ch }, nil -} diff --git a/libs/utils/testonly.go b/libs/utils/testonly.go deleted file mode 100644 index afd6b8aa8..000000000 --- a/libs/utils/testonly.go +++ /dev/null @@ -1,152 +0,0 @@ -package utils - -import ( - "fmt" - "math/big" - "math/rand" - "reflect" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/testing/protocmp" -) - -// ReadOnly - if a struct embeds ReadOnly, -// its private fields will be compared by TestEqual. -type ReadOnly struct{} - -// isReadOnly returns true if t embeds ReadOnly. -func isReadOnly(t reflect.Type) bool { - want := reflect.TypeOf(ReadOnly{}) - if t.Kind() != reflect.Struct { - return false - } - for i := range t.NumField() { - if f := t.Field(i); f.Anonymous || f.Type == want { - return true - } - } - return false -} - -func cmpComparer[T any, PT interface { - Cmp(b *T) int - *T -}](a PT, b PT) bool { - if a == nil || b == nil { - return a == b - } - return a.Cmp(b) == 0 -} - -var cmpOpts = []cmp.Option{ - protocmp.Transform(), - cmp.Exporter(isReadOnly), - cmpopts.EquateEmpty(), - cmp.Comparer(cmpComparer[big.Int]), -} - -// TestDiff generates a human-readable diff between two objects. -func TestDiff[T any](want, got T) error { - if diff := cmp.Diff(want, got, cmpOpts...); diff != "" { - return fmt.Errorf("want (-) got (+):\n%s", diff) - } - return nil -} - -// TestEqual is a more robust replacement for reflect.DeepEqual for tests. -func TestEqual[T any](a, b T) bool { - return cmp.Equal(a, b, cmpOpts...) -} - -// TestRngSplit returns a new random number splitted from the given one. -// This is a very primitive splitting, known to result with dependent randomness. -// If that ever causes a problem, we can switch to SplitMix. -func TestRngSplit(rng *rand.Rand) *rand.Rand { - return rand.New(rand.NewSource(rng.Int63())) -} - -// TestRng returns a deterministic random number generator. -func TestRng() *rand.Rand { - return rand.New(rand.NewSource(789345342)) -} - -var alphanum = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") - -// GenString generates a random string of length n. -func GenString(rng *rand.Rand, n int) string { - s := make([]rune, n) - for i := range n { - s[i] = alphanum[rand.Intn(len(alphanum))] - } - return string(s) -} - -// GenBytes generates a random byte slice. -func GenBytes(rng *rand.Rand, n int) []byte { - s := make([]byte, n) - _, _ = rng.Read(s) - return s -} - -// GenF is a function which generates T. -type GenF[T any] = func(rng *rand.Rand) T - -// GenSlice generates a slice of small random length. -func GenSlice[T any](rng *rand.Rand, gen GenF[T]) []T { - return GenSliceN(rng, 2+rng.Intn(3), gen) -} - -// GenSliceN generates a slice of n elements. -func GenSliceN[T any](rng *rand.Rand, n int, gen GenF[T]) []T { - s := make([]T, n) - for i := range s { - s[i] = gen(rng) - } - return s -} - -// GenMap generates a map of small random length. -func GenMap[K comparable, V any](rng *rand.Rand, genK GenF[K], genV GenF[V]) map[K]V { - return GenMapN(rng, 2+rng.Intn(3), genK, genV) -} - -// GenMapN generates a map of n elements. -func GenMapN[K comparable, V any](rng *rand.Rand, n int, genK GenF[K], genV GenF[V]) map[K]V { - m := make(map[K]V, n) - for len(m) < n { - m[genK(rng)] = genV(rng) - } - return m -} - -// GenTimestamp generates a random timestamp. -func GenTimestamp(rng *rand.Rand) time.Time { - return time.Unix(0, rng.Int63()) -} - -// GenHash generates a random Hash. -func GenHash(rng *rand.Rand) Hash { - var h Hash - _, _ = rng.Read(h[:]) - return h -} - -// Test tests whether reencoding a value is an identity operation. -func (c *ProtoConv[T, P]) Test(want T) error { - p := c.Encode(want) - raw, err := proto.Marshal(p) - if err != nil { - return fmt.Errorf("Marshal(): %w", err) - } - if err := proto.Unmarshal(raw, p); err != nil { - return fmt.Errorf("Unmarshal(): %w", err) - } - got, err := c.Decode(p) - if err != nil { - return fmt.Errorf("Decode(Encode()): %w", err) - } - return TestDiff(want, got) -} diff --git a/libs/utils/wait.go b/libs/utils/wait.go deleted file mode 100644 index 4c8c6634f..000000000 --- a/libs/utils/wait.go +++ /dev/null @@ -1,119 +0,0 @@ -package utils - -import ( - "context" - "encoding" - "errors" - "time" -) - -// IgnoreCancel returns nil if the error is context.Canceled, err otherwise. -func IgnoreCancel(err error) error { - if errors.Is(err, context.Canceled) { - return nil - } - return err -} - -// WithTimeout executes a function with a timeout. -func WithTimeout(ctx context.Context, d time.Duration, f func(ctx context.Context) error) error { - ctx, cancel := context.WithTimeout(ctx, d) - defer cancel() - return f(ctx) -} - -// WithTimeout1 executes a function with a timeout. -func WithTimeout1[R any](ctx context.Context, d time.Duration, f func(ctx context.Context) (R, error)) (R, error) { - ctx, cancel := context.WithTimeout(ctx, d) - defer cancel() - return f(ctx) -} - -// Sleep sleeps for a duration or until the context is canceled. -func Sleep(ctx context.Context, d time.Duration) error { - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(d): - return nil - } -} - -// SleepUntil sleeps until deadline t or until the context is canceled. -func SleepUntil(ctx context.Context, t time.Time) error { - return Sleep(ctx, time.Until(t)) -} - -// WaitFor polls a check function until it returns true or the context is canceled. -func WaitFor(ctx context.Context, interval time.Duration, check func() bool) error { - if check() { - return nil - } - ticker := time.NewTicker(interval) - for { - if check() { - return nil - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - } - } -} - -// WaitForWithTimeout polls a check function until it returns true, the context is canceled, or the timeout is reached. -func WaitForWithTimeout(ctx context.Context, interval, timeout time.Duration, check func() bool) error { - if check() { - return nil - } - - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - if check() { - return nil - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - } - } -} - -// Duration is a wrapper type around time.Duration that supports JSON marshaling/unmarshaling. -// nolint:recvcheck -type Duration time.Duration - -// MarshalText implements json.TextMarshaler interface to convert Duration to JSON string. -func (d Duration) MarshalText() ([]byte, error) { - return []byte(time.Duration(d).String()), nil -} - -// UnmarshalText implements json.TextUnmarshaler. -func (d *Duration) UnmarshalText(b []byte) error { - tmp, err := time.ParseDuration(string(b)) - if err != nil { - return err - } - *d = Duration(tmp) - return nil -} - -var _ encoding.TextMarshaler = Zero[Duration]() -var _ encoding.TextUnmarshaler = (*Duration)(nil) - -// Duration returns the underlying time.Duration value. -func (d Duration) Duration() time.Duration { - return time.Duration(d) -} - -// Seconds returns the underlying time.Duration value in seconds. -func (d Duration) Seconds() float64 { - return time.Duration(d).Seconds() -} diff --git a/libs/utils/wait_test.go b/libs/utils/wait_test.go deleted file mode 100644 index 91edc1267..000000000 --- a/libs/utils/wait_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package utils - -import ( - "encoding/json" - "testing" - "time" -) - -func TestJSON(t *testing.T) { - var got, want struct{ X Duration } - want.X = Duration(100 * time.Millisecond) - j, err := json.Marshal(want) - if err != nil { - t.Fatal(err) - } - t.Logf("%s", j) - if err := json.Unmarshal(j, &got); err != nil { - t.Fatal(err) - } - if err := TestDiff(want, got); err != nil { - t.Fatal(err) - } -} diff --git a/light/client_benchmark_test.go b/light/client_benchmark_test.go index ef5ea0de8..c07e0183b 100644 --- a/light/client_benchmark_test.go +++ b/light/client_benchmark_test.go @@ -103,6 +103,7 @@ func BenchmarkSequence(b *testing.B) { } func BenchmarkBisection(b *testing.B) { + ctx := b.Context() headers, vals, _ := genLightBlocksWithKeys(b, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) genesisBlock, _ := benchmarkFullNode.LightBlock(ctx, 1) diff --git a/light/client_test.go b/light/client_test.go index 55f0b877a..fee4bd841 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -37,8 +37,7 @@ func init() { } func TestClient(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() var ( keys = genPrivKeys(4) vals = keys.ToValidators(20, 10) @@ -228,8 +227,7 @@ func TestClient(t *testing.T) { for _, tc := range testCases { testCase := tc t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -351,14 +349,10 @@ func TestClient(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() mockNode := mockNodeFromHeadersAndVals(tc.otherHeaders, tc.vals) @@ -411,8 +405,7 @@ func TestClient(t *testing.T) { mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) require.NoError(t, err) @@ -442,8 +435,7 @@ func TestClient(t *testing.T) { mockNode.AssertExpectations(t) }) t.Run("BisectionBetweenTrustedHeaders", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) c, err := light.NewClient( @@ -475,8 +467,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("Cleanup", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() mockFullNode := &provider_mocks.Provider{} @@ -507,13 +498,9 @@ func TestClient(t *testing.T) { t.Run("RestoresTrustedHeaderAfterStartup", func(t *testing.T) { // trustedHeader.Height == options.Height - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - // 1. options.Hash == trustedHeader.Hash t.Run("hashes should match", func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -544,8 +531,7 @@ func TestClient(t *testing.T) { // 2. options.Hash != trustedHeader.Hash t.Run("hashes should not match", func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() + ctx := t.Context() trustedStore := dbs.New(dbm.NewMemDB()) err := trustedStore.SaveLightBlock(l1) @@ -585,8 +571,7 @@ func TestClient(t *testing.T) { }) }) t.Run("Update", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockFullNode := &provider_mocks.Provider{} mockFullNode.On("LightBlock", mock.Anything, int64(0)).Return(l3, nil) @@ -618,8 +603,7 @@ func TestClient(t *testing.T) { }) t.Run("Concurrency", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() mockFullNode := &provider_mocks.Provider{} @@ -667,8 +651,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("AddProviders", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockFullNode := mockNodeFromHeadersAndVals(map[int64]*types.SignedHeader{ 1: h1, @@ -707,8 +690,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("ReplacesPrimaryWithWitnessIfPrimaryIsUnavailable", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockFullNode := &provider_mocks.Provider{} mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) @@ -746,7 +728,6 @@ func TestClient(t *testing.T) { t.Run("TerminatesWitnessSearchAfterContextDeadlineExpires", func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Duration(1*time.Second)) defer cancel() - mockDeadNode := &provider_mocks.Provider{} mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) mockSlowNode := &provider_mocks.Provider{} @@ -771,8 +752,7 @@ func TestClient(t *testing.T) { mockSlowNode.AssertExpectations(t) }) t.Run("ReplacesPrimaryWithWitnessIfPrimaryDoesntHaveBlock", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockFullNode := &provider_mocks.Provider{} mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) @@ -800,8 +780,7 @@ func TestClient(t *testing.T) { mockFullNode.AssertExpectations(t) }) t.Run("BackwardsVerification", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() { @@ -950,8 +929,7 @@ func TestClient(t *testing.T) { mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) mockFullNode.On("ID", mock.Anything, mock.Anything).Return(id3, nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() lb1, _ := mockBadNode1.LightBlock(ctx, 2) require.NotEqual(t, lb1.Hash(), l1.Hash()) @@ -1020,8 +998,7 @@ func TestClient(t *testing.T) { mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) mockFullNode.On("ID", mock.Anything, mock.Anything).Return(id3, nil) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() lb1, _ := mockBadNode1.LightBlock(ctx, 2) require.NotEqual(t, lb1.Hash(), l1.Hash()) @@ -1058,8 +1035,7 @@ func TestClient(t *testing.T) { mockBadNode2.AssertExpectations(t) }) t.Run("TrustedValidatorSet", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -1121,8 +1097,7 @@ func TestClient(t *testing.T) { 0: vals, }) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() c, err := light.NewClient( @@ -1210,8 +1185,7 @@ func TestClient(t *testing.T) { for i, tc := range testCases { testCase := tc t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) if testCase.errorToThrow != nil { diff --git a/light/detector_test.go b/light/detector_test.go index bc343108a..02142b833 100644 --- a/light/detector_test.go +++ b/light/detector_test.go @@ -2,7 +2,6 @@ package light_test import ( "bytes" - "context" "testing" "time" @@ -33,8 +32,7 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { primaryValidators = make(map[int64]*types.ValidatorSet, latestHeight) ) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, latestHeight, valSize, 2, bTime) @@ -135,28 +133,23 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { }, } - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() - for _, tc := range cases { - testCase := tc - t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() logger := log.NewNopLogger() // primary performs an equivocation attack var ( valSize = 5 - primaryHeaders = make(map[int64]*types.SignedHeader, testCase.latestHeight) + primaryHeaders = make(map[int64]*types.SignedHeader, tc.latestHeight) // validators don't change in this network (however we still use a map just for convenience) - primaryValidators = make(map[int64]*types.ValidatorSet, testCase.latestHeight) + primaryValidators = make(map[int64]*types.ValidatorSet, tc.latestHeight) ) witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, - testCase.latestHeight+1, valSize, 2, bTime) - for height := int64(1); height <= testCase.latestHeight; height++ { - if height < testCase.divergenceHeight { + tc.latestHeight+1, valSize, 2, bTime) + for height := int64(1); height <= tc.latestHeight; height++ { + if height < tc.divergenceHeight { primaryHeaders[height] = witnessHeaders[height] primaryValidators[height] = witnessValidators[height] continue @@ -170,12 +163,12 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { primaryValidators[height] = witnessValidators[height] } - for _, height := range testCase.unusedWitnessBlockHeights { + for _, height := range tc.unusedWitnessBlockHeights { delete(witnessHeaders, height) } mockWitness := mockNodeFromHeadersAndVals(witnessHeaders, witnessValidators) - for _, height := range testCase.unusedPrimaryBlockHeights { + for _, height := range tc.unusedPrimaryBlockHeights { delete(primaryHeaders, height) } mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryValidators) @@ -186,20 +179,20 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { mockWitness.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { evAgainstPrimary := &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ - SignedHeader: primaryHeaders[testCase.divergenceHeight], - ValidatorSet: primaryValidators[testCase.divergenceHeight], + SignedHeader: primaryHeaders[tc.divergenceHeight], + ValidatorSet: primaryValidators[tc.divergenceHeight], }, - CommonHeight: testCase.divergenceHeight, + CommonHeight: tc.divergenceHeight, } return bytes.Equal(evidence.Hash(), evAgainstPrimary.Hash()) })).Return(nil) mockPrimary.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { evAgainstWitness := &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ - SignedHeader: witnessHeaders[testCase.divergenceHeight], - ValidatorSet: witnessValidators[testCase.divergenceHeight], + SignedHeader: witnessHeaders[tc.divergenceHeight], + ValidatorSet: witnessValidators[tc.divergenceHeight], }, - CommonHeight: testCase.divergenceHeight, + CommonHeight: tc.divergenceHeight, } return bytes.Equal(evidence.Hash(), evAgainstWitness.Hash()) })).Return(nil) @@ -217,12 +210,12 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { dbs.New(dbm.NewMemDB()), 5*time.Minute, light.Logger(logger), - testCase.lightOption, + tc.lightOption, ) require.NoError(t, err) // Check verification returns an error. - _, err = c.VerifyLightBlockAtHeight(ctx, testCase.latestHeight, bTime.Add(300*time.Second)) + _, err = c.VerifyLightBlockAtHeight(ctx, tc.latestHeight, bTime.Add(300*time.Second)) if assert.Error(t, err) { assert.Equal(t, light.ErrLightClientAttack, err) } @@ -245,8 +238,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { primaryValidators = make(map[int64]*types.ValidatorSet, forgedHeight) ) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, latestHeight, valSize, 2, bTime) @@ -393,8 +385,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { // => light client returns an error upon creation because primary and witness // have a different view. func TestClientDivergentTraces1(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() headers, vals, _ := genLightBlocksWithKeys(t, 1, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(headers, vals) @@ -429,8 +420,7 @@ func TestClientDivergentTraces1(t *testing.T) { // 2. Two out of three nodes don't respond but the third has a header that matches // => verification should be successful but two unresponsive witnesses should be blacklisted func TestClientDivergentTraces2(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() headers, vals, _ := genLightBlocksWithKeys(t, 2, 5, 2, bTime) @@ -473,8 +463,7 @@ func TestClientDivergentTraces3(t *testing.T) { primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) @@ -517,8 +506,7 @@ func TestClientDivergentTraces4(t *testing.T) { primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) diff --git a/light/dispatcher_test.go b/light/dispatcher_test.go index f0594244f..57fc042f8 100644 --- a/light/dispatcher_test.go +++ b/light/dispatcher_test.go @@ -39,8 +39,7 @@ func TestDispatcherBasic(t *testing.T) { t.Cleanup(leaktest.Check(t)) const numPeers = 5 - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() chans, ch := testChannel(100) @@ -104,8 +103,7 @@ func TestDispatcherReturnsNoBlock(t *testing.T) { func TestDispatcherTimeOutWaitingOnLightBlock(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() _, ch := testChannel(100) d := NewDispatcher(ch, func(height uint64) proto.Message { @@ -130,8 +128,7 @@ func TestDispatcherProviders(t *testing.T) { chainID := "test-chain" - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() chans, ch := testChannel(100) @@ -160,8 +157,7 @@ func TestDispatcherProviders(t *testing.T) { func TestPeerListBasic(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() peerList := NewPeerList() assert.Zero(t, peerList.Len()) @@ -207,8 +203,7 @@ func TestPeerListBlocksWhenEmpty(t *testing.T) { peerList := NewPeerList() require.Zero(t, peerList.Len()) doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() go func() { peerList.Pop(ctx) close(doneCh) @@ -226,8 +221,7 @@ func TestEmptyPeerListReturnsWhenContextCanceled(t *testing.T) { require.Zero(t, peerList.Len()) doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() wrapped, cancel := context.WithCancel(ctx) go func() { diff --git a/light/example_test.go b/light/example_test.go index 74b9f27b3..1cae469b9 100644 --- a/light/example_test.go +++ b/light/example_test.go @@ -1,7 +1,6 @@ package light_test import ( - "context" "github.com/tendermint/tendermint/light/provider" "testing" "time" @@ -18,8 +17,7 @@ import ( // Manually getting light blocks and verifying them. func TestExampleClient(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() conf, err := rpctest.CreateConfig(t, "ExampleClient_VerifyLightBlockAtHeight") if err != nil { t.Fatal(err) diff --git a/light/light_test.go b/light/light_test.go index c7a877785..e4fc42990 100644 --- a/light/light_test.go +++ b/light/light_test.go @@ -26,8 +26,7 @@ import ( func TestClientIntegration_Update(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() conf, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) @@ -90,8 +89,7 @@ func TestClientIntegration_Update(t *testing.T) { // Manually getting light blocks and verifying them. func TestClientIntegration_VerifyLightBlockAtHeight(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() conf, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) @@ -168,8 +166,7 @@ func waitForBlock(ctx context.Context, p provider.Provider, height int64) (*type } func TestClientStatusRPC(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() conf, err := rpctest.CreateConfig(t, t.Name()) require.NoError(t, err) diff --git a/light/store/db/db_test.go b/light/store/db/db_test.go index e647907c0..efb329e7f 100644 --- a/light/store/db/db_test.go +++ b/light/store/db/db_test.go @@ -19,8 +19,7 @@ import ( func TestLast_FirstLightBlockHeight(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Empty store height, err := dbStore.LastLightBlockHeight() @@ -46,8 +45,7 @@ func TestLast_FirstLightBlockHeight(t *testing.T) { func Test_SaveLightBlock(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Empty store h, err := dbStore.LightBlock(1) @@ -78,8 +76,7 @@ func Test_SaveLightBlock(t *testing.T) { func Test_LightBlockBefore(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() assert.Panics(t, func() { _, _ = dbStore.LightBlockBefore(0) @@ -101,8 +98,7 @@ func Test_LightBlockBefore(t *testing.T) { func Test_Prune(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Empty store assert.EqualValues(t, 0, dbStore.Size()) @@ -141,8 +137,7 @@ func Test_Prune(t *testing.T) { func Test_Concurrency(t *testing.T) { dbStore := New(dbm.NewMemDB()) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() var wg sync.WaitGroup for i := 1; i <= 100; i++ { diff --git a/node/node_test.go b/node/node_test.go index b90be8ff7..1992546a1 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -45,8 +45,7 @@ func TestNodeStartStop(t *testing.T) { defer os.RemoveAll(cfg.RootDir) - ctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() + ctx := t.Context() logger := log.NewNopLogger() // create & start node @@ -56,7 +55,6 @@ func TestNodeStartStop(t *testing.T) { n, ok := ns.(*nodeImpl) require.True(t, ok) t.Cleanup(func() { - bcancel() n.Wait() }) t.Cleanup(leaktest.CheckTimeout(t, time.Second)) @@ -75,17 +73,14 @@ func TestNodeStartStop(t *testing.T) { _, err = blocksSub.Next(tctx) require.NoError(t, err, "waiting for event") - cancel() // stop the subscription context - bcancel() // stop the base context - n.Wait() - - require.False(t, n.IsRunning(), "node must shut down") + t.Cleanup(func() { + n.Wait() + require.False(t, n.IsRunning(), "node must shut down") + }) } func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger log.Logger) *nodeImpl { t.Helper() - ctx, cancel := context.WithCancel(ctx) - defer cancel() ns, err := newDefaultNode(ctx, conf, logger, make(chan struct{})) require.NoError(t, err) @@ -94,7 +89,6 @@ func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger require.True(t, ok) t.Cleanup(func() { - cancel() if n.IsRunning() { ns.Wait() } @@ -112,8 +106,7 @@ func TestNodeDelayedStart(t *testing.T) { defer os.RemoveAll(cfg.RootDir) now := tmtime.Now() - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -132,8 +125,7 @@ func TestNodeSetAppVersion(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -158,8 +150,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { addr := "tcp://" + testFreeAddr(t) t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -195,8 +186,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { // address without a protocol must result in error func TestPrivValidatorListenAddrNoProtocol(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() addrNoPrefix := testFreeAddr(t) @@ -212,7 +202,6 @@ func TestPrivValidatorListenAddrNoProtocol(t *testing.T) { assert.Error(t, err) if n != nil && n.IsRunning() { - cancel() n.Wait() } } @@ -221,8 +210,7 @@ func TestNodeSetPrivValIPC(t *testing.T) { tmpfile := "/tmp/kms." + tmrand.Str(6) + ".sock" defer os.Remove(tmpfile) // clean up - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := config.ResetTestRoot(t.TempDir(), "node_priv_val_tcp_test") require.NoError(t, err) @@ -268,8 +256,7 @@ func testFreeAddr(t *testing.T) string { // create a proposal block using real and full // mempool and evidence pool and validate it. func TestCreateProposalBlock(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := config.ResetTestRoot(t.TempDir(), "node_create_proposal") require.NoError(t, err) @@ -369,8 +356,7 @@ func TestCreateProposalBlock(t *testing.T) { } func TestMaxTxsProposalBlockSize(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := config.ResetTestRoot(t.TempDir(), "node_create_proposal") require.NoError(t, err) @@ -443,8 +429,7 @@ func TestMaxTxsProposalBlockSize(t *testing.T) { } func TestMaxProposalBlockSize(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() cfg, err := config.ResetTestRoot(t.TempDir(), "node_create_proposal") require.NoError(t, err) @@ -592,8 +577,7 @@ func TestNodeNewSeedNode(t *testing.T) { cfg.Mode = config.ModeSeed defer os.RemoveAll(cfg.RootDir) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() nodeKey, err := types.LoadOrGenNodeKey(cfg.NodeKeyFile()) require.NoError(t, err) @@ -622,7 +606,6 @@ func TestNodeNewSeedNode(t *testing.T) { require.NoError(t, err) assert.True(t, n.pexReactor.IsRunning()) - cancel() n.Wait() assert.False(t, n.pexReactor.IsRunning()) @@ -634,8 +617,7 @@ func TestNodeSetEventSink(t *testing.T) { defer os.RemoveAll(cfg.RootDir) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -664,7 +646,6 @@ func TestNodeSetEventSink(t *testing.T) { if !n.IsRunning() { return } - cancel() n.Wait() } } @@ -756,8 +737,7 @@ func state(t *testing.T, nVals int, height int64) (sm.State, dbm.DB, []types.Pri } func TestLoadStateFromGenesis(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() _ = loadStatefromGenesis(ctx, t) } diff --git a/privval/file_test.go b/privval/file_test.go index f824a4d9c..8f412e917 100644 --- a/privval/file_test.go +++ b/privval/file_test.go @@ -1,7 +1,6 @@ package privval import ( - "context" "encoding/base64" "encoding/json" "fmt" @@ -36,8 +35,7 @@ func TestGenLoadValidator(t *testing.T) { } func TestResetValidator(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() privVal, _, tempStateFileName := newTestFilePV(t) emptyState := FilePVLastSignState{filePath: tempStateFileName} @@ -146,8 +144,7 @@ func TestUnmarshalValidatorKey(t *testing.T) { } func TestSignVote(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() privVal, _, _ := newTestFilePV(t) @@ -195,8 +192,7 @@ func TestSignVote(t *testing.T) { } func TestSignProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() privVal, _, _ := newTestFilePV(t) @@ -237,8 +233,7 @@ func TestSignProposal(t *testing.T) { } func TestDifferByTimestamp(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() tempKeyFile, err := os.CreateTemp(t.TempDir(), "priv_validator_key_") require.NoError(t, err) @@ -278,8 +273,7 @@ func TestDifferByTimestamp(t *testing.T) { } func TestVoteExtensionsAreAlwaysSigned(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() privVal, _, _ := newTestFilePV(t) pubKey, err := privVal.GetPubKey(ctx) diff --git a/privval/grpc/client_test.go b/privval/grpc/client_test.go index 827303a43..cd2f6ba3b 100644 --- a/privval/grpc/client_test.go +++ b/privval/grpc/client_test.go @@ -41,8 +41,7 @@ func dialer(t *testing.T, pv types.PrivValidator, logger log.Logger) (*grpc.Serv func TestSignerClient_GetPubKey(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockPV := types.NewMockPV() logger := log.NewTestingLogger(t) @@ -65,8 +64,7 @@ func TestSignerClient_GetPubKey(t *testing.T) { } func TestSignerClient_SignVote(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockPV := types.NewMockPV() logger := log.NewTestingLogger(t) @@ -120,8 +118,7 @@ func TestSignerClient_SignVote(t *testing.T) { } func TestSignerClient_SignProposal(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() mockPV := types.NewMockPV() logger := log.NewTestingLogger(t) diff --git a/privval/grpc/server_test.go b/privval/grpc/server_test.go index 9e80b7534..f3d75ffa8 100644 --- a/privval/grpc/server_test.go +++ b/privval/grpc/server_test.go @@ -1,7 +1,6 @@ package grpc_test import ( - "context" "testing" "time" @@ -33,8 +32,7 @@ func TestGetPubKey(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewTestingLogger(t) s := tmgrpc.NewSignerServer(logger, ChainID, tc.pv) @@ -108,8 +106,7 @@ func TestSignVote(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewTestingLogger(t) s := tmgrpc.NewSignerServer(logger, ChainID, tc.pv) @@ -179,8 +176,7 @@ func TestSignProposal(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewTestingLogger(t) s := tmgrpc.NewSignerServer(logger, ChainID, tc.pv) diff --git a/privval/signer_client_test.go b/privval/signer_client_test.go index 2c5ade79b..a293b6378 100644 --- a/privval/signer_client_test.go +++ b/privval/signer_client_test.go @@ -65,8 +65,7 @@ func getSignerTestCases(ctx context.Context, t *testing.T, logger log.Logger) [] func TestSignerClose(t *testing.T) { t.Cleanup(leaktest.Check(t)) - bctx, bcancel := context.WithCancel(t.Context()) - defer bcancel() + bctx := t.Context() logger := log.NewNopLogger() @@ -88,8 +87,7 @@ func TestSignerClose(t *testing.T) { func TestSignerPing(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -102,8 +100,7 @@ func TestSignerPing(t *testing.T) { func TestSignerGetPubKey(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -132,8 +129,7 @@ func TestSignerGetPubKey(t *testing.T) { func TestSignerProposal(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -172,8 +168,7 @@ func TestSignerProposal(t *testing.T) { func TestSignerVote(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -215,8 +210,7 @@ func TestSignerVote(t *testing.T) { func TestSignerVoteResetDeadline(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -266,8 +260,7 @@ func TestSignerVoteResetDeadline(t *testing.T) { func TestSignerVoteKeepAlive(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -316,8 +309,7 @@ func TestSignerVoteKeepAlive(t *testing.T) { func TestSignerSignProposalErrors(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -357,8 +349,7 @@ func TestSignerSignProposalErrors(t *testing.T) { func TestSignerSignVoteErrors(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -423,8 +414,7 @@ func brokenHandler(ctx context.Context, privVal types.PrivValidator, request pri func TestSignerUnexpectedResponse(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 21f177013..8419530da 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -42,8 +42,7 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() @@ -94,8 +93,7 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { func TestRetryConnToRemoteSigner(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewNopLogger() diff --git a/rpc/client/examples_test.go b/rpc/client/examples_test.go index be514920a..5d989b678 100644 --- a/rpc/client/examples_test.go +++ b/rpc/client/examples_test.go @@ -2,7 +2,6 @@ package client_test import ( "bytes" - "context" "log" "net/http" "testing" @@ -18,8 +17,7 @@ import ( ) func TestHTTPSimple(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Start a tendermint node (and kvstore) in the background to test against app := kvstore.NewApplication() @@ -68,8 +66,7 @@ func TestHTTPSimple(t *testing.T) { } func TestHTTPBatching(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // Start a tendermint node (and kvstore) in the background to test against app := kvstore.NewApplication() diff --git a/rpc/client/helpers_test.go b/rpc/client/helpers_test.go index a521ea829..d643fccb2 100644 --- a/rpc/client/helpers_test.go +++ b/rpc/client/helpers_test.go @@ -1,7 +1,6 @@ package client_test import ( - "context" "errors" "strings" "testing" @@ -15,8 +14,7 @@ import ( ) func TestWaitForHeight(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // test with error result - immediate failure m := &mock.StatusMock{ diff --git a/rpc/client/mock/abci_test.go b/rpc/client/mock/abci_test.go index d218cbc75..e3e7c0a90 100644 --- a/rpc/client/mock/abci_test.go +++ b/rpc/client/mock/abci_test.go @@ -1,7 +1,6 @@ package mock_test import ( - "context" "errors" "fmt" "testing" @@ -19,8 +18,7 @@ import ( ) func TestABCIMock(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() key, value := []byte("foo"), []byte("bar") height := int64(10) @@ -80,8 +78,7 @@ func TestABCIMock(t *testing.T) { } func TestABCIRecorder(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // This mock returns errors on everything but Query m := mock.ABCIMock{ @@ -165,8 +162,7 @@ func TestABCIApp(t *testing.T) { app := kvstore.NewApplication() m := mock.ABCIApp{app} - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() // get some info info, err := m.ABCIInfo(ctx) diff --git a/rpc/client/mock/status_test.go b/rpc/client/mock/status_test.go index 055a99c6a..a7b76417f 100644 --- a/rpc/client/mock/status_test.go +++ b/rpc/client/mock/status_test.go @@ -1,7 +1,6 @@ package mock_test import ( - "context" "testing" "time" @@ -14,8 +13,7 @@ import ( ) func TestStatus(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() m := &mock.StatusMock{ Call: mock.Call{ diff --git a/rpc/client/rpc_test.go b/rpc/client/rpc_test.go index c8ffaead4..dd14d9071 100644 --- a/rpc/client/rpc_test.go +++ b/rpc/client/rpc_test.go @@ -40,12 +40,11 @@ func getHTTPClient(t *testing.T, logger log.Logger, conf *config.Config) *rpchtt rpcAddr := conf.RPC.ListenAddress c, err := rpchttp.NewWithClient(rpcAddr, http.DefaultClient) require.NoError(t, err) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() require.NoError(t, c.Start(ctx)) c.Logger = logger t.Cleanup(func() { - cancel() require.NoError(t, c.Stop()) }) @@ -60,12 +59,11 @@ func getHTTPClientWithTimeout(t *testing.T, logger log.Logger, conf *config.Conf tclient := &http.Client{Timeout: timeout} c, err := rpchttp.NewWithClient(rpcAddr, tclient) require.NoError(t, err) - ctx, cancel := context.WithCancel(t.Context()) + ctx := t.Context() require.NoError(t, c.Start(ctx)) c.Logger = logger t.Cleanup(func() { - cancel() require.NoError(t, c.Stop()) }) @@ -90,8 +88,7 @@ func GetClients(t *testing.T, ns service.Service, conf *config.Config) []client. } func TestClientOperations(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewTestingLogger(t) @@ -189,11 +186,9 @@ func TestClientOperations(t *testing.T) { // Make sure info is correct (we connect properly) func TestClientMethodCalls(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() logger := log.NewTestingLogger(t) - n, conf := NodeSuite(ctx, t, logger) + n, conf := NodeSuite(t.Context(), t, logger) // for broadcast tx tests pool := getMempool(t, n) @@ -205,11 +200,12 @@ func TestClientMethodCalls(t *testing.T) { for i, c := range GetClients(t, n, conf) { t.Run(fmt.Sprintf("%T", c), func(t *testing.T) { t.Run("Status", func(t *testing.T) { - status, err := c.Status(ctx) + status, err := c.Status(t.Context()) require.NoError(t, err, "%d: %+v", i, err) assert.Equal(t, conf.Moniker, status.NodeInfo.Moniker) }) t.Run("Info", func(t *testing.T) { + ctx := t.Context() info, err := c.ABCIInfo(ctx) require.NoError(t, err) @@ -222,7 +218,7 @@ func TestClientMethodCalls(t *testing.T) { t.Run("NetInfo", func(t *testing.T) { nc, ok := c.(client.NetworkClient) require.True(t, ok, "%d", i) - netinfo, err := nc.NetInfo(ctx) + netinfo, err := nc.NetInfo(t.Context()) require.NoError(t, err, "%d: %+v", i, err) assert.True(t, netinfo.Listening) assert.Equal(t, 0, len(netinfo.Peers)) @@ -231,7 +227,7 @@ func TestClientMethodCalls(t *testing.T) { // FIXME: fix server so it doesn't panic on invalid input nc, ok := c.(client.NetworkClient) require.True(t, ok, "%d", i) - cons, err := nc.DumpConsensusState(ctx) + cons, err := nc.DumpConsensusState(t.Context()) require.NoError(t, err, "%d: %+v", i, err) assert.NotEmpty(t, cons.RoundState) assert.Empty(t, cons.Peers) @@ -240,17 +236,18 @@ func TestClientMethodCalls(t *testing.T) { // FIXME: fix server so it doesn't panic on invalid input nc, ok := c.(client.NetworkClient) require.True(t, ok, "%d", i) - cons, err := nc.ConsensusState(ctx) + cons, err := nc.ConsensusState(t.Context()) require.NoError(t, err, "%d: %+v", i, err) assert.NotEmpty(t, cons.RoundState) }) t.Run("Health", func(t *testing.T) { nc, ok := c.(client.NetworkClient) require.True(t, ok, "%d", i) - _, err := nc.Health(ctx) + _, err := nc.Health(t.Context()) require.NoError(t, err, "%d: %+v", i, err) }) t.Run("GenesisAndValidators", func(t *testing.T) { + ctx := t.Context() // make sure this is the right genesis file gen, err := c.Genesis(ctx) require.NoError(t, err, "%d: %+v", i, err) @@ -272,6 +269,7 @@ func TestClientMethodCalls(t *testing.T) { assert.Equal(t, gval.PubKey, val.PubKey) }) t.Run("GenesisChunked", func(t *testing.T) { + ctx := t.Context() first, err := c.GenesisChunked(ctx, 0) require.NoError(t, err) @@ -291,6 +289,7 @@ func TestClientMethodCalls(t *testing.T) { "first: %+v, doc: %s", first, string(doc)) }) t.Run("ABCIQuery", func(t *testing.T) { + ctx := t.Context() // write something k, v, tx := MakeTxKV() status, err := c.Status(ctx) @@ -309,6 +308,7 @@ func TestClientMethodCalls(t *testing.T) { } }) t.Run("AppCalls", func(t *testing.T) { + ctx := t.Context() // get an offset of height to avoid racing and guessing s, err := c.Status(ctx) require.NoError(t, err) @@ -409,8 +409,7 @@ func TestClientMethodCalls(t *testing.T) { // XXX Test proof }) t.Run("BlockchainInfo", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() err := client.WaitForHeight(ctx, c, 10, nil) require.NoError(t, err) @@ -439,6 +438,7 @@ func TestClientMethodCalls(t *testing.T) { assert.Contains(t, err.Error(), "can't be greater than max") }) t.Run("BroadcastTxCommit", func(t *testing.T) { + ctx := t.Context() _, _, tx := MakeTxKV() bres, err := c.BroadcastTxCommit(ctx, tx) require.NoError(t, err, "%d: %+v", i, err) @@ -448,6 +448,7 @@ func TestClientMethodCalls(t *testing.T) { require.Equal(t, 0, pool.Size()) }) t.Run("BroadcastTxSync", func(t *testing.T) { + ctx := t.Context() _, _, tx := MakeTxKV() initMempoolSize := pool.Size() bres, err := c.BroadcastTxSync(ctx, tx) @@ -461,6 +462,7 @@ func TestClientMethodCalls(t *testing.T) { pool.Flush() }) t.Run("CheckTx", func(t *testing.T) { + ctx := t.Context() _, _, tx := MakeTxKV() res, err := c.CheckTx(ctx, tx) @@ -471,7 +473,7 @@ func TestClientMethodCalls(t *testing.T) { }) t.Run("Events", func(t *testing.T) { t.Run("Header", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, waitForEventTimeout) + ctx, cancel := context.WithTimeout(t.Context(), waitForEventTimeout) defer cancel() query := types.QueryForEvent(types.EventNewBlockHeaderValue).String() evt, err := client.WaitForOneEvent(ctx, c, query) @@ -481,6 +483,7 @@ func TestClientMethodCalls(t *testing.T) { // TODO: more checks... }) t.Run("Block", func(t *testing.T) { + ctx := t.Context() const subscriber = "TestBlockEvents" eventCh, err := c.Subscribe(ctx, subscriber, types.QueryForEvent(types.EventNewBlockValue).String()) @@ -515,16 +518,17 @@ func TestClientMethodCalls(t *testing.T) { } }) t.Run("BroadcastTxAsync", func(t *testing.T) { + ctx := t.Context() testTxEventsSent(ctx, t, "async", c) }) t.Run("BroadcastTxSync", func(t *testing.T) { + ctx := t.Context() testTxEventsSent(ctx, t, "sync", c) }) }) t.Run("Evidence", func(t *testing.T) { t.Run("BroadcastDuplicateVote", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() chainID := conf.ChainID() @@ -567,6 +571,7 @@ func TestClientMethodCalls(t *testing.T) { } }) t.Run("BroadcastEmpty", func(t *testing.T) { + ctx := t.Context() _, err := c.BroadcastEvidence(ctx, nil) require.Error(t, err) }) @@ -589,8 +594,7 @@ func getMempool(t *testing.T, srv service.Service) mempool.Mempool { // so making a separate suite makes more sense, though isn't strictly // speaking desirable. func TestClientMethodCallsAdvanced(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx := t.Context() logger := log.NewTestingLogger(t) diff --git a/rpc/jsonrpc/client/integration_test.go b/rpc/jsonrpc/client/integration_test.go index 5e7023ef8..75e8937b6 100644 --- a/rpc/jsonrpc/client/integration_test.go +++ b/rpc/jsonrpc/client/integration_test.go @@ -23,8 +23,7 @@ func TestWSClientReconnectWithJitter(t *testing.T) { const maxReconnectAttempts = 3 const maxSleepTime = time.Duration(((1< Date: Mon, 18 Aug 2025 19:22:17 +0200 Subject: [PATCH 05/41] fix --- node/node_test.go | 6 ++++-- rpc/jsonrpc/server/http_server_test.go | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/node/node_test.go b/node/node_test.go index 1992546a1..a550020d3 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -606,9 +606,11 @@ func TestNodeNewSeedNode(t *testing.T) { require.NoError(t, err) assert.True(t, n.pexReactor.IsRunning()) - n.Wait() + t.Cleanup(func() { + n.Wait() - assert.False(t, n.pexReactor.IsRunning()) + assert.False(t, n.pexReactor.IsRunning()) + }) } func TestNodeSetEventSink(t *testing.T) { diff --git a/rpc/jsonrpc/server/http_server_test.go b/rpc/jsonrpc/server/http_server_test.go index 7781ceb03..bb72fd0ec 100644 --- a/rpc/jsonrpc/server/http_server_test.go +++ b/rpc/jsonrpc/server/http_server_test.go @@ -111,6 +111,9 @@ func TestServeTLS(t *testing.T) { TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } c := &http.Client{Transport: tr} + // We need this, because http.Transport is trying to be smart and keeps the connection open + // after res.Body.Close(). + defer c.CloseIdleConnections() res, err := c.Get("https://" + ln.Addr().String()) require.NoError(t, err) defer res.Body.Close() From 4cfdd5514e2f169d4b241560798ed29b37f8ca36 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 18 Aug 2025 19:39:44 +0200 Subject: [PATCH 06/41] tc := tc --- config/config_test.go | 1 - crypto/merkle/proof_test.go | 1 - crypto/merkle/rfc6962_test.go | 1 - crypto/merkle/tree_test.go | 1 - internal/consensus/msgs_test.go | 9 --------- internal/consensus/replay_test.go | 1 - internal/evidence/pool_test.go | 1 - internal/evidence/reactor_test.go | 1 - internal/mempool/priority_queue_test.go | 1 - internal/p2p/conn/connection_test.go | 1 - internal/p2p/conn/evil_secret_connection_test.go | 1 - internal/p2p/peermanager_test.go | 1 - internal/p2p/router_test.go | 2 -- internal/p2p/transport_mconn_test.go | 1 - internal/p2p/transport_test.go | 4 ---- internal/state/execution_test.go | 2 -- internal/state/indexer/block/kv/kv_test.go | 1 - internal/state/indexer/sink/kv/kv_test.go | 2 -- internal/state/indexer/tx/kv/kv_test.go | 2 -- internal/state/store_test.go | 1 - internal/statesync/chunks_test.go | 1 - internal/statesync/snapshots_test.go | 1 - libs/bits/bit_array_test.go | 1 - libs/bytes/bytes_test.go | 1 - libs/json/decoder_test.go | 1 - libs/json/encoder_test.go | 1 - libs/log/default_test.go | 2 -- light/client_test.go | 1 - light/verifier_test.go | 2 -- privval/grpc/server_test.go | 3 --- privval/msgs_test.go | 1 - proto/tendermint/blocksync/message_test.go | 5 ----- proto/tendermint/statesync/message_test.go | 3 --- types/block_test.go | 9 --------- types/evidence_test.go | 3 --- types/light_test.go | 1 - types/node_info_test.go | 1 - types/part_set_test.go | 2 -- 38 files changed, 74 deletions(-) diff --git a/config/config_test.go b/config/config_test.go index c5e8a56be..82fdd6606 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -124,7 +124,6 @@ func TestConsensusConfig_ValidateBasic(t *testing.T) { "DoubleSignCheckHeight negative": {func(c *ConsensusConfig) { c.DoubleSignCheckHeight = -1 }, true}, } for desc, tc := range testcases { - tc := tc // appease linter t.Run(desc, func(t *testing.T) { cfg := DefaultConsensusConfig() tc.modify(cfg) diff --git a/crypto/merkle/proof_test.go b/crypto/merkle/proof_test.go index 05a5ca369..f77b65314 100644 --- a/crypto/merkle/proof_test.go +++ b/crypto/merkle/proof_test.go @@ -155,7 +155,6 @@ func TestProofValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { _, proofs := ProofsFromByteSlices([][]byte{ []byte("apple"), diff --git a/crypto/merkle/rfc6962_test.go b/crypto/merkle/rfc6962_test.go index 7a70dbb91..6fb1b30f0 100644 --- a/crypto/merkle/rfc6962_test.go +++ b/crypto/merkle/rfc6962_test.go @@ -63,7 +63,6 @@ func TestRFC6962Hasher(t *testing.T) { got: innerHash([]byte("N123"), []byte("N456")), }, } { - tc := tc t.Run(tc.desc, func(t *testing.T) { wantBytes, err := hex.DecodeString(tc.want) if err != nil { diff --git a/crypto/merkle/tree_test.go b/crypto/merkle/tree_test.go index 56ed4f9bd..a9de1ce54 100644 --- a/crypto/merkle/tree_test.go +++ b/crypto/merkle/tree_test.go @@ -38,7 +38,6 @@ func TestHashFromByteSlices(t *testing.T) { }, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { hash := HashFromByteSlices(tc.slices) assert.Equal(t, tc.expectHash, hex.EncodeToString(hash)) diff --git a/internal/consensus/msgs_test.go b/internal/consensus/msgs_test.go index 33ca4c496..0f2d75cf4 100644 --- a/internal/consensus/msgs_test.go +++ b/internal/consensus/msgs_test.go @@ -431,7 +431,6 @@ func TestConsMsgsVectors(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { bz, err := proto.Marshal(tc.cMsg) require.NoError(t, err) @@ -472,7 +471,6 @@ func TestVoteSetMaj23MessageValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { message := VoteSetMaj23Message{ Height: tc.messageHeight, @@ -508,7 +506,6 @@ func TestVoteSetBitsMessageValidateBasic(t *testing.T) { } for i, tc := range testCases { - tc := tc t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { msg := &VoteSetBitsMessage{ Height: 1, @@ -546,7 +543,6 @@ func TestNewRoundStepMessageValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { message := NewRoundStepMessage{ Height: tc.messageHeight, @@ -581,7 +577,6 @@ func TestNewRoundStepMessageValidateHeight(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { message := NewRoundStepMessage{ Height: tc.messageHeight, @@ -626,7 +621,6 @@ func TestNewValidBlockMessageValidateBasic(t *testing.T) { } for i, tc := range testCases { - tc := tc t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { msg := &NewValidBlockMessage{ Height: 1, @@ -660,7 +654,6 @@ func TestProposalPOLMessageValidateBasic(t *testing.T) { } for i, tc := range testCases { - tc := tc t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { msg := &ProposalPOLMessage{ Height: 1, @@ -693,7 +686,6 @@ func TestBlockPartMessageValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { message := BlockPartMessage{ Height: tc.messageHeight, @@ -733,7 +725,6 @@ func TestHasVoteMessageValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { message := HasVoteMessage{ Height: tc.messageHeight, diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index 3f9b94935..86dcd1957 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -134,7 +134,6 @@ func TestWALCrash(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { ctx := t.Context() diff --git a/internal/evidence/pool_test.go b/internal/evidence/pool_test.go index 1e0ec8fa6..e9334a365 100644 --- a/internal/evidence/pool_test.go +++ b/internal/evidence/pool_test.go @@ -151,7 +151,6 @@ func TestAddExpiredEvidence(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.evDescription, func(t *testing.T) { ctx := t.Context() diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index 8846ebed5..3d898dc59 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -549,7 +549,6 @@ func TestEvidenceListSerialization(t *testing.T) { } for name, tc := range testCases { - tc := tc t.Run(name, func(t *testing.T) { protoEv := make([]tmproto.Evidence, len(tc.evidenceList)) diff --git a/internal/mempool/priority_queue_test.go b/internal/mempool/priority_queue_test.go index 0c28b4fa3..03ec647af 100644 --- a/internal/mempool/priority_queue_test.go +++ b/internal/mempool/priority_queue_test.go @@ -300,7 +300,6 @@ func TestTxPriorityQueue_GetEvictableTxs(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { evictTxs := pq.GetEvictableTxs(tc.priority, tc.txSize, tc.totalSize, tc.cap) diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index ca1e7159b..72e65a1a4 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -582,7 +582,6 @@ func TestConnVectors(t *testing.T) { } for _, tc := range testCases { - tc := tc pm := mustWrapPacket(tc.msg) bz, err := pm.Marshal() diff --git a/internal/p2p/conn/evil_secret_connection_test.go b/internal/p2p/conn/evil_secret_connection_test.go index 05e88cd85..3ad61bb1c 100644 --- a/internal/p2p/conn/evil_secret_connection_test.go +++ b/internal/p2p/conn/evil_secret_connection_test.go @@ -255,7 +255,6 @@ func TestMakeSecretConnection(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { privKey := ed25519.GenPrivKey() _, err := MakeSecretConnection(tc.conn, privKey) diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 2933d5afa..6e60889a0 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -100,7 +100,6 @@ func TestPeerManagerOptions_Validate(t *testing.T) { }, false}, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { err := tc.options.Validate() if tc.ok { diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 8cd85abec..0172bb114 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -366,7 +366,6 @@ func TestRouter_AcceptPeers(t *testing.T) { } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { ctx := t.Context() @@ -562,7 +561,6 @@ func TestRouter_DialPeers(t *testing.T) { } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) ctx := t.Context() diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 8830902f9..18d7f4fb3 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -145,7 +145,6 @@ func TestMConnTransport_Listen(t *testing.T) { {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero, Path: "foo"}, false}, } for _, tc := range testcases { - tc := tc t.Run(tc.endpoint.String(), func(t *testing.T) { t.Cleanup(leaktest.Check(t)) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index e939ca116..865215b22 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -119,7 +119,6 @@ func TestTransport_DialEndpoints(t *testing.T) { // Tests for networked endpoints (with IP). if len(endpoint.IP) > 0 && endpoint.Protocol != p2p.MemoryProtocol { for _, tc := range ipTestCases { - tc := tc t.Run(tc.ip.String(), func(t *testing.T) { e := endpoint require.NotNil(t, e) @@ -470,7 +469,6 @@ func TestEndpoint_NodeAddress(t *testing.T) { {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}}, } for _, tc := range testcases { - tc := tc t.Run(tc.endpoint.String(), func(t *testing.T) { // Without NodeID. expect := tc.expect @@ -522,7 +520,6 @@ func TestEndpoint_String(t *testing.T) { {p2p.Endpoint{Path: "foo"}, "/foo"}, } for _, tc := range testcases { - tc := tc t.Run(tc.expect, func(t *testing.T) { require.Equal(t, tc.expect, tc.endpoint.String()) }) @@ -556,7 +553,6 @@ func TestEndpoint_Validate(t *testing.T) { {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false}, } for _, tc := range testcases { - tc := tc t.Run(tc.endpoint.String(), func(t *testing.T) { err := tc.endpoint.Validate() if tc.expectValid { diff --git a/internal/state/execution_test.go b/internal/state/execution_test.go index d7ab0c582..5b80bc9bd 100644 --- a/internal/state/execution_test.go +++ b/internal/state/execution_test.go @@ -415,7 +415,6 @@ func TestValidateValidatorUpdates(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { err := sm.ValidateValidatorUpdates(tc.abciUpdates, tc.validatorParams) if tc.shouldErr { @@ -478,7 +477,6 @@ func TestUpdateValidators(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { updates, err := types.PB2TM.ValidatorUpdates(tc.abciUpdates) assert.NoError(t, err) diff --git a/internal/state/indexer/block/kv/kv_test.go b/internal/state/indexer/block/kv/kv_test.go index 8409da946..c1100afe4 100644 --- a/internal/state/indexer/block/kv/kv_test.go +++ b/internal/state/indexer/block/kv/kv_test.go @@ -130,7 +130,6 @@ func TestBlockIndexer(t *testing.T) { } for name, tc := range testCases { - tc := tc t.Run(name, func(t *testing.T) { ctx := t.Context() diff --git a/internal/state/indexer/sink/kv/kv_test.go b/internal/state/indexer/sink/kv/kv_test.go index 373431c0a..d658c62b6 100644 --- a/internal/state/indexer/sink/kv/kv_test.go +++ b/internal/state/indexer/sink/kv/kv_test.go @@ -142,7 +142,6 @@ func TestBlockFuncs(t *testing.T) { } for name, tc := range testCases { - tc := tc t.Run(name, func(t *testing.T) { ctx := t.Context() @@ -242,7 +241,6 @@ func TestTxSearchDeprecatedIndexing(t *testing.T) { ctx := t.Context() for _, tc := range testCases { - tc := tc t.Run(tc.q, func(t *testing.T) { results, err := indexer.SearchTxEvents(ctx, query.MustCompile(tc.q)) require.NoError(t, err) diff --git a/internal/state/indexer/tx/kv/kv_test.go b/internal/state/indexer/tx/kv/kv_test.go index ef547b81a..d2c7daec5 100644 --- a/internal/state/indexer/tx/kv/kv_test.go +++ b/internal/state/indexer/tx/kv/kv_test.go @@ -134,7 +134,6 @@ func TestTxSearch(t *testing.T) { ctx := t.Context() for _, tc := range testCases { - tc := tc t.Run(tc.q, func(t *testing.T) { results, err := indexer.Search(ctx, query.MustCompile(tc.q)) assert.NoError(t, err) @@ -233,7 +232,6 @@ func TestTxSearchDeprecatedIndexing(t *testing.T) { ctx := t.Context() for _, tc := range testCases { - tc := tc t.Run(tc.q, func(t *testing.T) { results, err := indexer.Search(ctx, query.MustCompile(tc.q)) require.NoError(t, err) diff --git a/internal/state/store_test.go b/internal/state/store_test.go index 722676c06..7c0a1ad9e 100644 --- a/internal/state/store_test.go +++ b/internal/state/store_test.go @@ -187,7 +187,6 @@ func TestPruneStates(t *testing.T) { "prune across checkpoint": {99900, 100002, 100002, false, 100000, 99995}, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { db := dbm.NewMemDB() diff --git a/internal/statesync/chunks_test.go b/internal/statesync/chunks_test.go index 8480b4dd8..3213a487c 100644 --- a/internal/statesync/chunks_test.go +++ b/internal/statesync/chunks_test.go @@ -125,7 +125,6 @@ func TestChunkQueue_Add_ChunkErrors(t *testing.T) { "invalid index": {&chunk{Height: 3, Format: 1, Index: 5, Chunk: []byte{3, 1, 0}}}, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { queue, teardown := setupChunkQueue(t) defer teardown() diff --git a/internal/statesync/snapshots_test.go b/internal/statesync/snapshots_test.go index 08cb08269..8fcf4d76c 100644 --- a/internal/statesync/snapshots_test.go +++ b/internal/statesync/snapshots_test.go @@ -19,7 +19,6 @@ func TestSnapshot_Key(t *testing.T) { "no metadata": {func(s *snapshot) { s.Metadata = nil }}, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { s := snapshot{ Height: 3, diff --git a/libs/bits/bit_array_test.go b/libs/bits/bit_array_test.go index 613f672a1..dbb1a8d97 100644 --- a/libs/bits/bit_array_test.go +++ b/libs/bits/bit_array_test.go @@ -241,7 +241,6 @@ func TestJSONMarshalUnmarshal(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.bA.String(), func(t *testing.T) { bz, err := json.Marshal(tc.bA) require.NoError(t, err) diff --git a/libs/bytes/bytes_test.go b/libs/bytes/bytes_test.go index 3dcd08100..4bf31925a 100644 --- a/libs/bytes/bytes_test.go +++ b/libs/bytes/bytes_test.go @@ -41,7 +41,6 @@ func TestJSONMarshal(t *testing.T) { } for i, tc := range cases { - tc := tc t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { ts := TestStruct{B1: tc.input, B2: tc.input} diff --git a/libs/json/decoder_test.go b/libs/json/decoder_test.go index 41faa1062..97b2bf710 100644 --- a/libs/json/decoder_test.go +++ b/libs/json/decoder_test.go @@ -131,7 +131,6 @@ func TestUnmarshal(t *testing.T) { "invalid type": {`"foo"`, Struct{}, true}, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { // Create a target variable as a pointer to the zero value of the tc.value type, // and wrap it in an empty interface. Decode into that interface. diff --git a/libs/json/encoder_test.go b/libs/json/encoder_test.go index 88eb56f85..8de611c9e 100644 --- a/libs/json/encoder_test.go +++ b/libs/json/encoder_test.go @@ -94,7 +94,6 @@ func TestMarshal(t *testing.T) { }, } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { bz, err := json.Marshal(tc.value) require.NoError(t, err) diff --git a/libs/log/default_test.go b/libs/log/default_test.go index 6ea723c51..3ca1c40a8 100644 --- a/libs/log/default_test.go +++ b/libs/log/default_test.go @@ -32,8 +32,6 @@ func TestNewDefaultLogger(t *testing.T) { } for name, tc := range testCases { - tc := tc - t.Run(name, func(t *testing.T) { _, err := log.NewDefaultLogger(tc.format, tc.level) if tc.expectErr { diff --git a/light/client_test.go b/light/client_test.go index fee4bd841..93368f942 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -350,7 +350,6 @@ func TestClient(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { ctx := t.Context() logger := log.NewNopLogger() diff --git a/light/verifier_test.go b/light/verifier_test.go index 5a2019e21..6033b87f6 100644 --- a/light/verifier_test.go +++ b/light/verifier_test.go @@ -153,7 +153,6 @@ func TestVerifyAdjacentHeaders(t *testing.T) { } for i, tc := range testCases { - tc := tc t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { err := light.VerifyAdjacent(header, tc.newHeader, tc.newVals, tc.trustingPeriod, tc.now, maxClockDrift) switch { @@ -267,7 +266,6 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { } for i, tc := range testCases { - tc := tc t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { err := light.VerifyNonAdjacent(header, vals, tc.newHeader, tc.newVals, tc.trustingPeriod, tc.now, maxClockDrift, diff --git a/privval/grpc/server_test.go b/privval/grpc/server_test.go index f3d75ffa8..92bb42d75 100644 --- a/privval/grpc/server_test.go +++ b/privval/grpc/server_test.go @@ -30,7 +30,6 @@ func TestGetPubKey(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { ctx := t.Context() logger := log.NewTestingLogger(t) @@ -104,7 +103,6 @@ func TestSignVote(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { ctx := t.Context() logger := log.NewTestingLogger(t) @@ -174,7 +172,6 @@ func TestSignProposal(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { ctx := t.Context() logger := log.NewTestingLogger(t) diff --git a/privval/msgs_test.go b/privval/msgs_test.go index 93efa55a9..88d6ec668 100644 --- a/privval/msgs_test.go +++ b/privval/msgs_test.go @@ -86,7 +86,6 @@ func TestPrivvalVectors(t *testing.T) { } for _, tc := range testCases { - tc := tc pm := mustWrapMsg(tc.msg) bz, err := pm.Marshal() diff --git a/proto/tendermint/blocksync/message_test.go b/proto/tendermint/blocksync/message_test.go index 3406c6dff..b5d6a79e1 100644 --- a/proto/tendermint/blocksync/message_test.go +++ b/proto/tendermint/blocksync/message_test.go @@ -24,7 +24,6 @@ func TestBlockRequest_Validate(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { msg := &bcproto.Message{} require.NoError(t, msg.Wrap(&bcproto.BlockRequest{Height: tc.requestHeight})) @@ -44,9 +43,7 @@ func TestNoBlockResponse_Validate(t *testing.T) { {"Valid Non-Response Message", 1, false}, {"Invalid Non-Response Message", -1, true}, } - for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { msg := &bcproto.Message{} require.NoError(t, msg.Wrap(&bcproto.NoBlockResponse{Height: tc.nonResponseHeight})) @@ -74,7 +71,6 @@ func TestStatusResponse_Validate(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { msg := &bcproto.Message{} require.NoError(t, msg.Wrap(&bcproto.StatusResponse{Height: tc.responseHeight})) @@ -120,7 +116,6 @@ func TestBlockchainMessageVectors(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { bz, err := proto.Marshal(tc.bmsg) require.NoError(t, err) diff --git a/proto/tendermint/statesync/message_test.go b/proto/tendermint/statesync/message_test.go index 1db421aca..744ac235f 100644 --- a/proto/tendermint/statesync/message_test.go +++ b/proto/tendermint/statesync/message_test.go @@ -103,7 +103,6 @@ func TestValidateMsg(t *testing.T) { } for name, tc := range testcases { - tc := tc t.Run(name, func(t *testing.T) { msg := new(ssproto.Message) @@ -215,8 +214,6 @@ func TestStateSyncVectors(t *testing.T) { } for _, tc := range testCases { - tc := tc - msg := new(ssproto.Message) require.NoError(t, msg.Wrap(tc.msg)) diff --git a/types/block_test.go b/types/block_test.go index ba626fefb..b6a47b3bf 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -113,7 +113,6 @@ func TestBlockValidateBasic(t *testing.T) { }, true}, } for i, tc := range testCases { - tc := tc i := i t.Run(tc.testName, func(t *testing.T) { block := MakeBlock(h, txs, commit, evList) @@ -286,7 +285,6 @@ func TestCommitValidateBasic(t *testing.T) { {"Incorrect round", func(com *Commit) { com.Round = -100 }, true}, } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { ctx := t.Context() @@ -384,7 +382,6 @@ func TestHeaderHash(t *testing.T) { }, nil}, } for _, tc := range testCases { - tc := tc t.Run(tc.desc, func(t *testing.T) { assert.Equal(t, tc.expectHash, tc.header.Hash()) @@ -507,7 +504,6 @@ func TestBlockMaxDataBytes(t *testing.T) { } for i, tc := range testCases { - tc := tc if tc.panics { assert.Panics(t, func() { MaxDataBytes(tc.maxBytes, tc.evidenceBytes, tc.valsCount) @@ -536,7 +532,6 @@ func TestBlockMaxDataBytesNoEvidence(t *testing.T) { } for i, tc := range testCases { - tc := tc if tc.panics { assert.Panics(t, func() { MaxDataBytesNoEvidence(tc.maxBytes, tc.valsCount) @@ -770,7 +765,6 @@ func TestBlockIDValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { blockID := BlockID{ Hash: tc.blockIDHash, @@ -1049,7 +1043,6 @@ func TestCommitSig_ValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.cs.ValidateBasic() @@ -1310,7 +1303,6 @@ func TestHeader_ValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.header.ValidateBasic() @@ -1409,7 +1401,6 @@ func TestCommit_ValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.commit.ValidateBasic() diff --git a/types/evidence_test.go b/types/evidence_test.go index 8def300fc..47e1a9f5a 100644 --- a/types/evidence_test.go +++ b/types/evidence_test.go @@ -124,7 +124,6 @@ func TestDuplicateVoteEvidenceValidation(t *testing.T) { }, true}, } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { vote1 := makeVote(ctx, t, val, chainID, math.MaxInt32, math.MaxInt64, math.MaxInt32, 0x02, blockID, defaultVoteTime) vote2 := makeVote(ctx, t, val, chainID, math.MaxInt32, math.MaxInt64, math.MaxInt32, 0x02, blockID2, defaultVoteTime) @@ -255,7 +254,6 @@ func TestLightClientAttackEvidenceValidation(t *testing.T) { }, true}, } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { lcae := &LightClientAttackEvidence{ ConflictingBlock: &LightBlock{ @@ -456,7 +454,6 @@ func TestEvidenceVectors(t *testing.T) { } for _, tc := range testCases { - tc := tc hash := tc.evList.Hash() require.Equal(t, tc.expBytes, hex.EncodeToString(hash), tc.testName) } diff --git a/types/light_test.go b/types/light_test.go index f399f3c59..e2cced02c 100644 --- a/types/light_test.go +++ b/types/light_test.go @@ -155,7 +155,6 @@ func TestSignedHeaderValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { sh := SignedHeader{ Header: tc.shHeader, diff --git a/types/node_info_test.go b/types/node_info_test.go index 110c67fc3..1f8480d02 100644 --- a/types/node_info_test.go +++ b/types/node_info_test.go @@ -241,7 +241,6 @@ func TestParseAddressString(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { addr, port, err := ParseAddressString(tc.addr) if tc.correct { diff --git a/types/part_set_test.go b/types/part_set_test.go index 9deea5741..cb777cf0a 100644 --- a/types/part_set_test.go +++ b/types/part_set_test.go @@ -98,7 +98,6 @@ func TestPartSetHeaderValidateBasic(t *testing.T) { {"Invalid Hash", func(psHeader *PartSetHeader) { psHeader.Hash = make([]byte, 1) }, true}, } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { data := tmrand.Bytes(testPartSize * 100) ps := NewPartSetFromData(data, testPartSize) @@ -127,7 +126,6 @@ func TestPartValidateBasic(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.testName, func(t *testing.T) { data := tmrand.Bytes(testPartSize * 100) ps := NewPartSetFromData(data, testPartSize) From 170eb67fafc713cec9dca790624aae767f6c04e6 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 18 Aug 2025 19:56:41 +0200 Subject: [PATCH 07/41] fixed a cleanup in test --- rpc/client/rpc_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rpc/client/rpc_test.go b/rpc/client/rpc_test.go index dd14d9071..17316d412 100644 --- a/rpc/client/rpc_test.go +++ b/rpc/client/rpc_test.go @@ -489,7 +489,8 @@ func TestClientMethodCalls(t *testing.T) { eventCh, err := c.Subscribe(ctx, subscriber, types.QueryForEvent(types.EventNewBlockValue).String()) require.NoError(t, err) t.Cleanup(func() { - if err := c.UnsubscribeAll(ctx, subscriber); err != nil { + // At this point the ctx is cancelled, so the cleanup needs to run with a background context. + if err := c.UnsubscribeAll(context.Background(), subscriber); err != nil { t.Error(err) } }) From f099af6dc3532af1ff9b4f4a8be5c6f56d5bd30e Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 20 Aug 2025 13:25:40 +0200 Subject: [PATCH 08/41] refactored baseService --- libs/service/service.go | 121 ++++++++++++++-------------------------- 1 file changed, 41 insertions(+), 80 deletions(-) diff --git a/libs/service/service.go b/libs/service/service.go index fdd9299fd..9c6c9497d 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -2,17 +2,12 @@ package service import ( "context" - "errors" - "sync" + "sync/atomic" "github.com/tendermint/tendermint/libs/log" ) var ( - // errAlreadyStopped is returned when somebody tries to stop an already - // stopped service (without resetting it). - errAlreadyStopped = errors.New("already stopped") - _ Service = (*BaseService)(nil) ) @@ -43,6 +38,14 @@ type Implementation interface { OnStop() } +type baseService struct { + // This is the context that (structured concurrency) service tasks will be executed with. + // It is canceled when outer context is canceled or when the service is stopped. + ctx context.Context + cancel context.CancelFunc + done chan struct{} +} + /* Classical-inheritance-style service declarations. Services can be started, then stopped, but cannot be restarted. @@ -81,12 +84,9 @@ Typical usage: type BaseService struct { logger log.Logger name string - mtx sync.Mutex - quit <-chan (struct{}) - cancel context.CancelFunc - // The "subclass" of BaseService - impl Implementation + impl Implementation + inner atomic.Pointer[baseService] } // NewBaseService creates a new BaseService. @@ -102,102 +102,63 @@ func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseSe // will be returned if the service is stopped, but not if it is // already running. func (bs *BaseService) Start(ctx context.Context) error { - bs.mtx.Lock() - defer bs.mtx.Unlock() - - if bs.quit != nil { + sCtx, cancel := context.WithCancel(ctx) + inner := &baseService{sCtx, cancel, make(chan struct{})} + if !bs.inner.CompareAndSwap(nil, inner) { + cancel() return nil } - select { - case <-bs.quit: - return errAlreadyStopped - default: - bs.logger.Debug("starting service", "service", bs.name, "impl", bs.name) - if err := bs.impl.OnStart(ctx); err != nil { - return err - } - - // we need a separate context to ensure that we start - // a thread that will get cleaned up and that the - // Stop/Wait functions work as expected. - srvCtx, cancel := context.WithCancel(context.Background()) - bs.cancel = cancel - bs.quit = srvCtx.Done() - - go func(ctx context.Context) { - select { - case <-srvCtx.Done(): - // this means stop was called manually - return - case <-ctx.Done(): - bs.Stop() - } - - bs.logger.Info("stopped service", - "service", bs.name) - }(ctx) - - return nil + bs.logger.Debug("starting service", "service", bs.name, "impl", bs.name) + // Currently sei-tendermint services (and tests) rely on the fact that OnStart is called with + // exactly the same context as Start. + if err := bs.impl.OnStart(ctx); err != nil { + cancel() + return err } + + go func() { + <-inner.ctx.Done() + inner.cancel() // make sure that ctx memory is released + bs.logger.Debug("stopping service", "service", bs.name) + bs.impl.OnStop() + bs.logger.Info("stopped service", "service", bs.name) + close(inner.done) + }() + return nil } // Stop manually terminates the service by calling OnStop method from // the implementation and releases all resources related to the // service. func (bs *BaseService) Stop() { - bs.mtx.Lock() - defer bs.mtx.Unlock() - - if bs.quit == nil { - return - } - - select { - case <-bs.quit: - return - default: - bs.logger.Debug("stopping service", "service", bs.name) - bs.impl.OnStop() - bs.cancel() - - return + if inner := bs.inner.Load(); inner != nil { + inner.cancel() + <-inner.done } } // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { - bs.mtx.Lock() - defer bs.mtx.Unlock() - - if bs.quit == nil { + inner := bs.inner.Load() + if inner == nil { return false } - select { - case <-bs.quit: + case <-inner.done: return false default: return true } } -func (bs *BaseService) getWait() <-chan struct{} { - bs.mtx.Lock() - defer bs.mtx.Unlock() - - if bs.quit == nil { - out := make(chan struct{}) - close(out) - return out +// Wait blocks until the service is stopped. +func (bs *BaseService) Wait() { + if inner := bs.inner.Load(); inner != nil { + <-inner.done } - - return bs.quit } -// Wait blocks until the service is stopped. -func (bs *BaseService) Wait() { <-bs.getWait() } - // String provides a human-friendly representation of the service. func (bs *BaseService) String() string { return bs.name } From fe5aa9819232189a4bf44931a9778aa69e0c5fb6 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 20 Aug 2025 13:45:58 +0200 Subject: [PATCH 09/41] added task spawning --- go.mod | 2 +- libs/service/service.go | 36 +++++- libs/utils/channels.go | 74 +++++++++++ libs/utils/mutex.go | 206 ++++++++++++++++++++++++++++++ libs/utils/mutex_test.go | 39 ++++++ libs/utils/option.go | 73 +++++++++++ libs/utils/option_test.go | 32 +++++ libs/utils/proto.go | 143 +++++++++++++++++++++ libs/utils/require/require.go | 81 ++++++++++++ libs/utils/scope/parallel.go | 41 ++++++ libs/utils/scope/parallel_test.go | 54 ++++++++ libs/utils/scope/start.go | 143 +++++++++++++++++++++ libs/utils/semaphore.go | 24 ++++ libs/utils/testonly.go | 152 ++++++++++++++++++++++ libs/utils/wait.go | 119 +++++++++++++++++ libs/utils/wait_test.go | 23 ++++ 16 files changed, 1235 insertions(+), 7 deletions(-) create mode 100644 libs/utils/channels.go create mode 100644 libs/utils/mutex.go create mode 100644 libs/utils/mutex_test.go create mode 100644 libs/utils/option.go create mode 100644 libs/utils/option_test.go create mode 100644 libs/utils/proto.go create mode 100644 libs/utils/require/require.go create mode 100644 libs/utils/scope/parallel.go create mode 100644 libs/utils/scope/parallel_test.go create mode 100644 libs/utils/scope/start.go create mode 100644 libs/utils/semaphore.go create mode 100644 libs/utils/testonly.go create mode 100644 libs/utils/wait.go create mode 100644 libs/utils/wait_test.go diff --git a/go.mod b/go.mod index e028c1a8b..b295d9103 100644 --- a/go.mod +++ b/go.mod @@ -231,7 +231,6 @@ require ( golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect google.golang.org/genproto v0.0.0-20220519153652-3a47de7e79bd // indirect - google.golang.org/protobuf v1.28.0 // indirect gopkg.in/ini.v1 v1.66.6 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect @@ -254,4 +253,5 @@ require ( go.opentelemetry.io/otel v1.9.0 go.opentelemetry.io/otel/sdk v1.9.0 go.opentelemetry.io/otel/trace v1.9.0 + google.golang.org/protobuf v1.28.0 ) diff --git a/libs/service/service.go b/libs/service/service.go index 9c6c9497d..425da0447 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -2,9 +2,9 @@ package service import ( "context" - "sync/atomic" - "github.com/tendermint/tendermint/libs/log" + "sync" + "sync/atomic" ) var ( @@ -43,6 +43,7 @@ type baseService struct { // It is canceled when outer context is canceled or when the service is stopped. ctx context.Context cancel context.CancelFunc + wg sync.WaitGroup done chan struct{} } @@ -103,9 +104,9 @@ func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseSe // already running. func (bs *BaseService) Start(ctx context.Context) error { sCtx, cancel := context.WithCancel(ctx) - inner := &baseService{sCtx, cancel, make(chan struct{})} + inner := &baseService{sCtx, cancel, sync.WaitGroup{}, make(chan struct{})} if !bs.inner.CompareAndSwap(nil, inner) { - cancel() + cancel() // free the context. return nil } @@ -113,15 +114,16 @@ func (bs *BaseService) Start(ctx context.Context) error { // Currently sei-tendermint services (and tests) rely on the fact that OnStart is called with // exactly the same context as Start. if err := bs.impl.OnStart(ctx); err != nil { - cancel() + cancel() // free the context. return err } go func() { <-inner.ctx.Done() - inner.cancel() // make sure that ctx memory is released + inner.cancel() // free the context. bs.logger.Debug("stopping service", "service", bs.name) bs.impl.OnStop() + inner.wg.Wait() // wait for all spawned tasks to finish bs.logger.Info("stopped service", "service", bs.name) close(inner.done) }() @@ -138,6 +140,28 @@ func (bs *BaseService) Stop() { } } +// Spawn spawns a new goroutine executing task, which will be cancelled +// when outer context is cancelled or when the service is stopped. +// Both Wait and Stop calls will block until the spawned task is finished. +// It should be called ONLY from within OnStart(). +// NOTE that the task is provided with a narrower context than the context +// provided to OnStart(). This is intentional. +// Panics if the service has not been started yet. +func (bs *BaseService) Spawn(task func(ctx context.Context) error) { + inner := bs.inner.Load() + if inner == nil { + panic("service is not started yet") + } + + inner.wg.Add(1) + go func() { + defer inner.wg.Done() + if err := task(inner.ctx); err != nil { + bs.logger.Error("task failed", "service", bs.name, "error", err) + } + }() +} + // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { diff --git a/libs/utils/channels.go b/libs/utils/channels.go new file mode 100644 index 000000000..9eed500ff --- /dev/null +++ b/libs/utils/channels.go @@ -0,0 +1,74 @@ +package utils + +import ( + "context" + + "github.com/pkg/errors" +) + +// Recv receives a value from a channel or returns an error if the context is canceled. +func Recv[T any](ctx context.Context, ch <-chan T) (zero T, err error) { + select { + case v, ok := <-ch: + if ok { + return v, nil + } + // We are not interested in channel closing, + // patiently wait for the context to be done instead. + <-ctx.Done() + return zero, ctx.Err() + case <-ctx.Done(): + return zero, ctx.Err() + } +} + +// RecvOrClosed receives a value from a channel, returns false if channel got closed, +// or returns an error if the context is canceled. +func RecvOrClosed[T any](ctx context.Context, ch <-chan T) (T, bool, error) { + select { + case v, ok := <-ch: + return v, ok, nil + case <-ctx.Done(): + var zero T + return zero, false, ctx.Err() + } +} + +// Send a value to channel or returns an error if the context is canceled. +func Send[T any](ctx context.Context, ch chan<- T, v T) error { + select { + case ch <- v: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// SendOrDrop send a value to channel if not full or drop the item if the channel is full. +func SendOrDrop[T any](ch chan<- T, v T) error { + select { + case ch <- v: + return nil + default: + // drop the item + return nil + } +} + +// ForEach is a helper function that reads from a channel and calls a handler for each item. +// this avoids needing a lot of for/select boilerplate everywhere. +func ForEach[T any](ctx context.Context, ch <-chan T, handler func(T) error) error { + for { + select { + case <-ctx.Done(): + return errors.WithStack(ctx.Err()) + case item, ok := <-ch: + if !ok { + return nil // Channel closed + } + if err := handler(item); err != nil { + return err // Stop on error + } + } + } +} diff --git a/libs/utils/mutex.go b/libs/utils/mutex.go new file mode 100644 index 000000000..b6f4a9a58 --- /dev/null +++ b/libs/utils/mutex.go @@ -0,0 +1,206 @@ +package utils + +import ( + "context" + "iter" + "sync" + "sync/atomic" + + "golang.org/x/sync/errgroup" +) + +// Mutex guards access to object of type T. +type Mutex[T any] struct { + mu sync.Mutex + value T +} + +// NewMutex creates a new Mutex with given object. +func NewMutex[T any](value T) (m Mutex[T]) { + m.value = value + // nolint:nakedret + return +} + +// Lock returns an iterator which locks the mutex and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *Mutex[T]) Lock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.Lock() + defer m.mu.Unlock() + _ = yield(m.value) + } +} + +// version of the value stored in an atomic watch. +type version[T any] struct { + updated chan struct{} + value T +} + +// newVersion constructs a new active version. +func newVersion[T any](value T) *version[T] { + return &version[T]{make(chan struct{}), value} +} + +type atomicWatch[T any] struct { + ptr atomic.Pointer[version[T]] +} + +// AtomicWatch stores a pointer to an IMMUTABLE value. +// Loading and waiting for updates do NOT require locking. +// TODO(gprusak): remove mutex and rename to AtomicSend, +// this will allow for sharing a mutex across multiple AtomicSenders. +type AtomicWatch[T any] struct { + atomicWatch[T] + mu sync.Mutex +} + +// AtomicRecv is a read-only reference to AtomicWatch. +type AtomicRecv[T any] struct{ *atomicWatch[T] } + +// NewAtomicWatch creates a new AtomicWatch with the given initial value. +func NewAtomicWatch[T any](value T) (w AtomicWatch[T]) { + w.ptr.Store(newVersion(value)) + // nolint:nakedret + return +} + +// Subscribe returns a view-only API of the atomic watch. +func (w *AtomicWatch[T]) Subscribe() AtomicRecv[T] { + return AtomicRecv[T]{&w.atomicWatch} +} + +// Load returns the current value of the atomic watch. +// Does not do any locking. +func (w *atomicWatch[T]) Load() T { return w.ptr.Load().value } + +// Store updates the value of the atomic watch. +func (w *AtomicWatch[T]) Store(value T) { + w.mu.Lock() + defer w.mu.Unlock() + close(w.ptr.Swap(newVersion(value)).updated) +} + +// Update conditionally updates the value of the atomic watch. +func (w *AtomicWatch[T]) Update(f func(T) (T, bool)) { + w.mu.Lock() + defer w.mu.Unlock() + old := w.ptr.Load() + if value, ok := f(old.value); ok { + w.ptr.Store(newVersion(value)) + close(old.updated) + } +} + +// Wait waits for the value of the atomic watch to satisfy the predicate. +// Does not do any locking. +func (w *atomicWatch[T]) Wait(ctx context.Context, pred func(T) bool) (T, error) { + for { + v := w.ptr.Load() + if pred(v.value) { + return v.value, nil + } + select { + case <-ctx.Done(): + return Zero[T](), ctx.Err() + case <-v.updated: + } + } +} + +// Iter executes sequentially the function f on each value of the atomic watch. +// Context passed to f is canceled when the next value is available. +// Exits when the returned error is different from nil and context.Canceled, +// or when the context passed to Iter is canceled (after f exits). +func (w *atomicWatch[T]) Iter(ctx context.Context, f func(ctx context.Context, v T) error) error { + for ctx.Err() == nil { + v := w.ptr.Load() + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { return f(ctx, v.value) }) + g.Go(func() error { + select { + case <-ctx.Done(): + case <-v.updated: + } + return context.Canceled + }) + if err := IgnoreCancel(g.Wait()); err != nil { + return err + } + } + return ctx.Err() +} + +// WatchCtrl controls the locked object in a Watch. +// It is provided only in the iterator returned by Lock(). +// Should NOT be stored anywhere. +type WatchCtrl struct { + mu sync.Mutex + updated chan struct{} +} + +// Watch stores a value of type T. +// Essentially a mutex, that can be awaited for updates. +type Watch[T any] struct { + ctrl WatchCtrl + val T +} + +// NewWatch constructs a new watch with the given value. +// Note that value in the watch cannot be changed, so T +// should be a pointer type if updates are required. +func NewWatch[T any](val T) Watch[T] { + return Watch[T]{ + WatchCtrl{updated: make(chan struct{})}, + val, + } +} + +// Wait waits for the value in the watch to be updated. +// Should be called only after locking the watch, i.e. within Lock() iterator. +// It unlocks -> waits for the update -> locks again. +func (c *WatchCtrl) Wait(ctx context.Context) error { + updated := c.updated + c.mu.Unlock() + defer c.mu.Lock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-updated: + return nil + } +} + +// WaitUntil waits for the value in the watch to satisfy the predicate. +// Should be called only after locking the watch, i.e. within Lock() iterator. +// The predicate is evaluated under the lock, so it can access the guarded object. +func (c *WatchCtrl) WaitUntil(ctx context.Context, pred func() bool) error { + for !pred() { + if err := c.Wait(ctx); err != nil { + return err + } + } + return nil +} + +// Updated signals waiters that the value in the watch has been updated. +func (c *WatchCtrl) Updated() { + close(c.updated) + c.updated = make(chan struct{}) +} + +// Lock returns an iterator which locks the watch and yields the guarded object. +// The watch is unlocked when the iterator is done. +// If the watch is nil, the iterator is a no-op. +// Additionally the WatchCtrl object is provided to the yield function: +// * to unlock -> wait for the update -> lock again, call ctrl.Wait(ctx) +// * to signal an update, call ctrl.Updated(). +func (w *Watch[T]) Lock() iter.Seq2[T, *WatchCtrl] { + return func(yield func(val T, ctrl *WatchCtrl) bool) { + w.ctrl.mu.Lock() + defer w.ctrl.mu.Unlock() + _ = yield(w.val, &w.ctrl) + } +} diff --git a/libs/utils/mutex_test.go b/libs/utils/mutex_test.go new file mode 100644 index 000000000..1da1ac372 --- /dev/null +++ b/libs/utils/mutex_test.go @@ -0,0 +1,39 @@ +package utils_test + +import ( + "context" + "fmt" + "testing" + + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/require" + "github.com/tendermint/tendermint/libs/utils/scope" +) + +func TestAtomicWatch(t *testing.T) { + ctx := t.Context() + v := 5 + w := utils.NewAtomicWatch(&v) + require.Equal(t, 5, *w.Load()) + + want := 10 + if err := scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { + for i := 0; i <= want; i++ { + w.Store(&i) + } + return nil + }) + + got, err := w.Wait(ctx, func(v *int) bool { return *v >= want }) + if err != nil { + return err + } + if *got != want { + return fmt.Errorf("got %v, want %v", *got, want) + } + return nil + }); err != nil { + t.Fatal(err) + } +} diff --git a/libs/utils/option.go b/libs/utils/option.go new file mode 100644 index 000000000..85fd6a471 --- /dev/null +++ b/libs/utils/option.go @@ -0,0 +1,73 @@ +package utils + +import ( + "encoding/json" +) + +// Option type inspired https://pkg.go.dev/github.com/samber/mo. +type Option[T any] struct { + ReadOnly + isPresent bool + value T +} + +// Some creates an Option with a value. +func Some[T any](value T) Option[T] { + return Option[T]{isPresent: true, value: value} +} + +// None creates an Option without a value. +func None[T any]() (zero Option[T]) { return } + +// Get unpacks the value from the Option, returning true if it was present. +func (o Option[T]) Get() (T, bool) { + if o.isPresent { + return o.value, true + } + return Zero[T](), false +} + +// IsPresent checks if the Option contains a value. +func (o Option[T]) IsPresent() bool { + return o.isPresent +} + +// Or returns the value if present, otherwise returns the default value. +func (o *Option[T]) Or(def T) T { + if o.isPresent { + return o.value + } + return def +} + +// MapOpt applies a function to the value if present, returning a new Option. +func MapOpt[T, R any](o Option[T], f func(T) R) Option[R] { + if o.isPresent { + return Some(f(o.value)) + } + return None[R]() +} + +// MarshalJSON implements the json.Marshaler interface. +// Note that it is defined on value, not pointer, because +// json.Marshal cannot call pointer methods on fields +// (i.e. it is broken by design). +func (o Option[T]) MarshalJSON() ([]byte, error) { + if o.isPresent { + return json.Marshal(o.value) + } + return []byte("null"), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (o *Option[T]) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + o.isPresent = false + return nil + } + if err := json.Unmarshal(data, &o.value); err != nil { + return err + } + o.isPresent = true + return nil +} diff --git a/libs/utils/option_test.go b/libs/utils/option_test.go new file mode 100644 index 000000000..04a55a1e1 --- /dev/null +++ b/libs/utils/option_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + "encoding/json" + "testing" + + "github.com/tendermint/tendermint/libs/utils/require" +) + +func testJSON[T any](t *testing.T, want T) { + enc, err := json.Marshal(want) + require.NoError(t, err) + t.Logf("%s", enc) + var got T + require.NoError(t, json.Unmarshal(enc, &got)) + require.NoError(t, TestDiff(want, got)) +} + +func TestOptionJSON(t *testing.T) { + type a struct { + X Option[int] + Y Option[string] + } + type b struct { + X Option[int] `json:"X,omitzero"` + Y Option[string] `json:"Y,omitzero"` + } + testJSON(t, &a{}) + testJSON(t, &a{Some(1), Some("a")}) + testJSON(t, &b{}) + testJSON(t, &b{Some(1), Some("a")}) +} diff --git a/libs/utils/proto.go b/libs/utils/proto.go new file mode 100644 index 000000000..5f5ad7a41 --- /dev/null +++ b/libs/utils/proto.go @@ -0,0 +1,143 @@ +package utils + +import ( + "crypto/sha256" + "errors" + "fmt" + "sync" + + "google.golang.org/protobuf/proto" +) + +// Hash is a SHA-256 hash. +type Hash [sha256.Size]byte + +// GetHash computes a hash of the given data. +func GetHash(data []byte) Hash { + return sha256.Sum256(data) +} + +// ParseHash parses a Hash from bytes. +func ParseHash(raw []byte) (Hash, error) { + if got, want := len(raw), sha256.Size; got != want { + return Hash{}, fmt.Errorf("hash size = %v, want %v", got, want) + } + return Hash(raw), nil +} + +// ProtoClone clones a proto.Message object. +func ProtoClone[T proto.Message](item T) T { + return proto.Clone(item).(T) +} + +// ProtoEqual compares two proto.Message objects. +func ProtoEqual[T proto.Message](a, b T) bool { + return proto.Equal(a, b) +} + +// ProtoHash hashes a proto.Message object. +// TODO(gprusak): make it deterministic. +func ProtoHash(a proto.Message) Hash { + raw, err := proto.Marshal(a) + if err != nil { + panic(err) + } + return sha256.Sum256(raw) +} + +// ProtoMessage is comparable proto.Message. +type ProtoMessage interface { + comparable + proto.Message +} + +// ProtoConv is a pair of functions to encode and decode between a type and a ProtoMessage. +type ProtoConv[T any, P ProtoMessage] struct { + Encode func(T) P + Decode func(P) (T, error) +} + +// EncodeSlice encodes a slice of T into a slice of P. +func (c ProtoConv[T, P]) EncodeSlice(t []T) []P { + p := make([]P, len(t)) + for i := range t { + p[i] = c.Encode(t[i]) + } + return p +} + +// DecodeSlice decodes a slice of P into a slice of T. +func (c ProtoConv[T, P]) DecodeSlice(p []P) ([]T, error) { + t := make([]T, len(p)) + var err error + for i := range p { + if t[i], err = c.Decode(p[i]); err != nil { + return nil, fmt.Errorf("[%d]: %w", i, err) + } + } + return t, nil +} + +// Slice constructs a slice. +// It is a syntax sugar for `[]T{v...}`, which avoids +// spelling out T. Not very useful if you need to spell +// out T to construct the elements: in that case +// you might prefer the []T{{...},{...}} syntax instead. +func Slice[T any](v ...T) []T { return v } + +// Alloc moves value to heap. +func Alloc[T any](v T) *T { return &v } + +// Zero returns a zero value of type T. +func Zero[T any]() (zero T) { return } + +// NoCopy may be added to structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +// +// Note that it must not be embedded, otherwise Lock and Unlock methods +// will be exported. +type NoCopy struct{} + +// Lock implements sync.Locker. +func (*NoCopy) Lock() {} + +// Unlock implements sync.Locker. +func (*NoCopy) Unlock() {} + +var _ sync.Locker = (*NoCopy)(nil) + +// NoCompare may be added to structs which must not be used as +// map keys. +type NoCompare [0]func() + +// EncodeOpt encodes Option[T], mapping None to Zero[P](). +func (c ProtoConv[T, P]) EncodeOpt(mv Option[T]) P { + v, ok := mv.Get() + if !ok { + return Zero[P]() + } + return c.Encode(v) +} + +// DecodeReq decodes a ProtoMessage into a T, returning an error if p is nil. +func (c ProtoConv[T, P]) DecodeReq(p P) (T, error) { + if p == Zero[P]() { + return Zero[T](), errors.New("missing") + } + return c.Decode(p) +} + +// DecodeOpt decodes a ProtoMessage into a T, returning nil if p is nil. +func (c ProtoConv[T, P]) DecodeOpt(p P) (Option[T], error) { + if p == Zero[P]() { + return None[T](), nil + } + t, err := c.DecodeReq(p) + if err != nil { + return None[T](), err + } + return Some(t), nil +} diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go new file mode 100644 index 000000000..66bb750d3 --- /dev/null +++ b/libs/utils/require/require.go @@ -0,0 +1,81 @@ +// Package require reexports strongly typed `testify/require` API. +// We don't reexport `New`, because methods cannot be generic. +package require + +import ( + "cmp" + + "github.com/stretchr/testify/require" +) + +// TestingT . +type TestingT = require.TestingT + +// False . +var False = require.False + +// True . +var True = require.True + +// Contains . +var Contains = require.Contains + +// EqualError . +// TODO: get rid of comparing errors by strings, +// use concrete error types instead. +var EqualError = require.EqualError + +// Error . +var Error = require.Error + +// ErrorIs . +var ErrorIs = require.ErrorIs + +// NoError . +var NoError = require.NoError + +// Empty . +var Empty = require.Empty + +// NotEmpty . +var NotEmpty = require.NotEmpty + +// Len . +var Len = require.Len + +// Nil . +var Nil = require.Nil + +// NotNil . +var NotNil = require.NotNil + +// Panics . +var Panics = require.Panics + +// Fail . +var Fail = require.Fail + +// Positive . +func Positive[T cmp.Ordered](t TestingT, e T, msgAndArgs ...any) { + require.Positive(t, e, msgAndArgs...) +} + +// Less . +func Less[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.Less(t, e1, e2, msgAndArgs...) +} + +// Greater . +func Greater[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.Greater(t, e1, e2, msgAndArgs...) +} + +// Equal . +func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { + require.Equal(t, expected, actual, msgAndArgs...) +} + +// NotEqual . +func NotEqual[T any](t TestingT, expected, actual T, msgAndArgs ...any) { + require.NotEqual(t, expected, actual, msgAndArgs...) +} diff --git a/libs/utils/scope/parallel.go b/libs/utils/scope/parallel.go new file mode 100644 index 000000000..1377184d5 --- /dev/null +++ b/libs/utils/scope/parallel.go @@ -0,0 +1,41 @@ +package scope + +import ( + "sync" + "sync/atomic" +) + +type parallelScope struct { + wg sync.WaitGroup + err atomic.Pointer[error] +} + +// ParallelScope is a scope which doesn't require cancellation token, +// just parallelization. +type ParallelScope struct{ *parallelScope } + +// Spawn spawns a new task in the scope. +func (s *parallelScope) Spawn(t func() error) { + s.wg.Add(1) + go func() { + if err := t(); err != nil { + s.err.CompareAndSwap(nil, &err) + } + s.wg.Done() + }() +} + +// Parallel executes a function in parallel scope. +// Compared to Run, it does not allow for early cancellation, +// therefore is suitable for non-blocking computations. +// Returns the first error returned by any of the spawned tasks. +// Waits until all the tasks complete, before returning. +func Parallel(main func(ParallelScope) error) error { + var s parallelScope + s.Spawn(func() error { return main(ParallelScope{&s}) }) + s.wg.Wait() + if perr := s.err.Load(); perr != nil { + return *perr + } + return nil +} diff --git a/libs/utils/scope/parallel_test.go b/libs/utils/scope/parallel_test.go new file mode 100644 index 000000000..7f98872ad --- /dev/null +++ b/libs/utils/scope/parallel_test.go @@ -0,0 +1,54 @@ +package scope + +import ( + "errors" + "testing" +) + +func TestParallelOk(t *testing.T) { + x := [10]int{} + if err := Parallel(func(s ParallelScope) error { + for i := range x { + s.Spawn(func() error { + x[i] = i + return nil + }) + } + return nil + }); err != nil { + t.Fatal(err) + } + for want, got := range x { + if want != got { + t.Fatalf("x[%d] = %d, want %d", want, got, want) + } + } +} + +func TestParallelFail(t *testing.T) { + var wantErr = errors.New("custom err") + x := [10]int{} + err := Parallel(func(s ParallelScope) error { + for i := range x { + s.Spawn(func() error { + if i%2 == 0 { + return wantErr + } + x[i] = i + return nil + }) + } + return nil + }) + if !errors.Is(err, wantErr) { + t.Fatalf("err = %v, want %v", err, wantErr) + } + for want, got := range x { + if want%2 == 0 { + want = 0 + } + if want != got { + t.Fatalf("x[%d] = %d, want %d", want, got, want) + } + } +} diff --git a/libs/utils/scope/start.go b/libs/utils/scope/start.go new file mode 100644 index 000000000..cba8d2e4d --- /dev/null +++ b/libs/utils/scope/start.go @@ -0,0 +1,143 @@ +package scope + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "golang.org/x/sync/errgroup" + + "github.com/tendermint/tendermint/libs/utils" +) + +// Scope of concurrenct tasks. +type Scope struct { + // scope is a concurrecy primitive, so no-ctx-in-struct rule does not apply + // nolint:containedctx + ctx context.Context + all *errgroup.Group + main *sync.WaitGroup +} + +// Spawn spawns a main task. +// Scope gets automatically canceled when all the main tasks return. +func (s Scope) Spawn(t func() error) { + s.main.Add(1) + s.all.Go(func() error { + defer s.main.Done() + return t() + }) +} + +// JoinHandle is a handle to an awaitable task. +type JoinHandle[R any] struct { + result *utils.AtomicWatch[*R] +} + +// Spawn1 is the same as Scope.Spawn, but allows awaiting completion of a task and getting its result. +func Spawn1[R any](s Scope, t func() (R, error)) JoinHandle[R] { + result := utils.NewAtomicWatch[*R](nil) + s.Spawn(func() error { + v, err := t() + if err != nil { + return err + } + result.Store(&v) + return nil + }) + return JoinHandle[R]{&result} +} + +// Join awaits completion of a task and returns its result. +// WARNING: it does NOT return the error of the task - error is returned from the Run() command. +// Join() can only fail when context is canceled. +func (h JoinHandle[R]) Join(ctx context.Context) (R, error) { + res, err := h.result.Wait(ctx, func(v *R) bool { return v != nil }) + if err != nil { + return utils.Zero[R](), err + } + return *res, nil +} + +// If true, tasks that do not respect context cancellation will be logged. +// This is useful for debugging, but causes unnecessary overhead. +// Since this is a constant, debug guard should be optimized out by the compiler. +const enableDebugGuard = false + +func (s Scope) debugGuard(name string, done chan struct{}) { + select { + case <-done: + return + case <-s.ctx.Done(): + } + for { + select { + case <-done: + return + case <-time.After(10 * time.Second): + } + log.Printf("task %q still running", name) + } +} + +// SpawnNamed spawns a named main task. +func (s Scope) SpawnNamed(name string, t func() error) { + done := make(chan struct{}) + s.Spawn(func() error { + defer close(done) + if err := t(); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + return nil + }) + if enableDebugGuard { + go s.debugGuard(name, done) + } +} + +// SpawnBgNamed spawns a named background task. +func (s Scope) SpawnBgNamed(name string, t func() error) { + done := make(chan struct{}) + s.SpawnBg(func() error { + defer close(done) + if err := t(); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + return nil + }) + if enableDebugGuard { + go s.debugGuard(name, done) + } +} + +// SpawnBg spawns a background task. +// Background tasks get canceled when all the main tasks return. +func (s Scope) SpawnBg(t func() error) { s.all.Go(t) } + +// Run runs a scope capable of spawning tasks. +// It is guaranteed that all the spawned tasks will be executed (even if spawned after the context is cancelled), +// and that `Run` will return only after all the tasks have completed. +// Context of the tasks will be automatically cancelled as soon as ANY task returns an error. +// Returns the first error returned by any task (main or background). +func Run(ctx context.Context, main func(context.Context, Scope) error) error { + ctx, cancel := context.WithCancel(ctx) + all, ctx := errgroup.WithContext(ctx) + s := Scope{ctx, all, &sync.WaitGroup{}} + s.Spawn(func() error { return main(ctx, s) }) + s.main.Wait() + cancel() + return s.all.Wait() +} + +// Run1 is the same as Run, but returns the result of the main task. +func Run1[R any](ctx context.Context, main func(context.Context, Scope) (R, error)) (res R, err error) { + err = Run(ctx, func(ctx context.Context, s Scope) error { + var err error + res, err = main(ctx, s) + return err + }) + //nolint:nakedret + return +} diff --git a/libs/utils/semaphore.go b/libs/utils/semaphore.go new file mode 100644 index 000000000..728c12a5c --- /dev/null +++ b/libs/utils/semaphore.go @@ -0,0 +1,24 @@ +package utils + +import ( + "context" +) + +// Semaphore provides a way to bound concurrenct access to a resource. +type Semaphore struct { + ch chan struct{} +} + +// NewSemaphore constructs a new semaphore with n permits. +func NewSemaphore(n int) *Semaphore { + return &Semaphore{ch: make(chan struct{}, n)} +} + +// Acquire acquires a permit from the semaphore. +// Blocks until a permit is available. +func (s *Semaphore) Acquire(ctx context.Context) (relase func(), err error) { + if err := Send(ctx, s.ch, struct{}{}); err != nil { + return nil, err + } + return func() { <-s.ch }, nil +} diff --git a/libs/utils/testonly.go b/libs/utils/testonly.go new file mode 100644 index 000000000..afd6b8aa8 --- /dev/null +++ b/libs/utils/testonly.go @@ -0,0 +1,152 @@ +package utils + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" +) + +// ReadOnly - if a struct embeds ReadOnly, +// its private fields will be compared by TestEqual. +type ReadOnly struct{} + +// isReadOnly returns true if t embeds ReadOnly. +func isReadOnly(t reflect.Type) bool { + want := reflect.TypeOf(ReadOnly{}) + if t.Kind() != reflect.Struct { + return false + } + for i := range t.NumField() { + if f := t.Field(i); f.Anonymous || f.Type == want { + return true + } + } + return false +} + +func cmpComparer[T any, PT interface { + Cmp(b *T) int + *T +}](a PT, b PT) bool { + if a == nil || b == nil { + return a == b + } + return a.Cmp(b) == 0 +} + +var cmpOpts = []cmp.Option{ + protocmp.Transform(), + cmp.Exporter(isReadOnly), + cmpopts.EquateEmpty(), + cmp.Comparer(cmpComparer[big.Int]), +} + +// TestDiff generates a human-readable diff between two objects. +func TestDiff[T any](want, got T) error { + if diff := cmp.Diff(want, got, cmpOpts...); diff != "" { + return fmt.Errorf("want (-) got (+):\n%s", diff) + } + return nil +} + +// TestEqual is a more robust replacement for reflect.DeepEqual for tests. +func TestEqual[T any](a, b T) bool { + return cmp.Equal(a, b, cmpOpts...) +} + +// TestRngSplit returns a new random number splitted from the given one. +// This is a very primitive splitting, known to result with dependent randomness. +// If that ever causes a problem, we can switch to SplitMix. +func TestRngSplit(rng *rand.Rand) *rand.Rand { + return rand.New(rand.NewSource(rng.Int63())) +} + +// TestRng returns a deterministic random number generator. +func TestRng() *rand.Rand { + return rand.New(rand.NewSource(789345342)) +} + +var alphanum = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + +// GenString generates a random string of length n. +func GenString(rng *rand.Rand, n int) string { + s := make([]rune, n) + for i := range n { + s[i] = alphanum[rand.Intn(len(alphanum))] + } + return string(s) +} + +// GenBytes generates a random byte slice. +func GenBytes(rng *rand.Rand, n int) []byte { + s := make([]byte, n) + _, _ = rng.Read(s) + return s +} + +// GenF is a function which generates T. +type GenF[T any] = func(rng *rand.Rand) T + +// GenSlice generates a slice of small random length. +func GenSlice[T any](rng *rand.Rand, gen GenF[T]) []T { + return GenSliceN(rng, 2+rng.Intn(3), gen) +} + +// GenSliceN generates a slice of n elements. +func GenSliceN[T any](rng *rand.Rand, n int, gen GenF[T]) []T { + s := make([]T, n) + for i := range s { + s[i] = gen(rng) + } + return s +} + +// GenMap generates a map of small random length. +func GenMap[K comparable, V any](rng *rand.Rand, genK GenF[K], genV GenF[V]) map[K]V { + return GenMapN(rng, 2+rng.Intn(3), genK, genV) +} + +// GenMapN generates a map of n elements. +func GenMapN[K comparable, V any](rng *rand.Rand, n int, genK GenF[K], genV GenF[V]) map[K]V { + m := make(map[K]V, n) + for len(m) < n { + m[genK(rng)] = genV(rng) + } + return m +} + +// GenTimestamp generates a random timestamp. +func GenTimestamp(rng *rand.Rand) time.Time { + return time.Unix(0, rng.Int63()) +} + +// GenHash generates a random Hash. +func GenHash(rng *rand.Rand) Hash { + var h Hash + _, _ = rng.Read(h[:]) + return h +} + +// Test tests whether reencoding a value is an identity operation. +func (c *ProtoConv[T, P]) Test(want T) error { + p := c.Encode(want) + raw, err := proto.Marshal(p) + if err != nil { + return fmt.Errorf("Marshal(): %w", err) + } + if err := proto.Unmarshal(raw, p); err != nil { + return fmt.Errorf("Unmarshal(): %w", err) + } + got, err := c.Decode(p) + if err != nil { + return fmt.Errorf("Decode(Encode()): %w", err) + } + return TestDiff(want, got) +} diff --git a/libs/utils/wait.go b/libs/utils/wait.go new file mode 100644 index 000000000..4c8c6634f --- /dev/null +++ b/libs/utils/wait.go @@ -0,0 +1,119 @@ +package utils + +import ( + "context" + "encoding" + "errors" + "time" +) + +// IgnoreCancel returns nil if the error is context.Canceled, err otherwise. +func IgnoreCancel(err error) error { + if errors.Is(err, context.Canceled) { + return nil + } + return err +} + +// WithTimeout executes a function with a timeout. +func WithTimeout(ctx context.Context, d time.Duration, f func(ctx context.Context) error) error { + ctx, cancel := context.WithTimeout(ctx, d) + defer cancel() + return f(ctx) +} + +// WithTimeout1 executes a function with a timeout. +func WithTimeout1[R any](ctx context.Context, d time.Duration, f func(ctx context.Context) (R, error)) (R, error) { + ctx, cancel := context.WithTimeout(ctx, d) + defer cancel() + return f(ctx) +} + +// Sleep sleeps for a duration or until the context is canceled. +func Sleep(ctx context.Context, d time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(d): + return nil + } +} + +// SleepUntil sleeps until deadline t or until the context is canceled. +func SleepUntil(ctx context.Context, t time.Time) error { + return Sleep(ctx, time.Until(t)) +} + +// WaitFor polls a check function until it returns true or the context is canceled. +func WaitFor(ctx context.Context, interval time.Duration, check func() bool) error { + if check() { + return nil + } + ticker := time.NewTicker(interval) + for { + if check() { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// WaitForWithTimeout polls a check function until it returns true, the context is canceled, or the timeout is reached. +func WaitForWithTimeout(ctx context.Context, interval, timeout time.Duration, check func() bool) error { + if check() { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + if check() { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Duration is a wrapper type around time.Duration that supports JSON marshaling/unmarshaling. +// nolint:recvcheck +type Duration time.Duration + +// MarshalText implements json.TextMarshaler interface to convert Duration to JSON string. +func (d Duration) MarshalText() ([]byte, error) { + return []byte(time.Duration(d).String()), nil +} + +// UnmarshalText implements json.TextUnmarshaler. +func (d *Duration) UnmarshalText(b []byte) error { + tmp, err := time.ParseDuration(string(b)) + if err != nil { + return err + } + *d = Duration(tmp) + return nil +} + +var _ encoding.TextMarshaler = Zero[Duration]() +var _ encoding.TextUnmarshaler = (*Duration)(nil) + +// Duration returns the underlying time.Duration value. +func (d Duration) Duration() time.Duration { + return time.Duration(d) +} + +// Seconds returns the underlying time.Duration value in seconds. +func (d Duration) Seconds() float64 { + return time.Duration(d).Seconds() +} diff --git a/libs/utils/wait_test.go b/libs/utils/wait_test.go new file mode 100644 index 000000000..91edc1267 --- /dev/null +++ b/libs/utils/wait_test.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/json" + "testing" + "time" +) + +func TestJSON(t *testing.T) { + var got, want struct{ X Duration } + want.X = Duration(100 * time.Millisecond) + j, err := json.Marshal(want) + if err != nil { + t.Fatal(err) + } + t.Logf("%s", j) + if err := json.Unmarshal(j, &got); err != nil { + t.Fatal(err) + } + if err := TestDiff(want, got); err != nil { + t.Fatal(err) + } +} From 7cdf7f99e864cc2ecf149133750699fb41c22b9d Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 20 Aug 2025 14:04:34 +0200 Subject: [PATCH 10/41] example refactor: ticker --- internal/consensus/common_test.go | 4 +- internal/consensus/state.go | 33 +++++---- internal/consensus/ticker.go | 114 ++++++++---------------------- libs/service/service.go | 12 ++-- 4 files changed, 56 insertions(+), 107 deletions(-) diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index fc7e4af31..d74d023c9 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -971,9 +971,7 @@ type mockTicker struct { fired bool } -func (m *mockTicker) Start(context.Context) error { return nil } -func (m *mockTicker) Stop() {} -func (m *mockTicker) IsRunning() bool { return false } +func (m *mockTicker) Run(context.Context) error { return nil } func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) { m.mtx.Lock() diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 20472c805..f052cf1c5 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -95,6 +95,20 @@ type timeoutInfo struct { Step cstypes.RoundStepType `json:"step"` } +func (ti *timeoutInfo) Less(b *timeoutInfo) bool { + // sort by height, then round, then step + if ti.Height != b.Height { + return ti.Height < b.Height + } + if ti.Round != b.Round { + return ti.Round < b.Round + } + // This is copy-pasted logic, supposedly allowing for updating the timeout with Step 0 without incrementing the step. + // Note that because of this Less is NOT a strict order. + // TODO(gprusak): Figure out why we special case step 0 and fix it. + return ti.Step <= 0 || ti.Step < b.Step +} + func (timeoutInfo) TypeTag() string { return "tendermint/wal/TimeoutInfo" } func (ti *timeoutInfo) String() string { @@ -420,9 +434,7 @@ func (cs *State) OnStart(ctx context.Context) error { // NOTE: we will get a build up of garbage go routines // firing on the tockChan until the receiveRoutine is started // to deal with them (by that point, at most one will be valid) - if err := cs.timeoutTicker.Start(ctx); err != nil { - return err - } + cs.Spawn("timeoutTicker", cs.timeoutTicker.Run) // We may have lost some votes if the process crashed reload from consensus // log to catchup. @@ -496,12 +508,11 @@ func (cs *State) OnStart(ctx context.Context) error { // // this is only used in tests. func (cs *State) startRoutines(ctx context.Context, maxSteps int) { - err := cs.timeoutTicker.Start(ctx) - if err != nil { - cs.logger.Error("failed to start timeout ticker", "err", err) - return - } - + go func() { + if err := cs.timeoutTicker.Run(ctx); err != nil { + cs.logger.Error("cs.timeoutTicker.Run()", "err", err) + } + }() go cs.receiveRoutine(ctx, maxSteps) } @@ -537,10 +548,6 @@ func (cs *State) OnStop() { cs.logger.Error("OnStop: timeout waiting for commit to finish", "time", commitTimeout) } } - - if cs.timeoutTicker.IsRunning() { - cs.timeoutTicker.Stop() - } // WAL is stopped in receiveRoutine. } diff --git a/internal/consensus/ticker.go b/internal/consensus/ticker.go index c47af635c..a362249df 100644 --- a/internal/consensus/ticker.go +++ b/internal/consensus/ticker.go @@ -2,10 +2,9 @@ package consensus import ( "context" - "time" - "github.com/tendermint/tendermint/libs/log" - "github.com/tendermint/tendermint/libs/service" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" ) var ( @@ -16,9 +15,7 @@ var ( // conditional on the height/round/step in the timeoutInfo. // The timeoutInfo.Duration may be non-positive. type TimeoutTicker interface { - Start(context.Context) error - Stop() - IsRunning() bool + Run(context.Context) error Chan() <-chan timeoutInfo // on which to receive a timeout ScheduleTimeout(ti timeoutInfo) // reset the timer } @@ -29,108 +26,53 @@ type TimeoutTicker interface { // Timeouts are scheduled along the tickChan, // and fired on the tockChan. type timeoutTicker struct { - service.BaseService - logger log.Logger - - timer *time.Timer - tickChan chan timeoutInfo // for scheduling timeouts - tockChan chan timeoutInfo // for notifying about them + logger log.Logger + tick utils.AtomicWatch[utils.Option[timeoutInfo]] // for scheduling timeouts + tockChan chan timeoutInfo // for notifying about them } // NewTimeoutTicker returns a new TimeoutTicker. func NewTimeoutTicker(logger log.Logger) TimeoutTicker { tt := &timeoutTicker{ logger: logger, - timer: time.NewTimer(0), - tickChan: make(chan timeoutInfo, tickTockBufferSize), + tick: utils.NewAtomicWatch(utils.None[timeoutInfo]()), tockChan: make(chan timeoutInfo, tickTockBufferSize), } - tt.BaseService = *service.NewBaseService(logger, "TimeoutTicker", tt) - tt.stopTimer() // don't want to fire until the first scheduled timeout return tt } -// OnStart implements service.Service. It starts the timeout routine. -func (t *timeoutTicker) OnStart(ctx context.Context) error { - go t.timeoutRoutine(ctx) - - return nil -} - -// OnStop implements service.Service. It stops the timeout routine. -func (t *timeoutTicker) OnStop() { t.stopTimer() } - // Chan returns a channel on which timeouts are sent. func (t *timeoutTicker) Chan() <-chan timeoutInfo { return t.tockChan } -// ScheduleTimeout schedules a new timeout by sending on the internal tickChan. -// The timeoutRoutine is always available to read from tickChan, so this won't block. -// The scheduling may fail if the timeoutRoutine has already scheduled a timeout for a later height/round/step. -func (t *timeoutTicker) ScheduleTimeout(ti timeoutInfo) { - t.tickChan <- ti -} - -//------------------------------------------------------------- - -// stop the timer and drain if necessary -func (t *timeoutTicker) stopTimer() { - // Stop() returns false if it was already fired or was stopped - if !t.timer.Stop() { - select { - case <-t.timer.C: - default: +// ScheduleTimeout schedules a new timeout, which replaces the previous one. +// Noop if a timeout for a later height/round/step has been already scheduled. +func (t *timeoutTicker) ScheduleTimeout(newti timeoutInfo) { + t.tick.Update(func(old utils.Option[timeoutInfo]) (utils.Option[timeoutInfo], bool) { + if oldti, ok := old.Get(); !ok || oldti.Less(&newti) { + return utils.Some(newti), true } - } + return old, false + }) } -// send on tickChan to start a new timer. // timers are interupted and replaced by new ticks from later steps // timeouts of 0 on the tickChan will be immediately relayed to the tockChan -func (t *timeoutTicker) timeoutRoutine(ctx context.Context) { - var ti timeoutInfo - for { - select { - case newti := <-t.tickChan: - t.logger.Debug("Received tick", "old_ti", ti, "new_ti", newti) - - // ignore tickers for old height/round/step - if newti.Height < ti.Height { - continue - } else if newti.Height == ti.Height { - if newti.Round < ti.Round { - continue - } else if newti.Round == ti.Round { - if ti.Step > 0 && newti.Step <= ti.Step { - continue - } - } +func (t *timeoutTicker) Run(ctx context.Context) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + return t.tick.Iter(ctx, func(ctx context.Context, mti utils.Option[timeoutInfo]) error { + ti, ok := mti.Get() + if !ok { + return nil } - - // stop the last timer - t.stopTimer() - - // update timeoutInfo and reset timer - // NOTE time.Timer allows duration to be non-positive - ti = newti - t.timer.Stop() - t.timer.Reset(ti.Duration) t.logger.Debug("Internal state machine timeout scheduled", "duration", ti.Duration, "height", ti.Height, "round", ti.Round, "step", ti.Step) - case <-t.timer.C: + if err := utils.Sleep(ctx, ti.Duration); err != nil { + return err + } t.logger.Debug("Internal state machine timeout elapsed ", "duration", ti.Duration, "height", ti.Height, "round", ti.Round, "step", ti.Step) - // go routine here guarantees timeoutRoutine doesn't block. - // Determinism comes from playback in the receiveRoutine. - // We can eliminate it by merging the timeoutRoutine into receiveRoutine - // and managing the timeouts ourselves with a millisecond ticker - go func(toi timeoutInfo) { - select { - case t.tockChan <- toi: - case <-ctx.Done(): - } - }(ti) - case <-ctx.Done(): - return - } - } + s.Spawn(func() error { return utils.Send(ctx, t.tockChan, ti) }) + return nil + }) + }) } diff --git a/libs/service/service.go b/libs/service/service.go index 425da0447..685b267c1 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -3,6 +3,7 @@ package service import ( "context" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils" "sync" "sync/atomic" ) @@ -140,14 +141,15 @@ func (bs *BaseService) Stop() { } } -// Spawn spawns a new goroutine executing task, which will be cancelled +// Spawn spawns a new goroutine executing the task, which will be cancelled // when outer context is cancelled or when the service is stopped. +// Error (other than ctx.Canceled) is logged after the task finishes. // Both Wait and Stop calls will block until the spawned task is finished. // It should be called ONLY from within OnStart(). -// NOTE that the task is provided with a narrower context than the context +// Note that the task is provided with a narrower context than the context // provided to OnStart(). This is intentional. // Panics if the service has not been started yet. -func (bs *BaseService) Spawn(task func(ctx context.Context) error) { +func (bs *BaseService) Spawn(name string, task func(ctx context.Context) error) { inner := bs.inner.Load() if inner == nil { panic("service is not started yet") @@ -156,8 +158,8 @@ func (bs *BaseService) Spawn(task func(ctx context.Context) error) { inner.wg.Add(1) go func() { defer inner.wg.Done() - if err := task(inner.ctx); err != nil { - bs.logger.Error("task failed", "service", bs.name, "error", err) + if err := utils.IgnoreCancel(task(inner.ctx)); err != nil { + bs.logger.Error("task failed", "name", name, "service", bs.name, "error", err) } }() } From f200d99041a8c82c6da5c7ef3019cc07981103c0 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 20 Aug 2025 15:09:19 +0200 Subject: [PATCH 11/41] go vet --- internal/p2p/peermanager_test.go | 1 + internal/p2p/transport_test.go | 1 + 2 files changed, 2 insertions(+) diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 6e60889a0..048affb54 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -1453,6 +1453,7 @@ func TestPeerManager_EvictNext(t *testing.T) { // Since there are no more peers to evict, the next call should block. timeoutCtx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() _, err = peerManager.EvictNext(timeoutCtx) require.Error(t, err) require.Equal(t, context.DeadlineExceeded, err) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index 865215b22..ccb783f1d 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -315,6 +315,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { // Handshake should error on context timeout. ab, ba = dialAccept(ctx, t, a, b) timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond) + defer cancel() _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) require.Error(t, err) require.Equal(t, context.DeadlineExceeded, err) From f609e961adfcbc2b4b92190ab6f1ebd5fc0803bd Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 15:33:28 +0200 Subject: [PATCH 12/41] 2-heap queue --- internal/p2p/pqueue.go | 58 ----------- internal/p2p/router.go | 62 +----------- internal/p2p/rqueue.go | 218 +++++++++++++++++++++++++---------------- 3 files changed, 133 insertions(+), 205 deletions(-) diff --git a/internal/p2p/pqueue.go b/internal/p2p/pqueue.go index 3cd1c897a..095b414c5 100644 --- a/internal/p2p/pqueue.go +++ b/internal/p2p/pqueue.go @@ -13,64 +13,6 @@ import ( "github.com/tendermint/tendermint/libs/log" ) -// pqEnvelope defines a wrapper around an Envelope with priority to be inserted -// into a priority queue used for Envelope scheduling. -type pqEnvelope struct { - envelope Envelope - priority uint - size uint - timestamp time.Time - - index int -} - -// priorityQueue defines a type alias for a priority queue implementation. -type priorityQueue []*pqEnvelope - -func (pq priorityQueue) get(i int) *pqEnvelope { return pq[i] } -func (pq priorityQueue) Len() int { return len(pq) } - -func (pq priorityQueue) Less(i, j int) bool { - // if both elements have the same priority, prioritize based - // on most recent and largest - if pq[i].priority == pq[j].priority { - diff := pq[i].timestamp.Sub(pq[j].timestamp) - if diff < 0 { - diff *= -1 - } - if diff < 10*time.Millisecond { - return pq[i].size > pq[j].size - } - return pq[i].timestamp.After(pq[j].timestamp) - } - - // otherwise, pick the pqEnvelope with the higher priority - return pq[i].priority > pq[j].priority -} - -func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j -} - -func (pq *priorityQueue) Push(x interface{}) { - n := len(*pq) - pqEnv := x.(*pqEnvelope) - pqEnv.index = n - *pq = append(*pq, pqEnv) -} - -func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - pqEnv := old[n-1] - old[n-1] = nil - pqEnv.index = -1 - *pq = old[:n-1] - return pqEnv -} - // Assert the priority queue scheduler implements the queue interface at // compile-time. var _ queue = (*pqScheduler)(nil) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 978f4812e..4eef81b5a 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -35,10 +35,6 @@ type RouterOptions struct { // no timeout. HandshakeTimeout time.Duration - // QueueType must be, "priority", or "fifo". Defaults to - // "fifo". - QueueType string - // MaxIncomingConnectionAttempts rate limits the number of incoming connection // attempts per IP address. Defaults to 100. MaxIncomingConnectionAttempts uint @@ -75,23 +71,8 @@ type RouterOptions struct { NumConcurrentDials func() int } -const ( - queueTypeFifo = "fifo" - queueTypePriority = "priority" - queueTypeSimplePriority = "simple-priority" -) - // Validate validates router options. func (o *RouterOptions) Validate() error { - switch o.QueueType { - case "": - o.QueueType = queueTypeFifo - case queueTypeFifo, queueTypePriority, queueTypeSimplePriority: - // pass - default: - return fmt.Errorf("queue type %q is not supported", o.QueueType) - } - switch { case o.IncomingConnectionWindow == 0: o.IncomingConnectionWindow = 100 * time.Millisecond @@ -165,7 +146,6 @@ type Router struct { peerQueues map[types.NodeID]queue // outbound messages per peer for all channels // the channels that the peer queue has open peerChannels map[types.NodeID]ChannelIDSet - queueFactory func(int) queue nodeInfoProducer func() *types.NodeInfo // FIXME: We don't strictly need to use a mutex for this if we seal the @@ -231,30 +211,6 @@ func NewRouter( return router, nil } -func (r *Router) createQueueFactory(ctx context.Context) (func(int) queue, error) { - switch r.options.QueueType { - case queueTypeFifo: - return newFIFOQueue, nil - - case queueTypePriority: - return func(size int) queue { - if size%2 != 0 { - size++ - } - - q := newPQScheduler(r.logger, r.metrics, r.lc, r.chDescs, uint(size)/2, uint(size)/2, defaultCapacity) - q.start(ctx) - return q - }, nil - - case queueTypeSimplePriority: - return func(size int) queue { return newSimplePriorityQueue(ctx, size, r.chDescs) }, nil - - default: - return nil, fmt.Errorf("cannot construct queue of type %q", r.options.QueueType) - } -} - // ChannelCreator allows routers to construct their own channels, // either by receiving a reference to Router.OpenChannel or using some // kind shim for testing purposes. @@ -278,7 +234,7 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C messageType := chDesc.MessageType - queue := r.queueFactory(chDesc.RecvBufferCapacity) + queue := newSimplePriorityQueue(chDesc.RecvBufferCapacity, r.chDescs) outCh := make(chan Envelope, chDesc.RecvBufferCapacity) errCh := make(chan PeerError, chDesc.RecvBufferCapacity) channel := NewChannel(id, queue.dequeue(), outCh, errCh) @@ -688,7 +644,7 @@ func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) queu return peerQueue } - peerQueue := r.queueFactory(queueBufferDefault) + peerQueue := newSimplePriorityQueue(queueBufferDefault, r.chDescs) r.peerQueues[peerID] = peerQueue r.peerChannels[peerID] = channels return peerQueue @@ -982,16 +938,6 @@ func (r *Router) evictPeers(ctx context.Context) { } } -func (r *Router) setupQueueFactory(ctx context.Context) error { - qf, err := r.createQueueFactory(ctx) - if err != nil { - return err - } - - r.queueFactory = qf - return nil -} - func (r *Router) AddChDescToBeAdded(chDesc *ChannelDescriptor, callback func(*Channel)) { r.chDescsToBeAdded = append(r.chDescsToBeAdded, chDescAdderWithCallback{ chDesc: chDesc, @@ -1001,10 +947,6 @@ func (r *Router) AddChDescToBeAdded(chDesc *ChannelDescriptor, callback func(*Ch // OnStart implements service.Service. func (r *Router) OnStart(ctx context.Context) error { - if err := r.setupQueueFactory(ctx); err != nil { - return err - } - if err := r.transport.Listen(r.endpoint); err != nil { return err } diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index afeed3a65..5c83530ea 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -3,111 +3,155 @@ package p2p import ( "container/heap" "context" - "sort" "time" "github.com/gogo/protobuf/proto" + "github.com/tendermint/tendermint/libs/utils" ) -type simpleQueue struct { - input chan Envelope - output chan Envelope - closeFn func() - closeCh <-chan struct{} +type ord[T any] interface { + Less(T) bool +} + +type withIdx[T any] struct { + v T + minIdx int // index in byMin + maxIdx int // index in byMax +} + +func newWithIdx[T any](v T) *withIdx[T] { + return &withIdx[T] { v: v } +} + +// Heap returning minimal elements. +type byMin[T ord[T]] struct { a []*withIdx[T] } +func newByMin[T ord[T]](capacity int) byMin[T] { return byMin[T]{make([]*withIdx[T],0,capacity)} } +func (x *byMin[T]) Less(i, j int) bool { return x.a[i].v.Less(x.a[j].v) } +func (x *byMin[T]) Len() int { return len(x.a) } +func (x *byMin[T]) Swap(i, j int) { + x.a[i],x.a[j] = x.a[j],x.a[i] + x.a[i].minIdx = i + x.a[j].minIdx = j +} +func (x *byMin[T]) Push(v any) { + w := v.(*withIdx[T]) + w.minIdx = len(x.a) + x.a = append(x.a,w) +} +func (x *byMin[T]) Pop() any { + n := len(x.a)-1 + w := x.a[n] + x.a = x.a[:n] + return w +} + +// Heap returning maximal elements. +type byMax[T ord[T]] struct { a []*withIdx[T] } +func newByMax[T ord[T]](capacity int) byMax[T] { return byMax[T]{make([]*withIdx[T],0,capacity)} } +func (x *byMax[T]) Less(i, j int) bool { return x.a[j].v.Less(x.a[i].v) } +func (x *byMax[T]) Len() int { return len(x.a) } +func (x *byMax[T]) Swap(i, j int) { + x.a[i],x.a[j] = x.a[j],x.a[i] + x.a[i].maxIdx = i + x.a[j].maxIdx = j +} +func (x *byMax[T]) Push(v any) { + w := v.(*withIdx[T]) + w.maxIdx = len(x.a) + x.a = append(x.a,w) +} +func (x *byMax[T]) Pop() any { + n := len(x.a)-1 + w := x.a[n] + x.a = x.a[:n] + return w +} - maxSize int - chDescs []*ChannelDescriptor +// pqEnvelope defines a wrapper around an Envelope with priority to be inserted +// into a priority queue used for Envelope scheduling. +type pqEnvelope struct { + envelope Envelope + priority uint + size int + timestamp time.Time } -func newSimplePriorityQueue(ctx context.Context, size int, chDescs []*ChannelDescriptor) *simpleQueue { - if size%2 != 0 { - size++ +// true <=> a has higher priority than b +func (a *pqEnvelope) Less(b *pqEnvelope) bool { + // higher base priority wins + if a,b := a.priority,b.priority; a!=b { + return a > b } + // newer timestamp wins + if a,b := a.timestamp,b.timestamp; a.Sub(b).Abs() >= 10*time.Millisecond { + return a.After(b) + } + // larger first + return a.size > b.size +} + +type inner struct { + capacity int + byMin byMin[*pqEnvelope] + byMax byMax[*pqEnvelope] +} + +func newInner(capacity int) *inner { + return &inner { + capacity: capacity, + // We prune the maximal elements whenever capacity is exceeded. + // Therefore to avoid reallocation we need the heaps to have capacity+1. + byMin: newByMin[*pqEnvelope](capacity+1), + byMax: newByMax[*pqEnvelope](capacity+1), + } +} + +func (i *inner) Len() int { return i.byMin.Len() } - ctx, cancel := context.WithCancel(ctx) - q := &simpleQueue{ - input: make(chan Envelope, size*2), - output: make(chan Envelope, size/2), - maxSize: size * size, - closeCh: ctx.Done(), - closeFn: cancel, +func (i *inner) Push(e *pqEnvelope) { + w := newWithIdx(e) + heap.Push(&i.byMin,w) + heap.Push(&i.byMax,w) + if i.byMin.Len()>i.capacity { + w := heap.Pop(&i.byMax).(*withIdx[*pqEnvelope]) + heap.Remove(&i.byMin,w.minIdx) } +} - go q.run(ctx) - return q +func (i *inner) Pop() *pqEnvelope { + w := heap.Pop(&i.byMin).(*withIdx[*pqEnvelope]) + heap.Remove(&i.byMax,w.maxIdx) + return w.v } -func (q *simpleQueue) enqueue() chan<- Envelope { return q.input } -func (q *simpleQueue) dequeue() <-chan Envelope { return q.output } -func (q *simpleQueue) close() { q.closeFn() } -func (q *simpleQueue) closed() <-chan struct{} { return q.closeCh } +type simpleQueue struct { inner utils.Watch[*inner] } -func (q *simpleQueue) run(ctx context.Context) { - defer q.closeFn() +func newSimplePriorityQueue(size int) *simpleQueue { + return &simpleQueue{inner: utils.NewWatch(newInner(size))} +} - var chPriorities = make(map[ChannelID]uint, len(q.chDescs)) - for _, chDesc := range q.chDescs { - chID := chDesc.ID - chPriorities[chID] = uint(chDesc.Priority) +// Non-blocking send. +func (q *simpleQueue) Send(e Envelope, priority uint) { + // We construct the pqEnvelope without holding the lock to avoid contention. + pqe := &pqEnvelope{ + envelope: e, + size: proto.Size(e.Message), + priority: priority, + timestamp: time.Now().UTC(), + } + for inner,ctrl := range q.inner.Lock() { + inner.Push(pqe) + ctrl.Updated() } +} - pq := make(priorityQueue, 0, q.maxSize) - heap.Init(&pq) - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - // must have a buffer of exactly one because both sides of - // this channel are used in this loop, and simply signals adds - // to the heap - signal := make(chan struct{}, 1) - for { - select { - case <-ctx.Done(): - return - case <-q.closeCh: - return - case e := <-q.input: - // enqueue the incoming Envelope - heap.Push(&pq, &pqEnvelope{ - envelope: e, - size: uint(proto.Size(e.Message)), - priority: chPriorities[e.ChannelID], - timestamp: time.Now().UTC(), - }) - - select { - case signal <- struct{}{}: - default: - if len(pq) > q.maxSize { - sort.Sort(pq) - pq = pq[:q.maxSize] - } - } - - case <-ticker.C: - if len(pq) > q.maxSize { - sort.Sort(pq) - pq = pq[:q.maxSize] - } - if len(pq) > 0 { - select { - case signal <- struct{}{}: - default: - } - } - case <-signal: - SEND: - for len(pq) > 0 { - select { - case <-ctx.Done(): - return - case <-q.closeCh: - return - case q.output <- heap.Pop(&pq).(*pqEnvelope).envelope: - continue SEND - default: - break SEND - } - } +// Blocking recv. +func (q *simpleQueue) Recv(ctx context.Context) (Envelope,error) { + for inner,ctrl := range q.inner.Lock() { + if err:=ctrl.WaitUntil(ctx,func() bool { return inner.Len()>0 }); err!=nil { + return Envelope{},err } + return inner.Pop().envelope,nil } + panic("unreachable") } From e58e0ee66a3e8400be496a8017a4b9e8e2ad8905 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 15:56:26 +0200 Subject: [PATCH 13/41] almost done --- internal/p2p/channel.go | 18 +-- internal/p2p/pqueue.go | 235 ------------------------------------ internal/p2p/pqueue_test.go | 45 ------- internal/p2p/queue.go | 53 -------- internal/p2p/router.go | 153 ++++++++--------------- internal/p2p/rqueue.go | 12 +- node/node.go | 4 +- 7 files changed, 61 insertions(+), 459 deletions(-) delete mode 100644 internal/p2p/pqueue.go delete mode 100644 internal/p2p/pqueue_test.go delete mode 100644 internal/p2p/queue.go diff --git a/internal/p2p/channel.go b/internal/p2p/channel.go index 82d7a6b98..388d05e70 100644 --- a/internal/p2p/channel.go +++ b/internal/p2p/channel.go @@ -8,6 +8,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils" ) // Envelope contains a message with sender/receiver routing info. @@ -60,7 +61,7 @@ func (pe PeerError) Unwrap() error { return pe.Err } // Each message is wrapped in an Envelope to specify its sender and receiver. type Channel struct { ID ChannelID - inCh <-chan Envelope // inbound messages (peers to reactors) + inCh *queue // inbound messages (peers to reactors) outCh chan<- Envelope // outbound messages (reactors to peers) errCh chan<- PeerError // peer error reporting @@ -69,7 +70,7 @@ type Channel struct { // NewChannel creates a new channel. It is primarily for internal and test // use, reactors should use Router.OpenChannel(). -func NewChannel(id ChannelID, inCh <-chan Envelope, outCh chan<- Envelope, errCh chan<- PeerError) *Channel { +func NewChannel(id ChannelID, inCh *queue, outCh chan<- Envelope, errCh chan<- PeerError) *Channel { return &Channel{ ID: id, inCh: inCh, @@ -128,16 +129,9 @@ type ChannelIterator struct { func iteratorWorker(ctx context.Context, ch *Channel, pipe chan Envelope) { for { - select { - case <-ctx.Done(): - return - case envelope := <-ch.inCh: - select { - case <-ctx.Done(): - return - case pipe <- envelope: - } - } + e,err:=ch.inCh.Recv(ctx) + if err!=nil { return } + if err:=utils.Send(ctx, pipe, e); err!=nil { return } } } diff --git a/internal/p2p/pqueue.go b/internal/p2p/pqueue.go deleted file mode 100644 index 095b414c5..000000000 --- a/internal/p2p/pqueue.go +++ /dev/null @@ -1,235 +0,0 @@ -package p2p - -import ( - "container/heap" - "context" - "sort" - "strconv" - "sync" - "time" - - "github.com/gogo/protobuf/proto" - - "github.com/tendermint/tendermint/libs/log" -) - -// Assert the priority queue scheduler implements the queue interface at -// compile-time. -var _ queue = (*pqScheduler)(nil) - -type pqScheduler struct { - logger log.Logger - metrics *Metrics - lc *metricsLabelCache - size uint - sizes map[uint]uint // cumulative priority sizes - pq *priorityQueue - chDescs []*ChannelDescriptor - capacity uint - chPriorities map[ChannelID]uint - - enqueueCh chan Envelope - dequeueCh chan Envelope - - closeFn func() - closeCh <-chan struct{} - done chan struct{} -} - -func newPQScheduler( - logger log.Logger, - m *Metrics, - lc *metricsLabelCache, - chDescs []*ChannelDescriptor, - enqueueBuf, dequeueBuf, capacity uint, -) *pqScheduler { - - // copy each ChannelDescriptor and sort them by ascending channel priority - chDescsCopy := make([]*ChannelDescriptor, len(chDescs)) - copy(chDescsCopy, chDescs) - sort.Slice(chDescsCopy, func(i, j int) bool { return chDescsCopy[i].Priority < chDescsCopy[j].Priority }) - - var ( - chPriorities = make(map[ChannelID]uint) - sizes = make(map[uint]uint) - ) - - for _, chDesc := range chDescsCopy { - chID := chDesc.ID - chPriorities[chID] = uint(chDesc.Priority) - sizes[uint(chDesc.Priority)] = 0 - } - - pq := make(priorityQueue, 0) - heap.Init(&pq) - - closeCh := make(chan struct{}) - once := &sync.Once{} - - return &pqScheduler{ - logger: logger.With("router", "scheduler"), - metrics: m, - lc: lc, - chDescs: chDescsCopy, - capacity: capacity, - chPriorities: chPriorities, - pq: &pq, - sizes: sizes, - enqueueCh: make(chan Envelope, enqueueBuf), - dequeueCh: make(chan Envelope, dequeueBuf), - closeFn: func() { once.Do(func() { close(closeCh) }) }, - closeCh: closeCh, - done: make(chan struct{}), - } -} - -// start starts non-blocking process that starts the priority queue scheduler. -func (s *pqScheduler) start(ctx context.Context) { go s.process(ctx) } -func (s *pqScheduler) enqueue() chan<- Envelope { return s.enqueueCh } -func (s *pqScheduler) dequeue() <-chan Envelope { return s.dequeueCh } -func (s *pqScheduler) close() { s.closeFn() } -func (s *pqScheduler) closed() <-chan struct{} { return s.done } - -// process starts a block process where we listen for Envelopes to enqueue. If -// there is sufficient capacity, it will be enqueued into the priority queue, -// otherwise, we attempt to dequeue enough elements from the priority queue to -// make room for the incoming Envelope by dropping lower priority elements. If -// there isn't sufficient capacity at lower priorities for the incoming Envelope, -// it is dropped. -// -// After we attempt to enqueue the incoming Envelope, if the priority queue is -// non-empty, we pop the top Envelope and send it on the dequeueCh. -func (s *pqScheduler) process(ctx context.Context) { - defer close(s.done) - - for { - select { - case e := <-s.enqueueCh: - chIDStr := strconv.Itoa(int(e.ChannelID)) - pqEnv := &pqEnvelope{ - envelope: e, - size: uint(proto.Size(e.Message)), - priority: s.chPriorities[e.ChannelID], - timestamp: time.Now().UTC(), - } - - // enqueue - - // Check if we have sufficient capacity to simply enqueue the incoming - // Envelope. - if s.size+pqEnv.size <= s.capacity { - s.metrics.PeerPendingSendBytes.With("peer_id", string(pqEnv.envelope.To)).Add(float64(pqEnv.size)) - // enqueue the incoming Envelope - s.push(pqEnv) - } else { - // There is not sufficient capacity to simply enqueue the incoming - // Envelope. So we have to attempt to make room for it by dropping lower - // priority Envelopes or drop the incoming Envelope otherwise. - - // The cumulative size of all enqueue envelopes at the incoming envelope's - // priority or lower. - total := s.sizes[pqEnv.priority] - - if total >= pqEnv.size { - // There is room for the incoming Envelope, so we drop as many lower - // priority Envelopes as we need to. - var ( - canEnqueue bool - tmpSize = s.size - i = s.pq.Len() - 1 - ) - - // Drop lower priority Envelopes until sufficient capacity exists for - // the incoming Envelope - for i >= 0 && !canEnqueue { - pqEnvTmp := s.pq.get(i) - - if pqEnvTmp.priority < pqEnv.priority { - if tmpSize+pqEnv.size <= s.capacity { - canEnqueue = true - } else { - pqEnvTmpChIDStr := strconv.Itoa(int(pqEnvTmp.envelope.ChannelID)) - s.metrics.PeerQueueDroppedMsgs.With("ch_id", pqEnvTmpChIDStr).Add(1) - s.logger.Debug( - "dropped envelope", - "ch_id", pqEnvTmpChIDStr, - "priority", pqEnvTmp.priority, - "msg_size", pqEnvTmp.size, - "capacity", s.capacity, - ) - - s.metrics.PeerPendingSendBytes.With("peer_id", string(pqEnvTmp.envelope.To)).Add(float64(-pqEnvTmp.size)) - - // dequeue/drop from the priority queue - heap.Remove(s.pq, pqEnvTmp.index) - - // update the size tracker - tmpSize -= pqEnvTmp.size - - // start from the end again - i = s.pq.Len() - 1 - } - } else { - i-- - } - } - - // enqueue the incoming Envelope - s.push(pqEnv) - } else { - // There is not sufficient capacity to drop lower priority Envelopes, - // so we drop the incoming Envelope. - s.metrics.PeerQueueDroppedMsgs.With("ch_id", chIDStr).Add(1) - s.logger.Debug( - "dropped envelope", - "ch_id", chIDStr, - "priority", pqEnv.priority, - "msg_size", pqEnv.size, - "capacity", s.capacity, - ) - } - } - - // dequeue - - for s.pq.Len() > 0 { - pqEnv = heap.Pop(s.pq).(*pqEnvelope) - s.size -= pqEnv.size - - // deduct the Envelope size from all the relevant cumulative sizes - for i := 0; i < len(s.chDescs) && pqEnv.priority <= uint(s.chDescs[i].Priority); i++ { - s.sizes[uint(s.chDescs[i].Priority)] -= pqEnv.size - } - - s.metrics.PeerSendBytesTotal.With( - "chID", chIDStr, - "peer_id", string(pqEnv.envelope.To), - "message_type", s.lc.ValueToMetricLabel(pqEnv.envelope.Message)).Add(float64(pqEnv.size)) - s.metrics.PeerPendingSendBytes.With( - "peer_id", string(pqEnv.envelope.To)).Add(float64(-pqEnv.size)) - select { - case s.dequeueCh <- pqEnv.envelope: - case <-s.closeCh: - return - } - } - case <-ctx.Done(): - return - case <-s.closeCh: - return - } - } -} - -func (s *pqScheduler) push(pqEnv *pqEnvelope) { - // enqueue the incoming Envelope - heap.Push(s.pq, pqEnv) - s.size += pqEnv.size - s.metrics.PeerQueueMsgSize.With("ch_id", strconv.Itoa(int(pqEnv.envelope.ChannelID))).Add(float64(pqEnv.size)) - - // Update the cumulative sizes by adding the Envelope's size to every - // priority less than or equal to it. - for i := 0; i < len(s.chDescs) && pqEnv.priority <= uint(s.chDescs[i].Priority); i++ { - s.sizes[uint(s.chDescs[i].Priority)] += pqEnv.size - } -} diff --git a/internal/p2p/pqueue_test.go b/internal/p2p/pqueue_test.go deleted file mode 100644 index 614954589..000000000 --- a/internal/p2p/pqueue_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package p2p - -import ( - "testing" - "time" - - gogotypes "github.com/gogo/protobuf/types" - - "github.com/tendermint/tendermint/libs/log" -) - -type testMessage = gogotypes.StringValue - -func TestCloseWhileDequeueFull(t *testing.T) { - enqueueLength := 5 - chDescs := []*ChannelDescriptor{ - {ID: 0x01, Priority: 1}, - } - pqueue := newPQScheduler(log.NewNopLogger(), NopMetrics(), newMetricsLabelCache(), chDescs, uint(enqueueLength), 1, 120) - - for i := 0; i < enqueueLength; i++ { - pqueue.enqueue() <- Envelope{ - ChannelID: 0x01, - Message: &testMessage{Value: "foo"}, // 5 bytes - } - } - - ctx := t.Context() - - go pqueue.process(ctx) - - // sleep to allow context switch for process() to run - time.Sleep(10 * time.Millisecond) - doneCh := make(chan struct{}) - go func() { - pqueue.close() - close(doneCh) - }() - - select { - case <-doneCh: - case <-time.After(2 * time.Second): - t.Fatal("pqueue failed to close") - } -} diff --git a/internal/p2p/queue.go b/internal/p2p/queue.go deleted file mode 100644 index 2ce2f23fe..000000000 --- a/internal/p2p/queue.go +++ /dev/null @@ -1,53 +0,0 @@ -package p2p - -import ( - "sync" -) - -// default capacity for the size of a queue -const defaultCapacity uint = 16e6 // ~16MB - -// queue does QoS scheduling for Envelopes, enqueueing and dequeueing according -// to some policy. Queues are used at contention points, i.e.: -// -// - Receiving inbound messages to a single channel from all peers. -// - Sending outbound messages to a single peer from all channels. -type queue interface { - // enqueue returns a channel for submitting envelopes. - enqueue() chan<- Envelope - - // dequeue returns a channel ordered according to some queueing policy. - dequeue() <-chan Envelope - - // close closes the queue. After this call enqueue() will block, so the - // caller must select on closed() as well to avoid blocking forever. The - // enqueue() and dequeue() channels will not be closed. - close() - - // closed returns a channel that's closed when the scheduler is closed. - closed() <-chan struct{} -} - -// fifoQueue is a simple unbuffered lossless queue that passes messages through -// in the order they were received, and blocks until message is received. -type fifoQueue struct { - queueCh chan Envelope - closeFn func() - closeCh <-chan struct{} -} - -func newFIFOQueue(size int) queue { - closeCh := make(chan struct{}) - once := &sync.Once{} - - return &fifoQueue{ - queueCh: make(chan Envelope, size), - closeFn: func() { once.Do(func() { close(closeCh) }) }, - closeCh: closeCh, - } -} - -func (q *fifoQueue) enqueue() chan<- Envelope { return q.queueCh } -func (q *fifoQueue) dequeue() <-chan Envelope { return q.queueCh } -func (q *fifoQueue) close() { q.closeFn() } -func (q *fifoQueue) closed() <-chan struct{} { return q.closeCh } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 4eef81b5a..6a0496d19 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -16,6 +16,7 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils" "github.com/tendermint/tendermint/libs/service" "github.com/tendermint/tendermint/types" ) @@ -143,7 +144,7 @@ type Router struct { connTracker connectionTracker peerMtx sync.RWMutex - peerQueues map[types.NodeID]queue // outbound messages per peer for all channels + peerQueues map[types.NodeID]*queue // outbound messages per peer for all channels // the channels that the peer queue has open peerChannels map[types.NodeID]ChannelIDSet nodeInfoProducer func() *types.NodeInfo @@ -152,7 +153,7 @@ type Router struct { // channels on router start. This depends on whether we want to allow // dynamic channels in the future. channelMtx sync.RWMutex - channelQueues map[ChannelID]queue // inbound messages from all peers to a single channel + channelQueues map[ChannelID]*queue // inbound messages from all peers to a single channel channelMessages map[ChannelID]proto.Message chDescsToBeAdded []chDescAdderWithCallback @@ -199,9 +200,9 @@ func NewRouter( endpoint: endpoint, peerManager: peerManager, options: options, - channelQueues: map[ChannelID]queue{}, + channelQueues: map[ChannelID]*queue{}, channelMessages: map[ChannelID]proto.Message{}, - peerQueues: map[types.NodeID]queue{}, + peerQueues: map[types.NodeID]*queue{}, peerChannels: make(map[types.NodeID]ChannelIDSet), dynamicIDFilterer: dynamicIDFilterer, } @@ -234,10 +235,10 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C messageType := chDesc.MessageType - queue := newSimplePriorityQueue(chDesc.RecvBufferCapacity, r.chDescs) + queue := newQueue(chDesc.RecvBufferCapacity) outCh := make(chan Envelope, chDesc.RecvBufferCapacity) errCh := make(chan PeerError, chDesc.RecvBufferCapacity) - channel := NewChannel(id, queue.dequeue(), outCh, errCh) + channel := NewChannel(id, queue, outCh, errCh) channel.name = chDesc.Name var wrapper Wrapper @@ -259,10 +260,9 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C delete(r.channelQueues, id) delete(r.channelMessages, id) r.channelMtx.Unlock() - queue.close() }() - r.routeChannel(ctx, id, outCh, errCh, wrapper) + r.routeChannel(ctx, chDesc, outCh, errCh, wrapper) }() return channel, nil @@ -275,7 +275,7 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C // for messages, see Wrapper for details. func (r *Router) routeChannel( ctx context.Context, - chID ChannelID, + chDesc *ChannelDescriptor, outCh <-chan Envelope, errCh <-chan PeerError, wrapper Wrapper, @@ -292,13 +292,13 @@ func (r *Router) routeChannel( // Mark the envelope with the channel ID to allow sendPeer() to pass // it on to Transport.SendMessage(). - envelope.ChannelID = chID + envelope.ChannelID = chDesc.ID // wrap the message in a wrapper message, if requested if wrapper != nil { msg := proto.Clone(wrapper) if err := msg.(Wrapper).Wrap(envelope.Message); err != nil { - r.logger.Error("failed to wrap message", "channel", chID, "err", err) + r.logger.Error("failed to wrap message", "channel", chDesc.ID, "err", err) continue } @@ -306,16 +306,16 @@ func (r *Router) routeChannel( } // collect peer queues to pass the message via - var queues []queue + var queues []*queue if envelope.Broadcast { r.peerMtx.RLock() - queues = make([]queue, 0, len(r.peerQueues)) + queues = make([]*queue, 0, len(r.peerQueues)) for nodeID, q := range r.peerQueues { peerChs := r.peerChannels[nodeID] // check whether the peer is receiving on that channel - if _, ok := peerChs[chID]; ok { + if _, ok := peerChs[chDesc.ID]; ok { queues = append(queues, q) } } @@ -330,12 +330,12 @@ func (r *Router) routeChannel( peerChs := r.peerChannels[envelope.To] // check whether the peer is receiving on that channel - _, contains = peerChs[chID] + _, contains = peerChs[chDesc.ID] } r.peerMtx.RUnlock() if !ok { - r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chID) + r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chDesc.ID) continue } @@ -347,23 +347,12 @@ func (r *Router) routeChannel( continue } - queues = []queue{q} + queues = []*queue{q} } // send message to peers for _, q := range queues { - start := time.Now().UTC() - - select { - case q.enqueue() <- envelope: - r.metrics.RouterPeerQueueSend.Observe(time.Since(start).Seconds()) - - case <-q.closed(): - r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chID) - - case <-ctx.Done(): - return - } + q.Send(envelope, chDesc.Priority) } case peerError, ok := <-errCh: @@ -636,7 +625,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) } -func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) queue { +func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) *queue { r.peerMtx.Lock() defer r.peerMtx.Unlock() @@ -644,7 +633,7 @@ func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) queu return peerQueue } - peerQueue := newSimplePriorityQueue(queueBufferDefault, r.chDescs) + peerQueue := newQueue(queueBufferDefault) r.peerQueues[peerID] = peerQueue r.peerChannels[peerID] = channels return peerQueue @@ -769,8 +758,6 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec delete(r.peerChannels, peerID) r.peerMtx.Unlock() - sendQueue.close() - r.peerManager.Disconnected(ctx, peerID) r.metrics.Peers.Add(-1) }() @@ -788,7 +775,7 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec go func() { select { - case errCh <- r.sendPeer(ctx, peerID, conn, sendQueue): + case errCh <- utils.IgnoreCancel(r.sendPeer(ctx, peerID, conn, sendQueue)): case <-ctx.Done(): } }() @@ -800,7 +787,6 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec } _ = conn.Close() - sendQueue.close() select { case <-ctx.Done(): @@ -859,57 +845,41 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn } start := time.Now().UTC() - - select { - case queue.enqueue() <- Envelope{From: peerID, Message: msg, ChannelID: chID}: - r.metrics.PeerReceiveBytesTotal.With( - "chID", fmt.Sprint(chID), - "peer_id", string(peerID), - "message_type", r.lc.ValueToMetricLabel(msg)).Add(float64(proto.Size(msg))) - r.metrics.RouterChannelQueueSend.Observe(time.Since(start).Seconds()) - r.logger.Debug("received message", "peer", peerID, "message", msg) - - case <-queue.closed(): - r.logger.Debug("channel closed, dropping message", "peer", peerID, "channel", chID) - - case <-ctx.Done(): - return nil - } + // Priority is not used since all messages in this queue are from the same channel. + queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID},0) + r.metrics.PeerReceiveBytesTotal.With( + "chID", fmt.Sprint(chID), + "peer_id", string(peerID), + "message_type", r.lc.ValueToMetricLabel(msg)).Add(float64(proto.Size(msg))) + r.metrics.RouterChannelQueueSend.Observe(time.Since(start).Seconds()) + r.logger.Debug("received message", "peer", peerID, "message", msg) } } // sendPeer sends queued messages to a peer. -func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue queue) error { +func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue *queue) error { for { start := time.Now().UTC() + envelope,err := peerQueue.Recv(ctx) + if err!=nil { return err } + r.metrics.RouterPeerQueueRecv.Observe(time.Since(start).Seconds()) + if envelope.Message == nil { + r.logger.Error("dropping nil message", "peer", peerID) + continue + } - select { - case envelope := <-peerQueue.dequeue(): - r.metrics.RouterPeerQueueRecv.Observe(time.Since(start).Seconds()) - if envelope.Message == nil { - r.logger.Error("dropping nil message", "peer", peerID) - continue - } - - bz, err := proto.Marshal(envelope.Message) - if err != nil { - r.logger.Error("failed to marshal message", "peer", peerID, "err", err) - continue - } - - if err = conn.SendMessage(ctx, envelope.ChannelID, bz); err != nil { - r.logger.Error("failed to send message", "peer", peerID, "err", err) - return err - } - - r.logger.Debug("sent message", "peer", envelope.To, "message", envelope.Message) - - case <-peerQueue.closed(): - return nil + bz, err := proto.Marshal(envelope.Message) + if err != nil { + r.logger.Error("failed to marshal message", "peer", peerID, "err", err) + continue + } - case <-ctx.Done(): - return nil + if err = conn.SendMessage(ctx, envelope.ChannelID, bz); err != nil { + r.logger.Error("failed to send message", "peer", peerID, "err", err) + return err } + + r.logger.Debug("sent message", "peer", envelope.To, "message", envelope.Message) } } @@ -927,14 +897,7 @@ func (r *Router) evictPeers(ctx context.Context) { } r.logger.Info("evicting peer", "peer", peerID) - - r.peerMtx.RLock() - queue, ok := r.peerQueues[peerID] - r.peerMtx.RUnlock() - - if ok { - queue.close() - } + // TODO: we need to cancel the peer here. } } @@ -977,26 +940,6 @@ func (r *Router) OnStop() { if err := r.transport.Close(); err != nil { r.logger.Error("failed to close transport", "err", err) } - - // Collect all remaining queues, and wait for them to close. - queues := []queue{} - - r.channelMtx.RLock() - for _, q := range r.channelQueues { - queues = append(queues, q) - } - r.channelMtx.RUnlock() - - r.peerMtx.RLock() - for _, q := range r.peerQueues { - queues = append(queues, q) - } - r.peerMtx.RUnlock() - - for _, q := range queues { - q.close() - <-q.closed() - } } type ChannelIDSet map[ChannelID]struct{} diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index 5c83530ea..647bb3d28 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -71,7 +71,7 @@ func (x *byMax[T]) Pop() any { // into a priority queue used for Envelope scheduling. type pqEnvelope struct { envelope Envelope - priority uint + priority int size int timestamp time.Time } @@ -124,14 +124,14 @@ func (i *inner) Pop() *pqEnvelope { return w.v } -type simpleQueue struct { inner utils.Watch[*inner] } +type queue struct { inner utils.Watch[*inner] } -func newSimplePriorityQueue(size int) *simpleQueue { - return &simpleQueue{inner: utils.NewWatch(newInner(size))} +func newQueue(size int) *queue { + return &queue{inner: utils.NewWatch(newInner(size))} } // Non-blocking send. -func (q *simpleQueue) Send(e Envelope, priority uint) { +func (q *queue) Send(e Envelope, priority int) { // We construct the pqEnvelope without holding the lock to avoid contention. pqe := &pqEnvelope{ envelope: e, @@ -146,7 +146,7 @@ func (q *simpleQueue) Send(e Envelope, priority uint) { } // Blocking recv. -func (q *simpleQueue) Recv(ctx context.Context) (Envelope,error) { +func (q *queue) Recv(ctx context.Context) (Envelope,error) { for inner,ctrl := range q.inner.Lock() { if err:=ctrl.WaitUntil(ctx,func() bool { return inner.Len()>0 }); err!=nil { return Envelope{},err diff --git a/node/node.go b/node/node.go index d6e43987a..6ed7ef5c1 100644 --- a/node/node.go +++ b/node/node.go @@ -809,9 +809,7 @@ func LoadStateFromDBOrGenesisDocProvider(stateStore sm.Store, genDoc *types.Gene } func getRouterConfig(conf *config.Config, appClient abciclient.Client) p2p.RouterOptions { - opts := p2p.RouterOptions{ - QueueType: conf.P2P.QueueType, - } + opts := p2p.RouterOptions{} if conf.FilterPeers && appClient != nil { opts.FilterPeerByID = func(ctx context.Context, id types.NodeID) error { From 257fa17731cd85ff97b85480b15d9549570526e4 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 17:00:21 +0200 Subject: [PATCH 14/41] binary compiles --- internal/p2p/router.go | 312 ++++++++++++++++------------------------- internal/p2p/rqueue.go | 3 +- libs/utils/mutex.go | 35 +++++ libs/utils/proto.go | 2 +- libs/utils/testonly.go | 2 +- 5 files changed, 157 insertions(+), 197 deletions(-) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 6a0496d19..59b691080 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -17,6 +17,7 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" "github.com/tendermint/tendermint/libs/service" "github.com/tendermint/tendermint/types" ) @@ -89,6 +90,12 @@ func (o *RouterOptions) Validate() error { return nil } +type peerState struct { + cancel context.CancelFunc + queue *queue // outbound messages per peer for all channels + channels ChannelIDSet // the channels that the peer queue has open +} + // Router manages peer connections and routes messages between peers and reactor // channels. It takes a PeerManager for peer lifecycle management (e.g. which // peers to dial and when) and a set of Transports for connecting and @@ -143,10 +150,7 @@ type Router struct { endpoint *Endpoint connTracker connectionTracker - peerMtx sync.RWMutex - peerQueues map[types.NodeID]*queue // outbound messages per peer for all channels - // the channels that the peer queue has open - peerChannels map[types.NodeID]ChannelIDSet + peerStates utils.RWMutex[map[types.NodeID]*peerState] nodeInfoProducer func() *types.NodeInfo // FIXME: We don't strictly need to use a mutex for this if we seal the @@ -202,8 +206,7 @@ func NewRouter( options: options, channelQueues: map[ChannelID]*queue{}, channelMessages: map[ChannelID]proto.Message{}, - peerQueues: map[types.NodeID]*queue{}, - peerChannels: make(map[types.NodeID]ChannelIDSet), + peerStates: utils.NewRWMutex(map[types.NodeID]*peerState{}), dynamicIDFilterer: dynamicIDFilterer, } @@ -254,17 +257,29 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C r.transport.AddChannelDescriptors([]*ChannelDescriptor{chDesc}) - go func() { - defer func() { - r.channelMtx.Lock() - delete(r.channelQueues, id) - delete(r.channelMessages, id) - r.channelMtx.Unlock() - }() - - r.routeChannel(ctx, chDesc, outCh, errCh, wrapper) - }() - + r.Spawn("channel", func(ctx context.Context) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { return r.routeChannel(ctx, chDesc, outCh, wrapper) }) + for { + peerError,err := utils.Recv(ctx,errCh) + if err!=nil { return err } + shouldEvict := peerError.Fatal || r.peerManager.HasMaxPeerCapacity() + r.logger.Error("peer error", + "peer", peerError.NodeID, + "err", peerError.Err, + "evicting", shouldEvict, + ) + if shouldEvict { + r.peerManager.Errored(peerError.NodeID, peerError.Err) + } else { + r.peerManager.processPeerEvent(ctx, PeerUpdate{ + NodeID: peerError.NodeID, + Status: PeerStatusBad, + }) + } + } + }) + }) return channel, nil } @@ -277,106 +292,63 @@ func (r *Router) routeChannel( ctx context.Context, chDesc *ChannelDescriptor, outCh <-chan Envelope, - errCh <-chan PeerError, wrapper Wrapper, -) { +) error { for { - select { - case envelope, ok := <-outCh: - if !ok { - return - } - if envelope.IsZero() { - continue - } - - // Mark the envelope with the channel ID to allow sendPeer() to pass - // it on to Transport.SendMessage(). - envelope.ChannelID = chDesc.ID + envelope,err := utils.Recv(ctx, outCh) + if err!=nil { return err } + if envelope.IsZero() { + continue + } - // wrap the message in a wrapper message, if requested - if wrapper != nil { - msg := proto.Clone(wrapper) - if err := msg.(Wrapper).Wrap(envelope.Message); err != nil { - r.logger.Error("failed to wrap message", "channel", chDesc.ID, "err", err) - continue - } + // Mark the envelope with the channel ID to allow sendPeer() to pass + // it on to Transport.SendMessage(). + envelope.ChannelID = chDesc.ID - envelope.Message = msg + // wrap the message in a wrapper message, if requested + if wrapper != nil { + msg := proto.Clone(wrapper) + if err := msg.(Wrapper).Wrap(envelope.Message); err != nil { + r.logger.Error("failed to wrap message", "channel", chDesc.ID, "err", err) + continue } - // collect peer queues to pass the message via - var queues []*queue - if envelope.Broadcast { - r.peerMtx.RLock() - - queues = make([]*queue, 0, len(r.peerQueues)) - for nodeID, q := range r.peerQueues { - peerChs := r.peerChannels[nodeID] + envelope.Message = msg + } - // check whether the peer is receiving on that channel - if _, ok := peerChs[chDesc.ID]; ok { - queues = append(queues, q) + // collect peer queues to pass the message via + var queues []*queue + if envelope.Broadcast { + for states := range r.peerStates.RLock() { + queues = make([]*queue, 0, len(states)) + for _, s := range states { + if _, ok := s.channels[chDesc.ID]; ok { + queues = append(queues, s.queue) } } - - r.peerMtx.RUnlock() - } else { - r.peerMtx.RLock() - - q, ok := r.peerQueues[envelope.To] - contains := false - if ok { - peerChs := r.peerChannels[envelope.To] - - // check whether the peer is receiving on that channel - _, contains = peerChs[chDesc.ID] - } - r.peerMtx.RUnlock() - - if !ok { - r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chDesc.ID) - continue - } - - if !contains { - // reactor tried to send a message across a channel that the - // peer doesn't have available. This is a known issue due to - // how peer subscriptions work: - // https://github.com/tendermint/tendermint/issues/6598 - continue - } - - queues = []*queue{q} } - - // send message to peers - for _, q := range queues { - q.Send(envelope, chDesc.Priority) + } else { + ok := false + var s *peerState + for states := range r.peerStates.RLock() { + s,ok = states[envelope.To] } - - case peerError, ok := <-errCh: if !ok { - return + r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chDesc.ID) + continue } - - shouldEvict := peerError.Fatal || r.peerManager.HasMaxPeerCapacity() - r.logger.Error("peer error", - "peer", peerError.NodeID, - "err", peerError.Err, - "evicting", shouldEvict, - ) - if shouldEvict { - r.peerManager.Errored(peerError.NodeID, peerError.Err) - } else { - r.peerManager.processPeerEvent(ctx, PeerUpdate{ - NodeID: peerError.NodeID, - Status: PeerStatusBad, - }) + if _, contains := s.channels[chDesc.ID]; !contains { + // reactor tried to send a message across a channel that the + // peer doesn't have available. This is a known issue due to + // how peer subscriptions work: + // https://github.com/tendermint/tendermint/issues/6598 + continue } - - case <-ctx.Done(): - return + queues = []*queue{s.queue} + } + // send message to peers + for _, q := range queues { + q.Send(envelope, chDesc.Priority) } } } @@ -467,12 +439,11 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) { } // Spawn a goroutine for the handshake, to avoid head-of-line blocking. - go r.openConnection(ctx, conn) - + r.Spawn("openConnection",func(ctx context.Context) error { return r.openConnection(ctx, conn) }) } } -func (r *Router) openConnection(ctx context.Context, conn Connection) { +func (r *Router) openConnection(ctx context.Context, conn Connection) error { defer conn.Close() defer r.connTracker.RemoveConn(conn.RemoteEndpoint().IP) @@ -481,7 +452,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { if err := r.filterPeersIP(ctx, incomingIP, re.Port); err != nil { r.logger.Debug("peer filtered by IP", "ip", incomingIP.String(), "err", err) - return + return nil } // FIXME: The peer manager may reject the peer during Accepted() @@ -500,16 +471,12 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { // message to make sure both ends have accepted the connection, such // that it can be coordinated with the peer manager. peerInfo, err := r.handshakePeer(ctx, conn, "") - switch { - case errors.Is(err, context.Canceled): - return - case err != nil: - r.logger.Error("peer handshake failed", "endpoint", conn, "err", err) - return + if err!=nil { + return fmt.Errorf("peer handshake failed: endpoint=%v: %w", conn, err) } if err := r.filterPeersID(ctx, peerInfo.NodeID); err != nil { r.logger.Debug("peer filtered by node ID", "node", peerInfo.NodeID, "err", err) - return + return nil } if err := r.runWithPeerMutex(func() error { return r.peerManager.Accepted(peerInfo.NodeID) }); err != nil { @@ -517,12 +484,9 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { if strings.Contains(err.Error(), "is already connected") { r.peerManager.Errored(peerInfo.NodeID, err) } - r.logger.Error("failed to accept connection", - "op", "incoming/accepted", "peer", peerInfo.NodeID, "err", err) - return + return fmt.Errorf("failed to accept connection: op=incoming/accepted, peer=%v: %w",peerInfo.NodeID,err) } - - r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) + return r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) } // dialPeers maintains outbound connections to peers by dialing them. @@ -621,22 +585,10 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return } - // routePeer (also) calls connection close - go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) -} - -func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) *queue { - r.peerMtx.Lock() - defer r.peerMtx.Unlock() - - if peerQueue, ok := r.peerQueues[peerID]; ok { - return peerQueue - } - - peerQueue := newQueue(queueBufferDefault) - r.peerQueues[peerID] = peerQueue - r.peerChannels[peerID] = channels - return peerQueue + r.Spawn("routePeer",func(ctx context.Context) error { + defer conn.Close() + return r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) + }) } // dialPeer connects to a peer by dialing it. @@ -739,76 +691,44 @@ func (r *Router) handshakePeer( } func (r *Router) runWithPeerMutex(fn func() error) error { - r.peerMtx.Lock() - defer r.peerMtx.Unlock() - return fn() + for range r.peerStates.Lock() { + return fn() + } + panic("unreachable") } // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. -func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) { +func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) error { r.metrics.Peers.Add(1) r.peerManager.Ready(ctx, peerID, channels) - sendQueue := r.getOrMakeQueue(peerID, channels) - defer func() { - r.peerMtx.Lock() - delete(r.peerQueues, peerID) - delete(r.peerChannels, peerID) - r.peerMtx.Unlock() - - r.peerManager.Disconnected(ctx, peerID) - r.metrics.Peers.Add(-1) - }() - - r.logger.Debug("peer connected", "peer", peerID, "endpoint", conn) - - errCh := make(chan error, 2) - - go func() { - select { - case errCh <- r.receivePeer(ctx, peerID, conn): - case <-ctx.Done(): - } - }() - - go func() { - select { - case errCh <- utils.IgnoreCancel(r.sendPeer(ctx, peerID, conn, sendQueue)): - case <-ctx.Done(): - } - }() - - var err error - select { - case err = <-errCh: - case <-ctx.Done(): + ctx,cancel := context.WithCancel(ctx) + state := &peerState{ + cancel: cancel, + queue: newQueue(queueBufferDefault), + channels: channels, } - - _ = conn.Close() - - select { - case <-ctx.Done(): - case e := <-errCh: - // The first err was nil, so we update it with the second err, which may - // or may not be nil. - if err == nil { - err = e + for states := range r.peerStates.Lock() { + if old,ok := states[peerID]; ok { + old.cancel() } + states[peerID] = state } - - // if the context was canceled - if e := ctx.Err(); err == nil && e != nil { - err = e - } - - switch err { - case nil, io.EOF: - r.logger.Debug("peer disconnected", "peer", peerID, "endpoint", conn) - default: - r.logger.Error("peer failure", "peer", peerID, "endpoint", conn, "err", err) + r.logger.Debug("peer connected", "peer", peerID, "endpoint", conn) + err := scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { return r.receivePeer(ctx, peerID, conn) }) + s.Spawn(func() error { return r.sendPeer(ctx, peerID, conn, state.queue) }) + return nil + }) + r.logger.Info("peer disconnected", "peer", peerID, "endpoint", conn, "err", err) + for states := range r.peerStates.Lock() { + delete(states, peerID) } + r.peerManager.Disconnected(ctx, peerID) + r.metrics.Peers.Add(-1) + return err } // receivePeer receives inbound messages from a peer, deserializes them and @@ -897,7 +817,11 @@ func (r *Router) evictPeers(ctx context.Context) { } r.logger.Info("evicting peer", "peer", peerID) - // TODO: we need to cancel the peer here. + for states := range r.peerStates.Lock() { + if s,ok := states[peerID]; ok { + s.cancel() + } + } } } diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index 647bb3d28..5b042381e 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -127,7 +127,8 @@ func (i *inner) Pop() *pqEnvelope { type queue struct { inner utils.Watch[*inner] } func newQueue(size int) *queue { - return &queue{inner: utils.NewWatch(newInner(size))} + // TODO(gprusak): this size*size looks ridiculous. Fix it. + return &queue{inner: utils.NewWatch(newInner(size*size))} } // Non-blocking send. diff --git a/libs/utils/mutex.go b/libs/utils/mutex.go index b6f4a9a58..0bd9bb722 100644 --- a/libs/utils/mutex.go +++ b/libs/utils/mutex.go @@ -33,6 +33,41 @@ func (m *Mutex[T]) Lock() iter.Seq[T] { } } +// Mutex guards access to object of type T. +type RWMutex[T any] struct { + mu sync.RWMutex + value T +} + +// NewMutex creates a new Mutex with given object. +func NewRWMutex[T any](value T) (m RWMutex[T]) { + m.value = value + // nolint:nakedret + return +} + +// Lock returns an iterator which locks the mutex and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *RWMutex[T]) Lock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.Lock() + defer m.mu.Unlock() + _ = yield(m.value) + } +} + +// RLock returns an iterator which locks the mutex FOR READ and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *RWMutex[T]) RLock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + _ = yield(m.value) + } +} + // version of the value stored in an atomic watch. type version[T any] struct { updated chan struct{} diff --git a/libs/utils/proto.go b/libs/utils/proto.go index 5f5ad7a41..4593c9634 100644 --- a/libs/utils/proto.go +++ b/libs/utils/proto.go @@ -6,7 +6,7 @@ import ( "fmt" "sync" - "google.golang.org/protobuf/proto" + "github.com/gogo/protobuf/proto" ) // Hash is a SHA-256 hash. diff --git a/libs/utils/testonly.go b/libs/utils/testonly.go index afd6b8aa8..ca5884cb6 100644 --- a/libs/utils/testonly.go +++ b/libs/utils/testonly.go @@ -9,7 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/protobuf/proto" + "github.com/gogo/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) From d42991dd5efe01e917ebe74cb03d5f9560b2d75e Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 17:14:24 +0200 Subject: [PATCH 15/41] tests wip --- internal/p2p/channel.go | 4 +-- internal/p2p/channel_test.go | 12 ++++---- internal/p2p/pex/reactor_test.go | 16 +++++----- internal/p2p/router.go | 18 ++++++------ internal/p2p/router_init_test.go | 50 -------------------------------- internal/p2p/rqueue.go | 12 ++++---- internal/p2p/rqueue_test.go | 32 +++++++------------- 7 files changed, 41 insertions(+), 103 deletions(-) diff --git a/internal/p2p/channel.go b/internal/p2p/channel.go index 388d05e70..7a096453f 100644 --- a/internal/p2p/channel.go +++ b/internal/p2p/channel.go @@ -61,7 +61,7 @@ func (pe PeerError) Unwrap() error { return pe.Err } // Each message is wrapped in an Envelope to specify its sender and receiver. type Channel struct { ID ChannelID - inCh *queue // inbound messages (peers to reactors) + inCh *Queue // inbound messages (peers to reactors) outCh chan<- Envelope // outbound messages (reactors to peers) errCh chan<- PeerError // peer error reporting @@ -70,7 +70,7 @@ type Channel struct { // NewChannel creates a new channel. It is primarily for internal and test // use, reactors should use Router.OpenChannel(). -func NewChannel(id ChannelID, inCh *queue, outCh chan<- Envelope, errCh chan<- PeerError) *Channel { +func NewChannel(id ChannelID, inCh *Queue, outCh chan<- Envelope, errCh chan<- PeerError) *Channel { return &Channel{ ID: id, inCh: inCh, diff --git a/internal/p2p/channel_test.go b/internal/p2p/channel_test.go index 4bbe178ac..4b5164b1c 100644 --- a/internal/p2p/channel_test.go +++ b/internal/p2p/channel_test.go @@ -11,14 +11,14 @@ import ( ) type channelInternal struct { - In chan Envelope + In *Queue Out chan Envelope Error chan PeerError } func testChannel(size int) (*channelInternal, *Channel) { in := &channelInternal{ - In: make(chan Envelope, size), + In: NewQueue(size), Out: make(chan Envelope, size), Error: make(chan PeerError, size), } @@ -112,7 +112,7 @@ func TestChannel(t *testing.T) { Case: func(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"},0) iter := ch.Receive(ctx) require.NotNil(t, iter) require.True(t, iter.Next(ctx)) @@ -157,7 +157,7 @@ func TestChannel(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"},0) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -180,7 +180,7 @@ func TestChannel(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"},0) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -204,7 +204,7 @@ func TestChannel(t *testing.T) { require.NotNil(t, iter) require.Nil(t, iter.Envelope()) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"},0) require.NotNil(t, iter) require.True(t, iter.Next(ctx)) diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 49e860130..19d1c155e 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -76,20 +76,20 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { badNode := newNodeID(t, "b") - r.pexInCh <- p2p.Envelope{ + r.pexInCh.Send(p2p.Envelope{ From: badNode, Message: &p2pproto.PexRequest{}, - } + },0) resp := <-r.pexOutCh msg, ok := resp.Message.(*p2pproto.PexResponse) require.True(t, ok) require.Empty(t, msg.Addresses) - r.pexInCh <- p2p.Envelope{ + r.pexInCh.Send(p2p.Envelope{ From: badNode, Message: &p2pproto.PexRequest{}, - } + },0) peerErr := <-r.pexErrCh require.Error(t, peerErr.Err) @@ -170,12 +170,12 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { if _, ok := req.Message.(*p2pproto.PexRequest); !ok { t.Fatal("expected v2 pex request") } - r.pexInCh <- p2p.Envelope{ + r.pexInCh.Send(p2p.Envelope{ From: peer.NodeID, Message: &p2pproto.PexResponse{ Addresses: addresses, }, - } + },0) case <-time.After(10 * time.Second): t.Fatal("pex failed to send a request within 10 seconds") @@ -266,7 +266,7 @@ func TestReactorWithNetworkGrowth(t *testing.T) { type singleTestReactor struct { reactor *pex.Reactor - pexInCh chan p2p.Envelope + pexInCh *p2p.Queue pexOutCh chan p2p.Envelope pexErrCh chan p2p.PeerError pexCh *p2p.Channel @@ -278,7 +278,7 @@ func setupSingle(ctx context.Context, t *testing.T) *singleTestReactor { t.Helper() nodeID := newNodeID(t, "a") chBuf := 2 - pexInCh := make(chan p2p.Envelope, chBuf) + pexInCh := p2p.NewQueue(chBuf) pexOutCh := make(chan p2p.Envelope, chBuf) pexErrCh := make(chan p2p.PeerError, chBuf) pexCh := p2p.NewChannel( diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 59b691080..8c4824ad3 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -92,7 +92,7 @@ func (o *RouterOptions) Validate() error { type peerState struct { cancel context.CancelFunc - queue *queue // outbound messages per peer for all channels + queue *Queue // outbound messages per peer for all channels channels ChannelIDSet // the channels that the peer queue has open } @@ -157,7 +157,7 @@ type Router struct { // channels on router start. This depends on whether we want to allow // dynamic channels in the future. channelMtx sync.RWMutex - channelQueues map[ChannelID]*queue // inbound messages from all peers to a single channel + channelQueues map[ChannelID]*Queue // inbound messages from all peers to a single channel channelMessages map[ChannelID]proto.Message chDescsToBeAdded []chDescAdderWithCallback @@ -204,7 +204,7 @@ func NewRouter( endpoint: endpoint, peerManager: peerManager, options: options, - channelQueues: map[ChannelID]*queue{}, + channelQueues: map[ChannelID]*Queue{}, channelMessages: map[ChannelID]proto.Message{}, peerStates: utils.NewRWMutex(map[types.NodeID]*peerState{}), dynamicIDFilterer: dynamicIDFilterer, @@ -238,7 +238,7 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C messageType := chDesc.MessageType - queue := newQueue(chDesc.RecvBufferCapacity) + queue := NewQueue(chDesc.RecvBufferCapacity) outCh := make(chan Envelope, chDesc.RecvBufferCapacity) errCh := make(chan PeerError, chDesc.RecvBufferCapacity) channel := NewChannel(id, queue, outCh, errCh) @@ -317,10 +317,10 @@ func (r *Router) routeChannel( } // collect peer queues to pass the message via - var queues []*queue + var queues []*Queue if envelope.Broadcast { for states := range r.peerStates.RLock() { - queues = make([]*queue, 0, len(states)) + queues = make([]*Queue, 0, len(states)) for _, s := range states { if _, ok := s.channels[chDesc.ID]; ok { queues = append(queues, s.queue) @@ -344,7 +344,7 @@ func (r *Router) routeChannel( // https://github.com/tendermint/tendermint/issues/6598 continue } - queues = []*queue{s.queue} + queues = []*Queue{s.queue} } // send message to peers for _, q := range queues { @@ -707,7 +707,7 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec ctx,cancel := context.WithCancel(ctx) state := &peerState{ cancel: cancel, - queue: newQueue(queueBufferDefault), + queue: NewQueue(queueBufferDefault), channels: channels, } for states := range r.peerStates.Lock() { @@ -777,7 +777,7 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn } // sendPeer sends queued messages to a peer. -func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue *queue) error { +func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue *Queue) error { for { start := time.Now().UTC() envelope,err := peerQueue.Recv(ctx) diff --git a/internal/p2p/router_init_test.go b/internal/p2p/router_init_test.go index 31d06338f..f2750b153 100644 --- a/internal/p2p/router_init_test.go +++ b/internal/p2p/router_init_test.go @@ -1,64 +1,14 @@ package p2p import ( - "os" "testing" "github.com/stretchr/testify/require" - - "github.com/tendermint/tendermint/libs/log" - "github.com/tendermint/tendermint/types" ) func TestRouter_ConstructQueueFactory(t *testing.T) { - ctx := t.Context() - t.Run("ValidateOptionsPopulatesDefaultQueue", func(t *testing.T) { opts := RouterOptions{} require.NoError(t, opts.Validate()) - require.Equal(t, "fifo", opts.QueueType) - }) - t.Run("Default", func(t *testing.T) { - require.Zero(t, os.Getenv("TM_P2P_QUEUE")) - opts := RouterOptions{} - r, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.NoError(t, err) - require.NoError(t, r.setupQueueFactory(ctx)) - - _, ok := r.queueFactory(1).(*fifoQueue) - require.True(t, ok) - }) - t.Run("Fifo", func(t *testing.T) { - opts := RouterOptions{QueueType: queueTypeFifo} - r, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.NoError(t, err) - require.NoError(t, r.setupQueueFactory(ctx)) - - _, ok := r.queueFactory(1).(*fifoQueue) - require.True(t, ok) - }) - t.Run("Priority", func(t *testing.T) { - opts := RouterOptions{QueueType: queueTypePriority} - r, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.NoError(t, err) - require.NoError(t, r.setupQueueFactory(ctx)) - - q, ok := r.queueFactory(1).(*pqScheduler) - require.True(t, ok) - defer q.close() - }) - t.Run("NonExistant", func(t *testing.T) { - opts := RouterOptions{QueueType: "fast"} - _, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.Error(t, err) - require.Contains(t, err.Error(), "fast") - }) - t.Run("InternalsSafeWhenUnspecified", func(t *testing.T) { - r := &Router{} - require.Zero(t, r.options.QueueType) - - fn, err := r.createQueueFactory(ctx) - require.Error(t, err) - require.Nil(t, fn) }) } diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index 5b042381e..40f700735 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -68,7 +68,7 @@ func (x *byMax[T]) Pop() any { } // pqEnvelope defines a wrapper around an Envelope with priority to be inserted -// into a priority queue used for Envelope scheduling. +// into a priority Queue used for Envelope scheduling. type pqEnvelope struct { envelope Envelope priority int @@ -124,15 +124,15 @@ func (i *inner) Pop() *pqEnvelope { return w.v } -type queue struct { inner utils.Watch[*inner] } +type Queue struct { inner utils.Watch[*inner] } -func newQueue(size int) *queue { +func NewQueue(size int) *Queue { // TODO(gprusak): this size*size looks ridiculous. Fix it. - return &queue{inner: utils.NewWatch(newInner(size*size))} + return &Queue{inner: utils.NewWatch(newInner(size*size))} } // Non-blocking send. -func (q *queue) Send(e Envelope, priority int) { +func (q *Queue) Send(e Envelope, priority int) { // We construct the pqEnvelope without holding the lock to avoid contention. pqe := &pqEnvelope{ envelope: e, @@ -147,7 +147,7 @@ func (q *queue) Send(e Envelope, priority int) { } // Blocking recv. -func (q *queue) Recv(ctx context.Context) (Envelope,error) { +func (q *Queue) Recv(ctx context.Context) (Envelope,error) { for inner,ctrl := range q.inner.Lock() { if err:=ctrl.WaitUntil(ctx,func() bool { return inner.Len()>0 }); err!=nil { return Envelope{},err diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 4e5e4c8bf..47bd146e4 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "testing" "time" ) @@ -11,35 +12,22 @@ func TestSimpleQueue(t *testing.T) { // set up a small queue with very small buffers so we can // watch it shed load, then send a bunch of messages to the // queue, most of which we'll watch it drop. - sq := newSimplePriorityQueue(ctx, 1, nil) - for i := 0; i < 100; i++ { - sq.enqueue() <- Envelope{From: "merlin"} + sq := NewQueue(1) + for range 100 { + sq.Send(Envelope{From: "merlin"},0) } seen := 0 -RETRY: for seen <= 2 { - select { - case e := <-sq.dequeue(): - if e.From != "merlin" { - continue - } - seen++ - case <-time.After(10 * time.Millisecond): - break RETRY + ctx,cancel := context.WithTimeout(ctx,10 * time.Millisecond) + defer cancel() + if _,err:=sq.Recv(ctx); err!=nil { + break } } // if we don't see any messages, then it's just broken. - if seen == 0 { - t.Errorf("seen %d messages, should have seen more than one", seen) + if seen != 1 { + t.Errorf("seen %d messages, should have seen %v", seen, 1) } - // ensure that load shedding happens: there can be at most 3 - // messages that we get out of this, one that was buffered - // plus 2 that were under the cap, everything else gets - // dropped. - if seen > 3 { - t.Errorf("saw %d messages, should have seen 5 or fewer", seen) - } - } From 910fe80b51e4a1ca39ab0f52b1dee851dd4bd8de Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 18:44:41 +0200 Subject: [PATCH 16/41] before making OpenChannel private --- internal/p2p/router.go | 19 +++---- internal/p2p/router_test.go | 24 ++++---- internal/statesync/reactor_test.go | 91 ++++++++++++------------------ light/dispatcher_test.go | 10 ++-- 4 files changed, 61 insertions(+), 83 deletions(-) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 8c4824ad3..c75d3a4ef 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -479,7 +479,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { return nil } - if err := r.runWithPeerMutex(func() error { return r.peerManager.Accepted(peerInfo.NodeID) }); err != nil { + if err := r.peerManager.Accepted(peerInfo.NodeID); err != nil { // If peer is trying to reconnect, error and let it reconnect if strings.Contains(err.Error(), "is already connected") { r.peerManager.Errored(peerInfo.NodeID, err) @@ -572,7 +572,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return } - if err := r.runWithPeerMutex(func() error { return r.peerManager.Dialed(address) }); err != nil { + if err := r.peerManager.Dialed(address); err != nil { // If peer is trying to reconnect, fail it and let it reconnect if strings.Contains(err.Error(), "is already connected") { r.logger.Error(fmt.Sprintf("Disconnecting %s because of %s", address.NodeID, err)) @@ -690,21 +690,13 @@ func (r *Router) handshakePeer( return peerInfo, nil } -func (r *Router) runWithPeerMutex(fn func() error) error { - for range r.peerStates.Lock() { - return fn() - } - panic("unreachable") -} - // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) error { r.metrics.Peers.Add(1) - r.peerManager.Ready(ctx, peerID, channels) - ctx,cancel := context.WithCancel(ctx) + peerCtx,cancel := context.WithCancel(ctx) state := &peerState{ cancel: cancel, queue: NewQueue(queueBufferDefault), @@ -717,9 +709,12 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec states[peerID] = state } r.logger.Debug("peer connected", "peer", peerID, "endpoint", conn) - err := scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + r.peerManager.Ready(ctx, peerID, channels) + err := scope.Run(peerCtx, func(ctx context.Context, s scope.Scope) error { s.Spawn(func() error { return r.receivePeer(ctx, peerID, conn) }) s.Spawn(func() error { return r.sendPeer(ctx, peerID, conn, state.queue) }) + <-ctx.Done() + _ = conn.Close() // TODO: conn.ReceiveMessage() (either in inmem or mconn) does not respect the context. Fix that. return nil }) r.logger.Info("peer disconnected", "peer", peerID, "endpoint", conn, "err", err) diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 0172bb114..05b0cce20 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -96,17 +96,18 @@ func TestRouter_Network(t *testing.T) { func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) + logger,_ := log.NewDefaultLogger("plain","debug") ctx := t.Context() // Set up a router with no transports (so no peers). - peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) + peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) testnet := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 1}) router, err := p2p.NewRouter( - log.NewNopLogger(), + logger, p2p.NopMetrics(), selfKey, peerManager, @@ -121,7 +122,7 @@ func TestRouter_Channel_Basic(t *testing.T) { require.NoError(t, router.Start(ctx)) t.Cleanup(router.Wait) - // Opening a channel should work. + t.Logf("Opening a channel should work.") chctx, chcancel := context.WithCancel(ctx) defer chcancel() @@ -129,29 +130,29 @@ func TestRouter_Channel_Basic(t *testing.T) { require.NoError(t, err) require.NotNil(t, channel) - // Opening the same channel again should fail. + t.Logf("Opening the same channel again should fail.") _, err = router.OpenChannel(ctx, chDesc) require.Error(t, err) - // Opening a different channel should work. + t.Logf("Opening a different channel should work.") chDesc2 := &p2p.ChannelDescriptor{ID: 2, MessageType: &p2ptest.Message{}} _, err = router.OpenChannel(ctx, chDesc2) require.NoError(t, err) - // Closing the channel, then opening it again should be fine. + t.Logf("Closing the channel, then opening it again should be fine.") chcancel() time.Sleep(200 * time.Millisecond) // yes yes, but Close() is async... channel, err = router.OpenChannel(ctx, chDesc) require.NoError(t, err) - // We should be able to send on the channel, even though there are no peers. + t.Logf("We should be able to send on the channel, even though there are no peers.") p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "foo"}, }) - // A message to ourselves should be dropped. + t.Logf("A message to ourselves should be dropped.") p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ To: selfID, Message: &p2ptest.Message{Value: "self"}, @@ -734,6 +735,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func TestRouter_EvictPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) + logger,_ := log.NewDefaultLogger("plain","debug") ctx := t.Context() @@ -761,13 +763,13 @@ func TestRouter_EvictPeers(t *testing.T) { mockTransport.On("Listen", mock.Anything).Return(nil) // Set up and start the router. - peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) + peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) sub := peerManager.Subscribe(ctx) router, err := p2p.NewRouter( - log.NewNopLogger(), + logger, p2p.NopMetrics(), selfKey, peerManager, @@ -785,7 +787,7 @@ func TestRouter_EvictPeers(t *testing.T) { NodeID: peerInfo.NodeID, Status: p2p.PeerStatusUp, }) - + t.Logf("node is up") peerManager.Errored(peerInfo.NodeID, errors.New("boom")) p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index a27fdac89..19edb0c2a 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -42,22 +42,22 @@ type reactorTestSuite struct { stateProvider *mocks.StateProvider snapshotChannel *p2p.Channel - snapshotInCh chan p2p.Envelope + snapshotInCh *p2p.Queue snapshotOutCh chan p2p.Envelope snapshotPeerErrCh chan p2p.PeerError chunkChannel *p2p.Channel - chunkInCh chan p2p.Envelope + chunkInCh *p2p.Queue chunkOutCh chan p2p.Envelope chunkPeerErrCh chan p2p.PeerError blockChannel *p2p.Channel - blockInCh chan p2p.Envelope + blockInCh *p2p.Queue blockOutCh chan p2p.Envelope blockPeerErrCh chan p2p.PeerError paramsChannel *p2p.Channel - paramsInCh chan p2p.Envelope + paramsInCh *p2p.Queue paramsOutCh chan p2p.Envelope paramsPeerErrCh chan p2p.PeerError @@ -73,7 +73,7 @@ func setup( t *testing.T, conn *clientmocks.Client, stateProvider *mocks.StateProvider, - chBuf uint, + chBuf int, ) *reactorTestSuite { t.Helper() @@ -82,16 +82,16 @@ func setup( } rts := &reactorTestSuite{ - snapshotInCh: make(chan p2p.Envelope, chBuf), + snapshotInCh: p2p.NewQueue(chBuf), snapshotOutCh: make(chan p2p.Envelope, chBuf), snapshotPeerErrCh: make(chan p2p.PeerError, chBuf), - chunkInCh: make(chan p2p.Envelope, chBuf), + chunkInCh: p2p.NewQueue(chBuf), chunkOutCh: make(chan p2p.Envelope, chBuf), chunkPeerErrCh: make(chan p2p.PeerError, chBuf), - blockInCh: make(chan p2p.Envelope, chBuf), + blockInCh: p2p.NewQueue(chBuf), blockOutCh: make(chan p2p.Envelope, chBuf), blockPeerErrCh: make(chan p2p.PeerError, chBuf), - paramsInCh: make(chan p2p.Envelope, chBuf), + paramsInCh: p2p.NewQueue(chBuf), paramsOutCh: make(chan p2p.Envelope, chBuf), paramsPeerErrCh: make(chan p2p.PeerError, chBuf), conn: conn, @@ -242,11 +242,11 @@ func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { rts := setup(ctx, t, nil, nil, 2) - rts.chunkInCh <- p2p.Envelope{ + rts.chunkInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: ChunkChannel, Message: &ssproto.SnapshotsRequest{}, - } + },0) response := <-rts.chunkPeerErrCh require.Error(t, response.Err) @@ -297,11 +297,11 @@ func TestReactor_ChunkRequest(t *testing.T) { rts := setup(ctx, t, conn, nil, 2) - rts.chunkInCh <- p2p.Envelope{ + rts.chunkInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: ChunkChannel, Message: tc.request, - } + },0) response := <-rts.chunkOutCh require.Equal(t, tc.expectResponse, response.Message) @@ -317,11 +317,11 @@ func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { rts := setup(ctx, t, nil, nil, 2) - rts.snapshotInCh <- p2p.Envelope{ + rts.snapshotInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: SnapshotChannel, Message: &ssproto.ChunkRequest{}, - } + },0) response := <-rts.snapshotPeerErrCh require.Error(t, response.Err) @@ -377,11 +377,11 @@ func TestReactor_SnapshotsRequest(t *testing.T) { rts := setup(ctx, t, conn, nil, 100) - rts.snapshotInCh <- p2p.Envelope{ + rts.snapshotInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: SnapshotChannel, Message: &ssproto.SnapshotsRequest{}, - } + },0) if len(tc.expectResponses) > 0 { retryUntil(ctx, t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) @@ -434,13 +434,13 @@ func TestReactor_LightBlockResponse(t *testing.T) { rts.stateStore.On("LoadValidators", height).Return(vals, nil) - rts.blockInCh <- p2p.Envelope{ + rts.blockInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: LightBlockChannel, Message: &ssproto.LightBlockRequest{ Height: 10, }, - } + },0) require.Empty(t, rts.blockPeerErrCh) select { @@ -622,7 +622,6 @@ func TestReactor_Backfill(t *testing.T) { // test backfill algorithm with varying failure rates [0, 10] failureRates := []int{0, 2, 9} for _, failureRate := range failureRates { - failureRate := failureRate t.Run(fmt.Sprintf("failure rate: %d", failureRate), func(t *testing.T) { ctx := t.Context() t.Cleanup(leaktest.CheckTimeout(t, 1*time.Minute)) @@ -718,7 +717,7 @@ func handleLightBlockRequests( t *testing.T, chain map[int64]*types.LightBlock, receiving chan p2p.Envelope, - sending chan p2p.Envelope, + sending *p2p.Queue, close chan struct{}, failureRate int) { requests := 0 @@ -732,17 +731,13 @@ func handleLightBlockRequests( if requests%10 >= failureRate { lb, err := chain[int64(msg.Height)].ToProto() require.NoError(t, err) - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: LightBlockChannel, Message: &ssproto.LightBlockResponse{ LightBlock: lb, }, - }: - case <-ctx.Done(): - return - } + },0) } else { switch errorCount % 3 { case 0: // send a different block @@ -750,29 +745,21 @@ func handleLightBlockRequests( _, _, lb := mockLB(ctx, t, int64(msg.Height), factory.DefaultTestTime, factory.MakeBlockID(), vals, pv) differntLB, err := lb.ToProto() require.NoError(t, err) - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: LightBlockChannel, Message: &ssproto.LightBlockResponse{ LightBlock: differntLB, }, - }: - case <-ctx.Done(): - return - } + },0) case 1: // send nil block i.e. pretend we don't have it - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: LightBlockChannel, Message: &ssproto.LightBlockResponse{ LightBlock: nil, }, - }: - case <-ctx.Done(): - return - } + },0) case 2: // don't do anything } errorCount++ @@ -788,7 +775,8 @@ func handleLightBlockRequests( func handleConsensusParamsRequest( ctx context.Context, t *testing.T, - receiving, sending chan p2p.Envelope, + receiving chan p2p.Envelope, + sending *p2p.Queue, closeCh chan struct{}, ) { t.Helper() @@ -804,21 +792,14 @@ func handleConsensusParamsRequest( t.Errorf("message was %T which is not a params request", envelope.Message) return } - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: ParamsChannel, Message: &ssproto.ParamsResponse{ Height: msg.Height, ConsensusParams: paramsProto, }, - }: - case <-ctx.Done(): - return - case <-closeCh: - return - } - + },0) case <-closeCh: return } @@ -902,7 +883,7 @@ func handleSnapshotRequests( ctx context.Context, t *testing.T, receivingCh chan p2p.Envelope, - sendingCh chan p2p.Envelope, + sendingCh *p2p.Queue, closeCh chan struct{}, snapshots []snapshot, ) { @@ -917,7 +898,7 @@ func handleSnapshotRequests( _, ok := envelope.Message.(*ssproto.SnapshotsRequest) require.True(t, ok) for _, snapshot := range snapshots { - sendingCh <- p2p.Envelope{ + sendingCh.Send(p2p.Envelope{ From: envelope.To, ChannelID: SnapshotChannel, Message: &ssproto.SnapshotsResponse{ @@ -927,7 +908,7 @@ func handleSnapshotRequests( Hash: snapshot.Hash, Metadata: snapshot.Metadata, }, - } + },0) } } } @@ -937,7 +918,7 @@ func handleChunkRequests( ctx context.Context, t *testing.T, receivingCh chan p2p.Envelope, - sendingCh chan p2p.Envelope, + sendingCh *p2p.Queue, closeCh chan struct{}, chunk []byte, ) { @@ -951,7 +932,7 @@ func handleChunkRequests( case envelope := <-receivingCh: msg, ok := envelope.Message.(*ssproto.ChunkRequest) require.True(t, ok) - sendingCh <- p2p.Envelope{ + sendingCh.Send(p2p.Envelope{ From: envelope.To, ChannelID: ChunkChannel, Message: &ssproto.ChunkResponse{ @@ -961,7 +942,7 @@ func handleChunkRequests( Chunk: chunk, Missing: false, }, - } + },0) } } diff --git a/light/dispatcher_test.go b/light/dispatcher_test.go index 57fc042f8..f73c768dd 100644 --- a/light/dispatcher_test.go +++ b/light/dispatcher_test.go @@ -21,14 +21,14 @@ import ( ) type channelInternal struct { - In chan p2p.Envelope + In *p2p.Queue Out chan p2p.Envelope Error chan p2p.PeerError } func testChannel(size int) (*channelInternal, *p2p.Channel) { in := &channelInternal{ - In: make(chan p2p.Envelope, size), + In: p2p.NewQueue(size), Out: make(chan p2p.Envelope, size), Error: make(chan p2p.PeerError, size), } @@ -55,7 +55,7 @@ func TestDispatcherBasic(t *testing.T) { // make a bunch of async requests and require that the correct responses are // given - for i := 0; i < numPeers; i++ { + for i := range numPeers { wg.Add(1) go func(height int64) { defer wg.Done() @@ -175,7 +175,7 @@ func TestPeerListBasic(t *testing.T) { assert.Equal(t, numPeers, peerList.Len()) half := numPeers / 2 - for i := 0; i < half; i++ { + for i := range half { assert.Equal(t, peerSet[i], peerList.Pop(ctx)) } assert.Equal(t, half, peerList.Len()) @@ -330,7 +330,7 @@ func handleRequests(ctx context.Context, t *testing.T, d *Dispatcher, ch chan p2 func createPeerSet(num int) []types.NodeID { peers := make([]types.NodeID, num) - for i := 0; i < num; i++ { + for i := range num { peers[i], _ = types.NewNodeID(strings.Repeat(fmt.Sprintf("%d", i), 2*types.NodeIDByteLength)) } return peers From 69d692fa690b1cb4002ed62403816d82ab07dffc Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 19:26:39 +0200 Subject: [PATCH 17/41] made the channels non-cancellable instead --- internal/p2p/channel.go | 2 + internal/p2p/p2ptest/network.go | 16 ++-- internal/p2p/p2ptest/require.go | 137 +++++++++---------------------- internal/p2p/pex/reactor_test.go | 14 ++-- internal/p2p/router.go | 11 +-- internal/p2p/router_test.go | 124 +++++++++++++--------------- internal/p2p/rqueue.go | 7 ++ 7 files changed, 118 insertions(+), 193 deletions(-) diff --git a/internal/p2p/channel.go b/internal/p2p/channel.go index 7a096453f..cffff9191 100644 --- a/internal/p2p/channel.go +++ b/internal/p2p/channel.go @@ -103,6 +103,8 @@ func (ch *Channel) SendError(ctx context.Context, pe PeerError) error { func (ch *Channel) String() string { return fmt.Sprintf("p2p.Channel<%d:%s>", ch.ID, ch.name) } +func (ch *Channel) ReceiveLen() int { return ch.inCh.Len() } + // Receive returns a new unbuffered iterator to receive messages from ch. // The iterator runs until ctx ends. func (ch *Channel) Receive(ctx context.Context) *ChannelIterator { diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 7069ac850..60a4311da 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -144,13 +144,12 @@ func (n *Network) NodeIDs() []types.NodeID { // MakeChannels makes a channel on all nodes and returns them, automatically // doing error checks and cleanups. func (n *Network) MakeChannels( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) map[types.NodeID]*p2p.Channel { channels := map[types.NodeID]*p2p.Channel{} for _, node := range n.Nodes { - channels[node.NodeID] = node.MakeChannel(ctx, t, chDesc) + channels[node.NodeID] = node.MakeChannel(t, chDesc) } return channels } @@ -159,13 +158,12 @@ func (n *Network) MakeChannels( // automatically doing error checks. The caller must ensure proper cleanup of // all the channels. func (n *Network) MakeChannelsNoCleanup( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) map[types.NodeID]*p2p.Channel { channels := map[types.NodeID]*p2p.Channel{} for _, node := range n.Nodes { - channels[node.NodeID] = node.MakeChannelNoCleanup(ctx, t, chDesc) + channels[node.NodeID] = node.MakeChannelNoCleanup(t, chDesc) } return channels } @@ -306,16 +304,13 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) // test cleanup, it also checks that the channel is empty, to make sure // all expected messages have been asserted. func (n *Node) MakeChannel( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) *p2p.Channel { - ctx, cancel := context.WithCancel(ctx) - channel, err := n.Router.OpenChannel(ctx, chDesc) + channel, err := n.Router.OpenChannel(chDesc) require.NoError(t, err) t.Cleanup(func() { - RequireEmpty(ctx, t, channel) - cancel() + RequireEmpty(t, channel) }) return channel } @@ -323,11 +318,10 @@ func (n *Node) MakeChannel( // MakeChannelNoCleanup opens a channel, with automatic error handling. The // caller must ensure proper cleanup of the channel. func (n *Node) MakeChannelNoCleanup( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) *p2p.Channel { - channel, err := n.Router.OpenChannel(ctx, chDesc) + channel, err := n.Router.OpenChannel(chDesc) require.NoError(t, err) return channel } diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index 885e080d4..ea8d98132 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -2,64 +2,40 @@ package p2ptest import ( "context" - "errors" "testing" - "time" "github.com/gogo/protobuf/proto" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils" ) // RequireEmpty requires that the given channel is empty. -func RequireEmpty(ctx context.Context, t *testing.T, channels ...*p2p.Channel) { +func RequireEmpty(t *testing.T, channels ...*p2p.Channel) { t.Helper() - - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - - iter := p2p.MergedChannelIterator(ctx, channels...) - count := 0 - for iter.Next(ctx) { - count++ - require.Nil(t, iter.Envelope()) + for _,ch := range channels { + if ch.ReceiveLen() != 0 { + t.Errorf("nonempty channel %v", ch) + } } - require.Zero(t, count) - require.Error(t, ctx.Err()) } // RequireReceive requires that the given envelope is received on the channel. -func RequireReceive(ctx context.Context, t *testing.T, channel *p2p.Channel, expect p2p.Envelope) { +func RequireReceive(t *testing.T, channel *p2p.Channel, expect p2p.Envelope) { t.Helper() - - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - iter := channel.Receive(ctx) - count := 0 - for iter.Next(ctx) { - count++ - envelope := iter.Envelope() - require.Equal(t, expect.From, envelope.From) - require.Equal(t, expect.Message, envelope.Message) - } - - if !assert.True(t, count >= 1) { - require.NoError(t, ctx.Err(), "timed out waiting for message %v", expect) - } + RequireReceiveUnordered(t, channel, []*p2p.Envelope{&expect}) } // RequireReceiveUnordered requires that the given envelopes are all received on // the channel, ignoring order. -func RequireReceiveUnordered(ctx context.Context, t *testing.T, channel *p2p.Channel, expect []*p2p.Envelope) { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - +func RequireReceiveUnordered(t *testing.T, channel *p2p.Channel, expect []*p2p.Envelope) { + t.Helper() + t.Logf("awaiting %d messages", len(expect)) actual := []*p2p.Envelope{} - + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() iter := channel.Receive(ctx) for iter.Next(ctx) { actual = append(actual, iter.Envelope()) @@ -68,103 +44,68 @@ func RequireReceiveUnordered(ctx context.Context, t *testing.T, channel *p2p.Cha return } } - - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - require.ElementsMatch(t, expect, actual) - } + require.FailNow(t,"not enough messages") } // RequireSend requires that the given envelope is sent on the channel. -func RequireSend(ctx context.Context, t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { - tctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - err := channel.Send(tctx, envelope) - switch { - case errors.Is(err, context.DeadlineExceeded): - require.Fail(t, "timed out sending message to %q", envelope.To) - default: - require.NoError(t, err, "unexpected error") - } +func RequireSend(t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { + t.Logf("sending message %v", envelope) + require.NoError(t,channel.Send(t.Context(), envelope)) } // RequireSendReceive requires that a given Protobuf message is sent to the // given peer, and then that the given response is received back. func RequireSendReceive( - ctx context.Context, t *testing.T, channel *p2p.Channel, peerID types.NodeID, send proto.Message, receive proto.Message, ) { - RequireSend(ctx, t, channel, p2p.Envelope{To: peerID, Message: send}) - RequireReceive(ctx, t, channel, p2p.Envelope{From: peerID, Message: send}) + RequireSend(t, channel, p2p.Envelope{To: peerID, Message: send}) + RequireReceive(t, channel, p2p.Envelope{From: peerID, Message: send}) } // RequireNoUpdates requires that a PeerUpdates subscription is empty. func RequireNoUpdates(ctx context.Context, t *testing.T, peerUpdates *p2p.PeerUpdates) { t.Helper() - select { - case update := <-peerUpdates.Updates(): - if ctx.Err() == nil { - require.Fail(t, "unexpected peer updates", "got %v", update) - } - case <-ctx.Done(): - default: + if len(peerUpdates.Updates())!=0 { + require.FailNow(t, "unexpected peer updates") } } // RequireError requires that the given peer error is submitted for a peer. -func RequireError(ctx context.Context, t *testing.T, channel *p2p.Channel, peerError p2p.PeerError) { - tctx, tcancel := context.WithTimeout(ctx, time.Second) - defer tcancel() - - err := channel.SendError(tctx, peerError) - switch { - case errors.Is(err, context.DeadlineExceeded): - require.Fail(t, "timed out reporting error", "%v for %q", peerError, channel.String()) - default: - require.NoError(t, err, "unexpected error") - } +func RequireSendError(t *testing.T, channel *p2p.Channel, peerError p2p.PeerError) { + require.NoError(t, channel.SendError(t.Context(), peerError)) } // RequireUpdate requires that a PeerUpdates subscription yields the given update. func RequireUpdate(t *testing.T, peerUpdates *p2p.PeerUpdates, expect p2p.PeerUpdate) { - timer := time.NewTimer(time.Second) // not time.After due to goroutine leaks - defer timer.Stop() - - select { - case update := <-peerUpdates.Updates(): - require.Equal(t, expect.NodeID, update.NodeID, "node id did not match") - require.Equal(t, expect.Status, update.Status, "statuses did not match") - case <-timer.C: - require.Fail(t, "timed out waiting for peer update", "expected %v", expect) + t.Logf("awaiting update %v", expect) + update,err := utils.Recv(t.Context(),peerUpdates.Updates()) + if err!=nil { + require.FailNow(t, "utils.Recv(): %w", err) } + require.Equal(t, expect.NodeID, update.NodeID, "node id did not match") + require.Equal(t, expect.Status, update.Status, "statuses did not match") } // RequireUpdates requires that a PeerUpdates subscription yields the given updates // in the given order. func RequireUpdates(t *testing.T, peerUpdates *p2p.PeerUpdates, expect []p2p.PeerUpdate) { - timer := time.NewTimer(time.Second) // not time.After due to goroutine leaks - defer timer.Stop() - + t.Logf("awaiting %d updates", len(expect)) actual := []p2p.PeerUpdate{} for { - select { - case update := <-peerUpdates.Updates(): - actual = append(actual, update) - if len(actual) == len(expect) { - for idx := range expect { - require.Equal(t, expect[idx].NodeID, actual[idx].NodeID) - require.Equal(t, expect[idx].Status, actual[idx].Status) - } - - return + update,err := utils.Recv(t.Context(),peerUpdates.Updates()) + if err!=nil { + require.FailNow(t,"utils.Recv(): %v",err) + } + actual = append(actual, update) + if len(actual) == len(expect) { + for idx := range expect { + require.Equal(t, expect[idx].NodeID, actual[idx].NodeID) + require.Equal(t, expect[idx].Status, actual[idx].Status) } - - case <-timer.C: - require.Equal(t, expect, actual, "did not receive expected peer updates") return } } diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 19d1c155e..4b0c67093 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -132,7 +132,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { testNet.addNodes(ctx, t, 110) nodes := make([]int, 110) - for i := 0; i < len(nodes); i++ { + for i := range nodes { nodes[i] = i + 2 } testNet.addAddresses(t, secondNode, nodes) @@ -375,14 +375,10 @@ func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorT // NOTE: we don't assert that the channels get drained after stopping the // reactor - rts.pexChannels = rts.network.MakeChannelsNoCleanup(ctx, t, pex.ChannelDescriptor()) + rts.pexChannels = rts.network.MakeChannelsNoCleanup(t, pex.ChannelDescriptor()) idx := 0 for nodeID := range rts.network.Nodes { - // make a copy to avoid getting hit by the range ref - // confusion: - nodeID := nodeID - rts.peerChans[nodeID] = make(chan p2p.PeerUpdate, chBuf) rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], chBuf) rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) @@ -433,7 +429,7 @@ func (r *reactorTestSuite) start(ctx context.Context, t *testing.T) { func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int) { t.Helper() - for i := 0; i < nodes; i++ { + for range nodes { node := r.network.MakeNode(ctx, t, p2ptest.NodeOptions{ MaxPeers: r.opts.MaxPeers, MaxConnected: r.opts.MaxConnected, @@ -441,7 +437,7 @@ func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int }) r.network.Nodes[node.NodeID] = node nodeID := node.NodeID - r.pexChannels[nodeID] = node.MakeChannelNoCleanup(ctx, t, pex.ChannelDescriptor()) + r.pexChannels[nodeID] = node.MakeChannelNoCleanup(t, pex.ChannelDescriptor()) r.peerChans[nodeID] = make(chan p2p.PeerUpdate, r.opts.BufferSize) r.peerUpdates[nodeID] = p2p.NewPeerUpdates(r.peerChans[nodeID], r.opts.BufferSize) r.network.Nodes[nodeID].PeerManager.Register(ctx, r.peerUpdates[nodeID]) @@ -642,7 +638,7 @@ func (r *reactorTestSuite) connectN(ctx context.Context, t *testing.T, n int) { } for i := 0; i < r.total; i++ { - for j := 0; j < n; j++ { + for j := range n { r.connectPeers(ctx, t, i, (i+j+1)%r.total) } } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index c75d3a4ef..ecc72a468 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -220,13 +220,8 @@ func NewRouter( // kind shim for testing purposes. type ChannelCreator func(context.Context, *ChannelDescriptor) (*Channel, error) -// OpenChannel opens a new channel for the given message type. The caller must -// close the channel when done, before stopping the Router. messageType is the -// type of message passed through the channel (used for unmarshaling), which can -// implement Wrapper to automatically (un)wrap multiple message types in a -// wrapper message. The caller may provide a size to make the channel buffered, -// which internally makes the inbound, outbound, and error channel buffered. -func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*Channel, error) { +// OpenChannel opens a new channel for the given message type. +func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { r.channelMtx.Lock() defer r.channelMtx.Unlock() @@ -834,7 +829,7 @@ func (r *Router) OnStart(ctx context.Context) error { } for _, chDescWithCb := range r.chDescsToBeAdded { - if ch, err := r.OpenChannel(ctx, chDescWithCb.chDesc); err != nil { + if ch, err := r.OpenChannel(chDescWithCb.chDesc); err != nil { return err } else { chDescWithCb.cb(ch) diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 05b0cce20..970528034 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -45,11 +45,11 @@ func TestRouter_Network(t *testing.T) { t.Cleanup(leaktest.Check(t)) - // Create a test network and open a channel where all peers run echoReactor. + t.Logf("Create a test network and open a channel where all peers run echoReactor.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 8}) local := network.RandomNode() peers := network.Peers(local.NodeID) - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) network.Start(ctx, t) @@ -58,16 +58,16 @@ func TestRouter_Network(t *testing.T) { go echoReactor(ctx, channels[peer.NodeID]) } - // Sending a message to each peer should work. + t.Logf("Sending a message to each peer should work.") for _, peer := range peers { - p2ptest.RequireSendReceive(ctx, t, channel, peer.NodeID, + p2ptest.RequireSendReceive(t, channel, peer.NodeID, &p2ptest.Message{Value: "foo"}, &p2ptest.Message{Value: "foo"}, ) } - // Sending a broadcast should return back a message from all peers. - p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ + t.Logf("Sending a broadcast should return back a message from all peers.") + p2ptest.RequireSend(t, channel, p2p.Envelope{ Broadcast: true, Message: &p2ptest.Message{Value: "bar"}, }) @@ -79,10 +79,10 @@ func TestRouter_Network(t *testing.T) { Message: &p2ptest.Message{Value: "bar"}, }) } - p2ptest.RequireReceiveUnordered(ctx, t, channel, expect) + p2ptest.RequireReceiveUnordered(t, channel, expect) - // We then submit an error for a peer, and watch it get disconnected and - // then reconnected as the router retries it. + t.Logf("We then submit an error for a peer, and watch it get disconnected and") + t.Logf("then reconnected as the router retries it.") peerUpdates := local.MakePeerUpdatesNoRequireEmpty(ctx, t) require.NoError(t, channel.SendError(ctx, p2p.PeerError{ NodeID: peers[0].NodeID, @@ -123,41 +123,31 @@ func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(router.Wait) t.Logf("Opening a channel should work.") - chctx, chcancel := context.WithCancel(ctx) - defer chcancel() - - channel, err := router.OpenChannel(chctx, chDesc) + channel, err := router.OpenChannel(chDesc) require.NoError(t, err) require.NotNil(t, channel) t.Logf("Opening the same channel again should fail.") - _, err = router.OpenChannel(ctx, chDesc) + _, err = router.OpenChannel(chDesc) require.Error(t, err) t.Logf("Opening a different channel should work.") chDesc2 := &p2p.ChannelDescriptor{ID: 2, MessageType: &p2ptest.Message{}} - _, err = router.OpenChannel(ctx, chDesc2) - require.NoError(t, err) - - t.Logf("Closing the channel, then opening it again should be fine.") - chcancel() - time.Sleep(200 * time.Millisecond) // yes yes, but Close() is async... - - channel, err = router.OpenChannel(ctx, chDesc) + _, err = router.OpenChannel(chDesc2) require.NoError(t, err) t.Logf("We should be able to send on the channel, even though there are no peers.") - p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ + p2ptest.RequireSend(t, channel, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "foo"}, }) t.Logf("A message to ourselves should be dropped.") - p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ + p2ptest.RequireSend(t, channel, p2p.Envelope{ To: selfID, Message: &p2ptest.Message{Value: "self"}, }) - p2ptest.RequireEmpty(ctx, t, channel) + p2ptest.RequireEmpty(t, channel) } // Channel tests are hairy to mock, so we use an in-memory network instead. @@ -166,59 +156,59 @@ func TestRouter_Channel_SendReceive(t *testing.T) { t.Cleanup(leaktest.Check(t)) - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) ids := network.NodeIDs() aID, bID, cID := ids[0], ids[1], ids[2] - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a, b, c := channels[aID], channels[bID], channels[cID] - otherChannels := network.MakeChannels(ctx, t, p2ptest.MakeChannelDesc(9)) + otherChannels := network.MakeChannels(t, p2ptest.MakeChannelDesc(9)) network.Start(ctx, t) // Sending a message a->b should work, and not send anything // further to a, b, or c. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireEmpty(t, a, b, c) // Sending a nil message a->b should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: nil}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil}) + p2ptest.RequireEmpty(t, a, b, c) // Sending a different message type should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) + p2ptest.RequireEmpty(t, a, b, c) // Sending to an unknown peer should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{ + p2ptest.RequireSend(t, a, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "a"}, }) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireEmpty(t, a, b, c) // Sending without a recipient should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}}) + p2ptest.RequireEmpty(t, a, b, c) // Sending to self should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}}) + p2ptest.RequireEmpty(t, a, b, c) // Removing b and sending to it should be dropped. network.Remove(ctx, t, bID) - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) + p2ptest.RequireEmpty(t, a, b, c) // After all this, sending a message c->a should work. - p2ptest.RequireSend(ctx, t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(ctx, t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireReceive(t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireEmpty(t, a, b, c) // None of these messages should have made it onto the other channels. for _, other := range otherChannels { - p2ptest.RequireEmpty(ctx, t, other) + p2ptest.RequireEmpty(t, other) } } @@ -232,24 +222,24 @@ func TestRouter_Channel_Broadcast(t *testing.T) { ids := network.NodeIDs() aID, bID, cID, dID := ids[0], ids[1], ids[2], ids[3] - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a, b, c, d := channels[aID], channels[bID], channels[cID], channels[dID] network.Start(ctx, t) // Sending a broadcast from b should work. - p2ptest.RequireSend(ctx, t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c, d) + p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireReceive(t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireReceive(t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireEmpty(t, a, b, c, d) // Removing one node from the network shouldn't prevent broadcasts from working. network.Remove(ctx, t, dID) - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(ctx, t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c, d) + p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireEmpty(t, a, b, c, d) } func TestRouter_Channel_Wrapper(t *testing.T) { @@ -270,7 +260,7 @@ func TestRouter_Channel_Wrapper(t *testing.T) { RecvMessageCapacity: 10, } - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a, b := channels[aID], channels[bID] network.Start(ctx, t) @@ -278,20 +268,20 @@ func TestRouter_Channel_Wrapper(t *testing.T) { // Since wrapperMessage implements p2p.Wrapper and handles Message, it // should automatically wrap and unwrap sent messages -- we prepend the // wrapper actions to the message value to signal this. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}}) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}}) // If we send a different message that can't be wrapped, it should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) - p2ptest.RequireEmpty(ctx, t, b) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) + p2ptest.RequireEmpty(t, b) // If we send the wrapper message itself, it should also be passed through // since WrapperMessage supports it, and should only be unwrapped at the receiver. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{ + p2ptest.RequireSend(t, a, p2p.Envelope{ To: bID, Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, }) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{ + p2ptest.RequireReceive(t, b, p2p.Envelope{ From: aID, Message: &p2ptest.Message{Value: "unwrap:foo"}, }) @@ -332,12 +322,12 @@ func TestRouter_Channel_Error(t *testing.T) { ids := network.NodeIDs() aID, bID := ids[0], ids[1] - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a := channels[aID] // Erroring b should cause it to be disconnected. It will reconnect shortly after. sub := network.Nodes[aID].MakePeerUpdates(ctx, t) - p2ptest.RequireError(ctx, t, a, p2p.PeerError{NodeID: bID, Err: errors.New("boom")}) + p2ptest.RequireSendError(t, a, p2p.PeerError{NodeID: bID, Err: errors.New("boom")}) p2ptest.RequireUpdates(t, sub, []p2p.PeerUpdate{ {NodeID: bID, Status: p2p.PeerStatusDown}, {NodeID: bID, Status: p2p.PeerStatusUp}, @@ -904,7 +894,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { Status: p2p.PeerStatusUp, }) - channel, err := router.OpenChannel(ctx, chDesc) + channel, err := router.OpenChannel(chDesc) require.NoError(t, err) require.NoError(t, channel.Send(ctx, p2p.Envelope{ diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index 40f700735..d63bf90f0 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -131,6 +131,13 @@ func NewQueue(size int) *Queue { return &Queue{inner: utils.NewWatch(newInner(size*size))} } +func (q *Queue) Len() int { + for inner := range q.inner.Lock() { + return inner.Len() + } + panic("unreachable") +} + // Non-blocking send. func (q *Queue) Send(e Envelope, priority int) { // We construct the pqEnvelope without holding the lock to avoid contention. From b68a5329e7b0f6da62030c76643e45921dec222d Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 22:51:19 +0200 Subject: [PATCH 18/41] p2p tests pass, need more tests for the queue --- internal/p2p/p2p_test.go | 1 + internal/p2p/p2ptest/network.go | 7 +-- internal/p2p/p2ptest/require.go | 17 +------ internal/p2p/peermanager.go | 7 ++- internal/p2p/peermanager_test.go | 13 ++--- internal/p2p/router.go | 24 ++++++--- internal/p2p/router_test.go | 83 +++++++++++++++++--------------- internal/p2p/rqueue.go | 3 ++ internal/p2p/rqueue_test.go | 18 ++----- 9 files changed, 88 insertions(+), 85 deletions(-) diff --git a/internal/p2p/p2p_test.go b/internal/p2p/p2p_test.go index d8657b774..1fcd46c94 100644 --- a/internal/p2p/p2p_test.go +++ b/internal/p2p/p2p_test.go @@ -17,6 +17,7 @@ var ( MessageType: &p2ptest.Message{}, Priority: 5, SendQueueCapacity: 10, + RecvBufferCapacity: 10, RecvMessageCapacity: 10, } diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 60a4311da..48b363475 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -51,7 +51,7 @@ func (opts *NetworkOptions) setDefaults() { // connects them to each other. func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Network { opts.setDefaults() - logger := log.NewNopLogger() + logger,_ := log.NewDefaultLogger("plain","info") network := &Network{ Nodes: map[types.NodeID]*Node{}, logger: logger, @@ -255,7 +255,8 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) maxRetryTime = opts.MaxRetryTime } - peerManager, err := p2p.NewPeerManager(n.logger, nodeID, dbm.NewMemDB(), p2p.PeerManagerOptions{ + logger := n.logger.With("node", nodeID[:5]) + peerManager, err := p2p.NewPeerManager(logger, nodeID, dbm.NewMemDB(), p2p.PeerManagerOptions{ MinRetryTime: 10 * time.Millisecond, MaxRetryTime: maxRetryTime, RetryTimeJitter: time.Millisecond, @@ -265,7 +266,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) require.NoError(t, err) router, err := p2p.NewRouter( - n.logger, + logger, p2p.NopMetrics(), privKey, peerManager, diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index ea8d98132..bbf78fc47 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -4,11 +4,9 @@ import ( "context" "testing" - "github.com/gogo/protobuf/proto" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/internal/p2p" - "github.com/tendermint/tendermint/types" "github.com/tendermint/tendermint/libs/utils" ) @@ -25,7 +23,7 @@ func RequireEmpty(t *testing.T, channels ...*p2p.Channel) { // RequireReceive requires that the given envelope is received on the channel. func RequireReceive(t *testing.T, channel *p2p.Channel, expect p2p.Envelope) { t.Helper() - RequireReceiveUnordered(t, channel, []*p2p.Envelope{&expect}) + RequireReceiveUnordered(t, channel, utils.Slice(&expect)) } // RequireReceiveUnordered requires that the given envelopes are all received on @@ -53,19 +51,6 @@ func RequireSend(t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { require.NoError(t,channel.Send(t.Context(), envelope)) } -// RequireSendReceive requires that a given Protobuf message is sent to the -// given peer, and then that the given response is received back. -func RequireSendReceive( - t *testing.T, - channel *p2p.Channel, - peerID types.NodeID, - send proto.Message, - receive proto.Message, -) { - RequireSend(t, channel, p2p.Envelope{To: peerID, Message: send}) - RequireReceive(t, channel, p2p.Envelope{From: peerID, Message: send}) -} - // RequireNoUpdates requires that a PeerUpdates subscription is empty. func RequireNoUpdates(ctx context.Context, t *testing.T, peerUpdates *p2p.PeerUpdates) { t.Helper() diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index f90e52521..907776fc7 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -1326,7 +1326,12 @@ func (s *peerStore) Ranked() []*peerInfo { sort.Slice(s.ranked, func(i, j int) bool { // FIXME: If necessary, consider precomputing scores before sorting, // to reduce the number of Score() calls. - return s.ranked[i].Score() > s.ranked[j].Score() + if a,b := s.ranked[i].Score(),s.ranked[j].Score(); a != b { + return a > b + } + // TODO(gprusak): we don't allow ties because tests require deterministic order. + // If not necessary in prod, then fix the tests instaed. + return s.ranked[i].ID < s.ranked[j].ID }) for _, peer := range s.ranked { s.metrics.PeerScore.With("peer_id", string(peer.ID)).Set(float64(int(peer.Score()))) diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 048affb54..d346376cd 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -834,7 +834,8 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} c := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("c", 40))} - peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ + logger,_ := log.NewDefaultLogger("plain","debug") + peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ PeerScores: map[types.NodeID]p2p.PeerScore{ a.NodeID: p2p.DefaultMutableScore - 1, // Set lower score for a to make it upgradeable b.NodeID: p2p.DefaultMutableScore + 1, // Higher score for b to attempt upgrade @@ -845,7 +846,7 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { }, p2p.NopMetrics()) require.NoError(t, err) - // Add and connect to peer a (lower scored) + t.Logf("Add and connect to peer a (lower scored)") added, err := peerManager.Add(a) require.NoError(t, err) require.True(t, added) @@ -854,7 +855,7 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { require.Equal(t, a, dial) require.NoError(t, peerManager.Dialed(a)) - // Add both higher scored peers b and c + t.Logf("Add both higher scored peers b and c") added, err = peerManager.Add(b) require.NoError(t, err) require.True(t, added) @@ -862,13 +863,13 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { require.NoError(t, err) require.True(t, added) - // Attempt to dial b for upgrade + t.Logf("Attempt to dial b for upgrade") dial, err = peerManager.TryDialNext() require.NoError(t, err) require.Equal(t, b, dial) - // When b's dial fails, the upgrade slot should be freed - // allowing c to attempt upgrade of the same peer (a) + t.Logf("When b's dial fails, the upgrade slot should be freed") + t.Logf("allowing c to attempt upgrade of the same peer (a)") require.NoError(t, peerManager.DialFailed(ctx, b)) dial, err = peerManager.TryDialNext() require.NoError(t, err) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index ecc72a468..6bf6a561f 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -302,8 +302,8 @@ func (r *Router) routeChannel( // wrap the message in a wrapper message, if requested if wrapper != nil { - msg := proto.Clone(wrapper) - if err := msg.(Wrapper).Wrap(envelope.Message); err != nil { + msg := utils.ProtoClone(wrapper) + if err := msg.Wrap(envelope.Message); err != nil { r.logger.Error("failed to wrap message", "channel", chDesc.ID, "err", err) continue } @@ -474,7 +474,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { return nil } - if err := r.peerManager.Accepted(peerInfo.NodeID); err != nil { + if err := r.runWithPeerMutex(func() error { return r.peerManager.Accepted(peerInfo.NodeID) }); err != nil { // If peer is trying to reconnect, error and let it reconnect if strings.Contains(err.Error(), "is already connected") { r.peerManager.Errored(peerInfo.NodeID, err) @@ -567,7 +567,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return } - if err := r.peerManager.Dialed(address); err != nil { + if err := r.runWithPeerMutex(func() error { return r.peerManager.Dialed(address) }); err != nil { // If peer is trying to reconnect, fail it and let it reconnect if strings.Contains(err.Error(), "is already connected") { r.logger.Error(fmt.Sprintf("Disconnecting %s because of %s", address.NodeID, err)) @@ -685,11 +685,18 @@ func (r *Router) handshakePeer( return peerInfo, nil } +func (r *Router) runWithPeerMutex(fn func() error) error { + for range r.peerStates.Lock() { + return fn() + } + panic("unreachable") +} // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) error { r.metrics.Peers.Add(1) + r.peerManager.Ready(ctx, peerID, channels) peerCtx,cancel := context.WithCancel(ctx) state := &peerState{ @@ -704,7 +711,6 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec states[peerID] = state } r.logger.Debug("peer connected", "peer", peerID, "endpoint", conn) - r.peerManager.Ready(ctx, peerID, channels) err := scope.Run(peerCtx, func(ctx context.Context, s scope.Scope) error { s.Spawn(func() error { return r.receivePeer(ctx, peerID, conn) }) s.Spawn(func() error { return r.sendPeer(ctx, peerID, conn, state.queue) }) @@ -714,10 +720,16 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec }) r.logger.Info("peer disconnected", "peer", peerID, "endpoint", conn, "err", err) for states := range r.peerStates.Lock() { - delete(states, peerID) + if states[peerID]==state { + delete(states, peerID) + } } + // TODO(gprusak): investigate if peerManager handles overlapping connetions correctly r.peerManager.Disconnected(ctx, peerID) r.metrics.Peers.Add(-1) + if errors.Is(err, io.EOF) { + return nil + } return err } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 970528034..0824577aa 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -10,6 +10,7 @@ import ( "sync" "testing" "time" + slog "log" "github.com/fortytw2/leaktest" "github.com/gogo/protobuf/proto" @@ -31,6 +32,7 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { for iter.Next(ctx) { envelope := iter.Envelope() value := envelope.Message.(*p2ptest.Message).Value + slog.Printf("sending back %v", value) if err := channel.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &p2ptest.Message{Value: value}, @@ -38,6 +40,7 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { return } } + slog.Printf("echoReactor done") } func TestRouter_Network(t *testing.T) { @@ -60,10 +63,9 @@ func TestRouter_Network(t *testing.T) { t.Logf("Sending a message to each peer should work.") for _, peer := range peers { - p2ptest.RequireSendReceive(t, channel, peer.NodeID, - &p2ptest.Message{Value: "foo"}, - &p2ptest.Message{Value: "foo"}, - ) + msg := &p2ptest.Message{Value: "foo"} + p2ptest.RequireSend(t, channel, p2p.Envelope{To: peer.NodeID, Message: msg, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, channel, p2p.Envelope{From: peer.NodeID, Message: msg, ChannelID: chDesc.ID}) } t.Logf("Sending a broadcast should return back a message from all peers.") @@ -167,46 +169,46 @@ func TestRouter_Channel_SendReceive(t *testing.T) { network.Start(ctx, t) - // Sending a message a->b should work, and not send anything - // further to a, b, or c. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}}) + t.Logf("Sending a message a->b should work, and not send anything further to a, b, or c.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // Sending a nil message a->b should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil}) + t.Logf("Sending a nil message a->b should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // Sending a different message type should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) + t.Logf("Sending a different message type should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // Sending to an unknown peer should be dropped. + t.Logf("Sending to an unknown peer should be dropped.") p2ptest.RequireSend(t, a, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "a"}, + ChannelID: chDesc.ID, }) p2ptest.RequireEmpty(t, a, b, c) - // Sending without a recipient should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}}) + t.Logf("Sending without a recipient should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // Sending to self should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}}) + t.Logf("Sending to self should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // Removing b and sending to it should be dropped. + t.Logf("Removing b and sending to it should be dropped.") network.Remove(ctx, t, bID) - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // After all this, sending a message c->a should work. - p2ptest.RequireSend(t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}}) + t.Logf("After all this, sending a message c->a should work.") + p2ptest.RequireSend(t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c) - // None of these messages should have made it onto the other channels. + t.Logf("None of these messages should have made it onto the other channels.") for _, other := range otherChannels { p2ptest.RequireEmpty(t, other) } @@ -217,7 +219,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) { ctx := t.Context() - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 4}) ids := network.NodeIDs() @@ -227,18 +229,18 @@ func TestRouter_Channel_Broadcast(t *testing.T) { network.Start(ctx, t) - // Sending a broadcast from b should work. - p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) + t.Logf("Sending a broadcast from b should work.") + p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c, d) - // Removing one node from the network shouldn't prevent broadcasts from working. + t.Logf("Removing one node from the network shouldn't prevent broadcasts from working.") network.Remove(ctx, t, dID) - p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) + p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, a, b, c, d) } @@ -247,7 +249,7 @@ func TestRouter_Channel_Wrapper(t *testing.T) { ctx := t.Context() - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 2}) ids := network.NodeIDs() @@ -257,6 +259,7 @@ func TestRouter_Channel_Wrapper(t *testing.T) { MessageType: &wrapperMessage{}, Priority: 5, SendQueueCapacity: 10, + RecvBufferCapacity: 10, RecvMessageCapacity: 10, } @@ -268,11 +271,11 @@ func TestRouter_Channel_Wrapper(t *testing.T) { // Since wrapperMessage implements p2p.Wrapper and handles Message, it // should automatically wrap and unwrap sent messages -- we prepend the // wrapper actions to the message value to signal this. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}}) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}, ChannelID: chDesc.ID}) // If we send a different message that can't be wrapped, it should be dropped. - p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}, ChannelID: chDesc.ID}) p2ptest.RequireEmpty(t, b) // If we send the wrapper message itself, it should also be passed through @@ -280,10 +283,12 @@ func TestRouter_Channel_Wrapper(t *testing.T) { p2ptest.RequireSend(t, a, p2p.Envelope{ To: bID, Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, + ChannelID: chDesc.ID, }) p2ptest.RequireReceive(t, b, p2p.Envelope{ From: aID, Message: &p2ptest.Message{Value: "unwrap:foo"}, + ChannelID: chDesc.ID, }) } @@ -316,7 +321,7 @@ func TestRouter_Channel_Error(t *testing.T) { ctx := t.Context() - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) network.Start(ctx, t) @@ -325,7 +330,7 @@ func TestRouter_Channel_Error(t *testing.T) { channels := network.MakeChannels(t, chDesc) a := channels[aID] - // Erroring b should cause it to be disconnected. It will reconnect shortly after. + t.Logf("Erroring b should cause it to be disconnected. It will reconnect shortly after.") sub := network.Nodes[aID].MakePeerUpdates(ctx, t) p2ptest.RequireSendError(t, a, p2p.PeerError{NodeID: bID, Err: errors.New("boom")}) p2ptest.RequireUpdates(t, sub, []p2p.PeerUpdate{ diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index d63bf90f0..8af401839 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -127,6 +127,9 @@ func (i *inner) Pop() *pqEnvelope { type Queue struct { inner utils.Watch[*inner] } func NewQueue(size int) *Queue { + if size<=0 { + size = 1 + } // TODO(gprusak): this size*size looks ridiculous. Fix it. return &Queue{inner: utils.NewWatch(newInner(size*size))} } diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 47bd146e4..8c01a13e9 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -1,9 +1,7 @@ package p2p import ( - "context" "testing" - "time" ) func TestSimpleQueue(t *testing.T) { @@ -16,18 +14,10 @@ func TestSimpleQueue(t *testing.T) { for range 100 { sq.Send(Envelope{From: "merlin"},0) } - - seen := 0 - - for seen <= 2 { - ctx,cancel := context.WithTimeout(ctx,10 * time.Millisecond) - defer cancel() - if _,err:=sq.Recv(ctx); err!=nil { - break - } + if _,err:=sq.Recv(ctx); err!=nil { + t.Fatal(err) } - // if we don't see any messages, then it's just broken. - if seen != 1 { - t.Errorf("seen %d messages, should have seen %v", seen, 1) + if sq.Len()!=0 { + t.Fatalf("queue length is %d, should be 0", sq.Len()) } } From f840a268fe5da8077d614cb27a98534185de247c Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 22:54:49 +0200 Subject: [PATCH 19/41] internal tests compile --- internal/blocksync/reactor_test.go | 12 +----------- internal/consensus/reactor_test.go | 8 ++++---- internal/evidence/reactor_test.go | 2 +- internal/mempool/reactor_test.go | 2 +- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index 1816ad925..f81dc69c7 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -73,7 +73,7 @@ func setup( } chDesc := &p2p.ChannelDescriptor{ID: BlockSyncChannel, MessageType: new(bcproto.Message)} - rts.blockSyncChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) + rts.blockSyncChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) i := 0 for nodeID := range rts.network.Nodes { @@ -101,10 +101,7 @@ func setup( func makeReactor( ctx context.Context, t *testing.T, - nodeID types.NodeID, genDoc *types.GenesisDoc, - privVal types.PrivValidator, - channelCreator p2p.ChannelCreator, peerEvents p2p.PeerEventSubscriber, peerManager *p2p.PeerManager, restartChan chan struct{}, @@ -188,10 +185,6 @@ func (rts *reactorTestSuite) addNode( rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], 1) rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) - chCreator := func(ctx context.Context, chdesc *p2p.ChannelDescriptor) (*p2p.Channel, error) { - return rts.blockSyncChannels[nodeID], nil - } - peerEvents := func(ctx context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] } restartChan := make(chan struct{}) remediationConfig := config.DefaultSelfRemediationConfig() @@ -200,10 +193,7 @@ func (rts *reactorTestSuite) addNode( reactor := makeReactor( ctx, t, - nodeID, genDoc, - privVal, - chCreator, peerEvents, rts.network.Nodes[nodeID].PeerManager, restartChan, diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index dc867a722..2ee657a41 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -78,10 +78,10 @@ func setup( blocksyncSubs: make(map[types.NodeID]eventbus.Subscription, numNodes), } - rts.stateChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(StateChannel, size)) - rts.dataChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(DataChannel, size)) - rts.voteChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(VoteChannel, size)) - rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(VoteSetBitsChannel, size)) + rts.stateChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(StateChannel, size)) + rts.dataChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(DataChannel, size)) + rts.voteChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteChannel, size)) + rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteSetBitsChannel, size)) ctx, cancel := context.WithCancel(ctx) t.Cleanup(cancel) diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index 3d898dc59..ac95bee16 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -64,7 +64,7 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store) *reactorTe } chDesc := &p2p.ChannelDescriptor{ID: evidence.EvidenceChannel, MessageType: new(tmproto.Evidence)} - rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) + rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), 0) idx := 0 diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index d26f2bec3..3387bcd9c 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -59,7 +59,7 @@ func setupReactors(ctx context.Context, t *testing.T, logger log.Logger, numNode } chDesc := GetChannelDescriptor(cfg.Mempool) - rts.mempoolChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) + rts.mempoolChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) for nodeID := range rts.network.Nodes { rts.kvstores[nodeID] = kvstore.NewApplication() From 9d210f55cff10a3518ff393b4c2f32d432db3068 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 26 Aug 2025 23:15:58 +0200 Subject: [PATCH 20/41] compensated for stupid descriptors --- internal/consensus/reactor_test.go | 3 ++- internal/evidence/reactor_test.go | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index 2ee657a41..ca004283b 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math" "math/rand" "os" "sync" @@ -57,7 +58,7 @@ func chDesc(chID p2p.ChannelID, size int) *p2p.ChannelDescriptor { return &p2p.ChannelDescriptor{ ID: chID, MessageType: new(tmcons.Message), - RecvBufferCapacity: size, + RecvBufferCapacity: int(math.Sqrt(float64(size))+1), } } diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index ac95bee16..f19f8276b 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -63,7 +63,11 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store) *reactorTe peerChans: make(map[types.NodeID]chan p2p.PeerUpdate, numStateStores), } - chDesc := &p2p.ChannelDescriptor{ID: evidence.EvidenceChannel, MessageType: new(tmproto.Evidence)} + chDesc := &p2p.ChannelDescriptor{ + ID: evidence.EvidenceChannel, + MessageType: new(tmproto.Evidence), + RecvBufferCapacity: 10, + } rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), 0) From 0fefcc846f7963a9a05bd93ec84e6f4d21d74606 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 27 Aug 2025 09:50:35 +0200 Subject: [PATCH 21/41] formatting --- internal/consensus/reactor_test.go | 2 +- internal/evidence/reactor_test.go | 4 +- internal/p2p/channel.go | 14 ++++-- internal/p2p/channel_test.go | 8 ++-- internal/p2p/p2ptest/network.go | 2 +- internal/p2p/p2ptest/require.go | 18 +++---- internal/p2p/peermanager.go | 2 +- internal/p2p/peermanager_test.go | 2 +- internal/p2p/pex/reactor_test.go | 6 +-- internal/p2p/router.go | 62 +++++++++++++++--------- internal/p2p/router_test.go | 20 ++++---- internal/p2p/rqueue.go | 76 +++++++++++++++--------------- internal/p2p/rqueue_test.go | 6 +-- internal/statesync/reactor_test.go | 22 ++++----- libs/utils/testonly.go | 2 +- 15 files changed, 135 insertions(+), 111 deletions(-) diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index ca004283b..c612d9bb0 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -58,7 +58,7 @@ func chDesc(chID p2p.ChannelID, size int) *p2p.ChannelDescriptor { return &p2p.ChannelDescriptor{ ID: chID, MessageType: new(tmcons.Message), - RecvBufferCapacity: int(math.Sqrt(float64(size))+1), + RecvBufferCapacity: int(math.Sqrt(float64(size)) + 1), } } diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index f19f8276b..a5084a30e 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -64,8 +64,8 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store) *reactorTe } chDesc := &p2p.ChannelDescriptor{ - ID: evidence.EvidenceChannel, - MessageType: new(tmproto.Evidence), + ID: evidence.EvidenceChannel, + MessageType: new(tmproto.Evidence), RecvBufferCapacity: 10, } rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) diff --git a/internal/p2p/channel.go b/internal/p2p/channel.go index cffff9191..e44445f93 100644 --- a/internal/p2p/channel.go +++ b/internal/p2p/channel.go @@ -7,8 +7,8 @@ import ( "github.com/gogo/protobuf/proto" - "github.com/tendermint/tendermint/types" "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/types" ) // Envelope contains a message with sender/receiver routing info. @@ -61,7 +61,7 @@ func (pe PeerError) Unwrap() error { return pe.Err } // Each message is wrapped in an Envelope to specify its sender and receiver. type Channel struct { ID ChannelID - inCh *Queue // inbound messages (peers to reactors) + inCh *Queue // inbound messages (peers to reactors) outCh chan<- Envelope // outbound messages (reactors to peers) errCh chan<- PeerError // peer error reporting @@ -131,9 +131,13 @@ type ChannelIterator struct { func iteratorWorker(ctx context.Context, ch *Channel, pipe chan Envelope) { for { - e,err:=ch.inCh.Recv(ctx) - if err!=nil { return } - if err:=utils.Send(ctx, pipe, e); err!=nil { return } + e, err := ch.inCh.Recv(ctx) + if err != nil { + return + } + if err := utils.Send(ctx, pipe, e); err != nil { + return + } } } diff --git a/internal/p2p/channel_test.go b/internal/p2p/channel_test.go index 4b5164b1c..f889f3f0c 100644 --- a/internal/p2p/channel_test.go +++ b/internal/p2p/channel_test.go @@ -112,7 +112,7 @@ func TestChannel(t *testing.T) { Case: func(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In.Send(Envelope{From: "kip", To: "merlin"},0) + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) iter := ch.Receive(ctx) require.NotNil(t, iter) require.True(t, iter.Next(ctx)) @@ -157,7 +157,7 @@ func TestChannel(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In.Send(Envelope{From: "kip", To: "merlin"},0) + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -180,7 +180,7 @@ func TestChannel(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In.Send(Envelope{From: "kip", To: "merlin"},0) + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -204,7 +204,7 @@ func TestChannel(t *testing.T) { require.NotNil(t, iter) require.Nil(t, iter.Envelope()) - ins.In.Send(Envelope{From: "kip", To: "merlin"},0) + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) require.NotNil(t, iter) require.True(t, iter.Next(ctx)) diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 48b363475..a8de3ff63 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -51,7 +51,7 @@ func (opts *NetworkOptions) setDefaults() { // connects them to each other. func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Network { opts.setDefaults() - logger,_ := log.NewDefaultLogger("plain","info") + logger, _ := log.NewDefaultLogger("plain", "info") network := &Network{ Nodes: map[types.NodeID]*Node{}, logger: logger, diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index bbf78fc47..31042eda3 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -13,7 +13,7 @@ import ( // RequireEmpty requires that the given channel is empty. func RequireEmpty(t *testing.T, channels ...*p2p.Channel) { t.Helper() - for _,ch := range channels { + for _, ch := range channels { if ch.ReceiveLen() != 0 { t.Errorf("nonempty channel %v", ch) } @@ -42,19 +42,19 @@ func RequireReceiveUnordered(t *testing.T, channel *p2p.Channel, expect []*p2p.E return } } - require.FailNow(t,"not enough messages") + require.FailNow(t, "not enough messages") } // RequireSend requires that the given envelope is sent on the channel. func RequireSend(t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { t.Logf("sending message %v", envelope) - require.NoError(t,channel.Send(t.Context(), envelope)) + require.NoError(t, channel.Send(t.Context(), envelope)) } // RequireNoUpdates requires that a PeerUpdates subscription is empty. func RequireNoUpdates(ctx context.Context, t *testing.T, peerUpdates *p2p.PeerUpdates) { t.Helper() - if len(peerUpdates.Updates())!=0 { + if len(peerUpdates.Updates()) != 0 { require.FailNow(t, "unexpected peer updates") } } @@ -67,8 +67,8 @@ func RequireSendError(t *testing.T, channel *p2p.Channel, peerError p2p.PeerErro // RequireUpdate requires that a PeerUpdates subscription yields the given update. func RequireUpdate(t *testing.T, peerUpdates *p2p.PeerUpdates, expect p2p.PeerUpdate) { t.Logf("awaiting update %v", expect) - update,err := utils.Recv(t.Context(),peerUpdates.Updates()) - if err!=nil { + update, err := utils.Recv(t.Context(), peerUpdates.Updates()) + if err != nil { require.FailNow(t, "utils.Recv(): %w", err) } require.Equal(t, expect.NodeID, update.NodeID, "node id did not match") @@ -81,9 +81,9 @@ func RequireUpdates(t *testing.T, peerUpdates *p2p.PeerUpdates, expect []p2p.Pee t.Logf("awaiting %d updates", len(expect)) actual := []p2p.PeerUpdate{} for { - update,err := utils.Recv(t.Context(),peerUpdates.Updates()) - if err!=nil { - require.FailNow(t,"utils.Recv(): %v",err) + update, err := utils.Recv(t.Context(), peerUpdates.Updates()) + if err != nil { + require.FailNow(t, "utils.Recv(): %v", err) } actual = append(actual, update) if len(actual) == len(expect) { diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 907776fc7..9070afb90 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -1326,7 +1326,7 @@ func (s *peerStore) Ranked() []*peerInfo { sort.Slice(s.ranked, func(i, j int) bool { // FIXME: If necessary, consider precomputing scores before sorting, // to reduce the number of Score() calls. - if a,b := s.ranked[i].Score(),s.ranked[j].Score(); a != b { + if a, b := s.ranked[i].Score(), s.ranked[j].Score(); a != b { return a > b } // TODO(gprusak): we don't allow ties because tests require deterministic order. diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index d346376cd..19e51bdf7 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -834,7 +834,7 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} c := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("c", 40))} - logger,_ := log.NewDefaultLogger("plain","debug") + logger, _ := log.NewDefaultLogger("plain", "debug") peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ PeerScores: map[types.NodeID]p2p.PeerScore{ a.NodeID: p2p.DefaultMutableScore - 1, // Set lower score for a to make it upgradeable diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 4b0c67093..a1c2d20af 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -79,7 +79,7 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { r.pexInCh.Send(p2p.Envelope{ From: badNode, Message: &p2pproto.PexRequest{}, - },0) + }, 0) resp := <-r.pexOutCh msg, ok := resp.Message.(*p2pproto.PexResponse) @@ -89,7 +89,7 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { r.pexInCh.Send(p2p.Envelope{ From: badNode, Message: &p2pproto.PexRequest{}, - },0) + }, 0) peerErr := <-r.pexErrCh require.Error(t, peerErr.Err) @@ -175,7 +175,7 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { Message: &p2pproto.PexResponse{ Addresses: addresses, }, - },0) + }, 0) case <-time.After(10 * time.Second): t.Fatal("pex failed to send a request within 10 seconds") diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 6bf6a561f..f6cc8dbec 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -16,9 +16,9 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/service" "github.com/tendermint/tendermint/libs/utils" "github.com/tendermint/tendermint/libs/utils/scope" - "github.com/tendermint/tendermint/libs/service" "github.com/tendermint/tendermint/types" ) @@ -91,8 +91,8 @@ func (o *RouterOptions) Validate() error { } type peerState struct { - cancel context.CancelFunc - queue *Queue // outbound messages per peer for all channels + cancel context.CancelFunc + queue *Queue // outbound messages per peer for all channels channels ChannelIDSet // the channels that the peer queue has open } @@ -150,7 +150,7 @@ type Router struct { endpoint *Endpoint connTracker connectionTracker - peerStates utils.RWMutex[map[types.NodeID]*peerState] + peerStates utils.RWMutex[map[types.NodeID]*peerState] nodeInfoProducer func() *types.NodeInfo // FIXME: We don't strictly need to use a mutex for this if we seal the @@ -256,8 +256,10 @@ func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { s.Spawn(func() error { return r.routeChannel(ctx, chDesc, outCh, wrapper) }) for { - peerError,err := utils.Recv(ctx,errCh) - if err!=nil { return err } + peerError, err := utils.Recv(ctx, errCh) + if err != nil { + return err + } shouldEvict := peerError.Fatal || r.peerManager.HasMaxPeerCapacity() r.logger.Error("peer error", "peer", peerError.NodeID, @@ -290,8 +292,10 @@ func (r *Router) routeChannel( wrapper Wrapper, ) error { for { - envelope,err := utils.Recv(ctx, outCh) - if err!=nil { return err } + envelope, err := utils.Recv(ctx, outCh) + if err != nil { + return err + } if envelope.IsZero() { continue } @@ -326,7 +330,7 @@ func (r *Router) routeChannel( ok := false var s *peerState for states := range r.peerStates.RLock() { - s,ok = states[envelope.To] + s, ok = states[envelope.To] } if !ok { r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chDesc.ID) @@ -434,7 +438,7 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) { } // Spawn a goroutine for the handshake, to avoid head-of-line blocking. - r.Spawn("openConnection",func(ctx context.Context) error { return r.openConnection(ctx, conn) }) + r.Spawn("openConnection", func(ctx context.Context) error { return r.openConnection(ctx, conn) }) } } @@ -466,7 +470,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { // message to make sure both ends have accepted the connection, such // that it can be coordinated with the peer manager. peerInfo, err := r.handshakePeer(ctx, conn, "") - if err!=nil { + if err != nil { return fmt.Errorf("peer handshake failed: endpoint=%v: %w", conn, err) } if err := r.filterPeersID(ctx, peerInfo.NodeID); err != nil { @@ -474,12 +478,15 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { return nil } + // TODO(gprusak): this is fragile that updating peerManager requires a lock on peerStates. + // If this is intended, they should just share the same mutex. + // Also currently the pattern of keeping the mutex locked for peerManager accesses is inconsistent. if err := r.runWithPeerMutex(func() error { return r.peerManager.Accepted(peerInfo.NodeID) }); err != nil { // If peer is trying to reconnect, error and let it reconnect if strings.Contains(err.Error(), "is already connected") { r.peerManager.Errored(peerInfo.NodeID, err) } - return fmt.Errorf("failed to accept connection: op=incoming/accepted, peer=%v: %w",peerInfo.NodeID,err) + return fmt.Errorf("failed to accept connection: op=incoming/accepted, peer=%v: %w", peerInfo.NodeID, err) } return r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) } @@ -569,6 +576,11 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { if err := r.runWithPeerMutex(func() error { return r.peerManager.Dialed(address) }); err != nil { // If peer is trying to reconnect, fail it and let it reconnect + // TODO(gprusak): this symmetric logic for handling duplicate connections is a source of race conditions: + // if 2 nodes try to establish a connection to each other at the same time, both connections will be dropped. + // Instead either: + // * break the symmetry by favoring incoming connection iff my.NodeID > peer.NodeID + // * keep incoming and outcoming connection pools separate to avoid the collision (recommended) if strings.Contains(err.Error(), "is already connected") { r.logger.Error(fmt.Sprintf("Disconnecting %s because of %s", address.NodeID, err)) r.peerManager.Disconnected(ctx, address.NodeID) @@ -580,7 +592,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return } - r.Spawn("routePeer",func(ctx context.Context) error { + r.Spawn("routePeer", func(ctx context.Context) error { defer conn.Close() return r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) }) @@ -691,6 +703,7 @@ func (r *Router) runWithPeerMutex(fn func() error) error { } panic("unreachable") } + // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. @@ -698,14 +711,14 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec r.metrics.Peers.Add(1) r.peerManager.Ready(ctx, peerID, channels) - peerCtx,cancel := context.WithCancel(ctx) + peerCtx, cancel := context.WithCancel(ctx) state := &peerState{ cancel: cancel, - queue: NewQueue(queueBufferDefault), + queue: NewQueue(queueBufferDefault), channels: channels, } for states := range r.peerStates.Lock() { - if old,ok := states[peerID]; ok { + if old, ok := states[peerID]; ok { old.cancel() } states[peerID] = state @@ -715,12 +728,15 @@ func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connec s.Spawn(func() error { return r.receivePeer(ctx, peerID, conn) }) s.Spawn(func() error { return r.sendPeer(ctx, peerID, conn, state.queue) }) <-ctx.Done() - _ = conn.Close() // TODO: conn.ReceiveMessage() (either in inmem or mconn) does not respect the context. Fix that. + // TODO(gprusak): we need to close the connection here, because + // the mock connection used in tests does not respect the context. + // Get rid of these stupid mocks. + _ = conn.Close() return nil }) r.logger.Info("peer disconnected", "peer", peerID, "endpoint", conn, "err", err) for states := range r.peerStates.Lock() { - if states[peerID]==state { + if states[peerID] == state { delete(states, peerID) } } @@ -768,7 +784,7 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn start := time.Now().UTC() // Priority is not used since all messages in this queue are from the same channel. - queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID},0) + queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID}, 0) r.metrics.PeerReceiveBytesTotal.With( "chID", fmt.Sprint(chID), "peer_id", string(peerID), @@ -782,8 +798,10 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue *Queue) error { for { start := time.Now().UTC() - envelope,err := peerQueue.Recv(ctx) - if err!=nil { return err } + envelope, err := peerQueue.Recv(ctx) + if err != nil { + return err + } r.metrics.RouterPeerQueueRecv.Observe(time.Since(start).Seconds()) if envelope.Message == nil { r.logger.Error("dropping nil message", "peer", peerID) @@ -820,7 +838,7 @@ func (r *Router) evictPeers(ctx context.Context) { r.logger.Info("evicting peer", "peer", peerID) for states := range r.peerStates.Lock() { - if s,ok := states[peerID]; ok { + if s, ok := states[peerID]; ok { s.cancel() } } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 0824577aa..efb153714 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -5,12 +5,12 @@ import ( "errors" "fmt" "io" + slog "log" "runtime" "strings" "sync" "testing" "time" - slog "log" "github.com/fortytw2/leaktest" "github.com/gogo/protobuf/proto" @@ -98,7 +98,7 @@ func TestRouter_Network(t *testing.T) { func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) - logger,_ := log.NewDefaultLogger("plain","debug") + logger, _ := log.NewDefaultLogger("plain", "debug") ctx := t.Context() @@ -184,8 +184,8 @@ func TestRouter_Channel_SendReceive(t *testing.T) { t.Logf("Sending to an unknown peer should be dropped.") p2ptest.RequireSend(t, a, p2p.Envelope{ - To: types.NodeID(strings.Repeat("a", 40)), - Message: &p2ptest.Message{Value: "a"}, + To: types.NodeID(strings.Repeat("a", 40)), + Message: &p2ptest.Message{Value: "a"}, ChannelID: chDesc.ID, }) p2ptest.RequireEmpty(t, a, b, c) @@ -259,7 +259,7 @@ func TestRouter_Channel_Wrapper(t *testing.T) { MessageType: &wrapperMessage{}, Priority: 5, SendQueueCapacity: 10, - RecvBufferCapacity: 10, + RecvBufferCapacity: 10, RecvMessageCapacity: 10, } @@ -281,13 +281,13 @@ func TestRouter_Channel_Wrapper(t *testing.T) { // If we send the wrapper message itself, it should also be passed through // since WrapperMessage supports it, and should only be unwrapped at the receiver. p2ptest.RequireSend(t, a, p2p.Envelope{ - To: bID, - Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, + To: bID, + Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, ChannelID: chDesc.ID, }) p2ptest.RequireReceive(t, b, p2p.Envelope{ - From: aID, - Message: &p2ptest.Message{Value: "unwrap:foo"}, + From: aID, + Message: &p2ptest.Message{Value: "unwrap:foo"}, ChannelID: chDesc.ID, }) @@ -730,7 +730,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func TestRouter_EvictPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) - logger,_ := log.NewDefaultLogger("plain","debug") + logger, _ := log.NewDefaultLogger("plain", "debug") ctx := t.Context() diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index 8af401839..a5ce4175e 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -14,54 +14,56 @@ type ord[T any] interface { } type withIdx[T any] struct { - v T + v T minIdx int // index in byMin maxIdx int // index in byMax } func newWithIdx[T any](v T) *withIdx[T] { - return &withIdx[T] { v: v } + return &withIdx[T]{v: v} } // Heap returning minimal elements. -type byMin[T ord[T]] struct { a []*withIdx[T] } -func newByMin[T ord[T]](capacity int) byMin[T] { return byMin[T]{make([]*withIdx[T],0,capacity)} } -func (x *byMin[T]) Less(i, j int) bool { return x.a[i].v.Less(x.a[j].v) } -func (x *byMin[T]) Len() int { return len(x.a) } +type byMin[T ord[T]] struct{ a []*withIdx[T] } + +func newByMin[T ord[T]](capacity int) byMin[T] { return byMin[T]{make([]*withIdx[T], 0, capacity)} } +func (x *byMin[T]) Less(i, j int) bool { return x.a[i].v.Less(x.a[j].v) } +func (x *byMin[T]) Len() int { return len(x.a) } func (x *byMin[T]) Swap(i, j int) { - x.a[i],x.a[j] = x.a[j],x.a[i] + x.a[i], x.a[j] = x.a[j], x.a[i] x.a[i].minIdx = i x.a[j].minIdx = j } func (x *byMin[T]) Push(v any) { w := v.(*withIdx[T]) w.minIdx = len(x.a) - x.a = append(x.a,w) + x.a = append(x.a, w) } func (x *byMin[T]) Pop() any { - n := len(x.a)-1 + n := len(x.a) - 1 w := x.a[n] x.a = x.a[:n] return w } // Heap returning maximal elements. -type byMax[T ord[T]] struct { a []*withIdx[T] } -func newByMax[T ord[T]](capacity int) byMax[T] { return byMax[T]{make([]*withIdx[T],0,capacity)} } -func (x *byMax[T]) Less(i, j int) bool { return x.a[j].v.Less(x.a[i].v) } -func (x *byMax[T]) Len() int { return len(x.a) } +type byMax[T ord[T]] struct{ a []*withIdx[T] } + +func newByMax[T ord[T]](capacity int) byMax[T] { return byMax[T]{make([]*withIdx[T], 0, capacity)} } +func (x *byMax[T]) Less(i, j int) bool { return x.a[j].v.Less(x.a[i].v) } +func (x *byMax[T]) Len() int { return len(x.a) } func (x *byMax[T]) Swap(i, j int) { - x.a[i],x.a[j] = x.a[j],x.a[i] + x.a[i], x.a[j] = x.a[j], x.a[i] x.a[i].maxIdx = i x.a[j].maxIdx = j } func (x *byMax[T]) Push(v any) { w := v.(*withIdx[T]) w.maxIdx = len(x.a) - x.a = append(x.a,w) + x.a = append(x.a, w) } func (x *byMax[T]) Pop() any { - n := len(x.a)-1 + n := len(x.a) - 1 w := x.a[n] x.a = x.a[:n] return w @@ -79,11 +81,11 @@ type pqEnvelope struct { // true <=> a has higher priority than b func (a *pqEnvelope) Less(b *pqEnvelope) bool { // higher base priority wins - if a,b := a.priority,b.priority; a!=b { + if a, b := a.priority, b.priority; a != b { return a > b } // newer timestamp wins - if a,b := a.timestamp,b.timestamp; a.Sub(b).Abs() >= 10*time.Millisecond { + if a, b := a.timestamp, b.timestamp; a.Sub(b).Abs() >= 10*time.Millisecond { return a.After(b) } // larger first @@ -92,17 +94,17 @@ func (a *pqEnvelope) Less(b *pqEnvelope) bool { type inner struct { capacity int - byMin byMin[*pqEnvelope] - byMax byMax[*pqEnvelope] + byMin byMin[*pqEnvelope] + byMax byMax[*pqEnvelope] } func newInner(capacity int) *inner { - return &inner { + return &inner{ capacity: capacity, // We prune the maximal elements whenever capacity is exceeded. // Therefore to avoid reallocation we need the heaps to have capacity+1. - byMin: newByMin[*pqEnvelope](capacity+1), - byMax: newByMax[*pqEnvelope](capacity+1), + byMin: newByMin[*pqEnvelope](capacity + 1), + byMax: newByMax[*pqEnvelope](capacity + 1), } } @@ -110,28 +112,28 @@ func (i *inner) Len() int { return i.byMin.Len() } func (i *inner) Push(e *pqEnvelope) { w := newWithIdx(e) - heap.Push(&i.byMin,w) - heap.Push(&i.byMax,w) - if i.byMin.Len()>i.capacity { + heap.Push(&i.byMin, w) + heap.Push(&i.byMax, w) + if i.byMin.Len() > i.capacity { w := heap.Pop(&i.byMax).(*withIdx[*pqEnvelope]) - heap.Remove(&i.byMin,w.minIdx) + heap.Remove(&i.byMin, w.minIdx) } } func (i *inner) Pop() *pqEnvelope { w := heap.Pop(&i.byMin).(*withIdx[*pqEnvelope]) - heap.Remove(&i.byMax,w.maxIdx) + heap.Remove(&i.byMax, w.maxIdx) return w.v } -type Queue struct { inner utils.Watch[*inner] } +type Queue struct{ inner utils.Watch[*inner] } func NewQueue(size int) *Queue { - if size<=0 { + if size <= 0 { size = 1 } // TODO(gprusak): this size*size looks ridiculous. Fix it. - return &Queue{inner: utils.NewWatch(newInner(size*size))} + return &Queue{inner: utils.NewWatch(newInner(size * size))} } func (q *Queue) Len() int { @@ -150,19 +152,19 @@ func (q *Queue) Send(e Envelope, priority int) { priority: priority, timestamp: time.Now().UTC(), } - for inner,ctrl := range q.inner.Lock() { + for inner, ctrl := range q.inner.Lock() { inner.Push(pqe) ctrl.Updated() } } // Blocking recv. -func (q *Queue) Recv(ctx context.Context) (Envelope,error) { - for inner,ctrl := range q.inner.Lock() { - if err:=ctrl.WaitUntil(ctx,func() bool { return inner.Len()>0 }); err!=nil { - return Envelope{},err +func (q *Queue) Recv(ctx context.Context) (Envelope, error) { + for inner, ctrl := range q.inner.Lock() { + if err := ctrl.WaitUntil(ctx, func() bool { return inner.Len() > 0 }); err != nil { + return Envelope{}, err } - return inner.Pop().envelope,nil + return inner.Pop().envelope, nil } panic("unreachable") } diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 8c01a13e9..3187c69ec 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -12,12 +12,12 @@ func TestSimpleQueue(t *testing.T) { // queue, most of which we'll watch it drop. sq := NewQueue(1) for range 100 { - sq.Send(Envelope{From: "merlin"},0) + sq.Send(Envelope{From: "merlin"}, 0) } - if _,err:=sq.Recv(ctx); err!=nil { + if _, err := sq.Recv(ctx); err != nil { t.Fatal(err) } - if sq.Len()!=0 { + if sq.Len() != 0 { t.Fatalf("queue length is %d, should be 0", sq.Len()) } } diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 19edb0c2a..b61cb31ac 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -246,7 +246,7 @@ func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { From: types.NodeID("aa"), ChannelID: ChunkChannel, Message: &ssproto.SnapshotsRequest{}, - },0) + }, 0) response := <-rts.chunkPeerErrCh require.Error(t, response.Err) @@ -301,7 +301,7 @@ func TestReactor_ChunkRequest(t *testing.T) { From: types.NodeID("aa"), ChannelID: ChunkChannel, Message: tc.request, - },0) + }, 0) response := <-rts.chunkOutCh require.Equal(t, tc.expectResponse, response.Message) @@ -321,7 +321,7 @@ func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { From: types.NodeID("aa"), ChannelID: SnapshotChannel, Message: &ssproto.ChunkRequest{}, - },0) + }, 0) response := <-rts.snapshotPeerErrCh require.Error(t, response.Err) @@ -381,7 +381,7 @@ func TestReactor_SnapshotsRequest(t *testing.T) { From: types.NodeID("aa"), ChannelID: SnapshotChannel, Message: &ssproto.SnapshotsRequest{}, - },0) + }, 0) if len(tc.expectResponses) > 0 { retryUntil(ctx, t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) @@ -440,7 +440,7 @@ func TestReactor_LightBlockResponse(t *testing.T) { Message: &ssproto.LightBlockRequest{ Height: 10, }, - },0) + }, 0) require.Empty(t, rts.blockPeerErrCh) select { @@ -737,7 +737,7 @@ func handleLightBlockRequests( Message: &ssproto.LightBlockResponse{ LightBlock: lb, }, - },0) + }, 0) } else { switch errorCount % 3 { case 0: // send a different block @@ -751,7 +751,7 @@ func handleLightBlockRequests( Message: &ssproto.LightBlockResponse{ LightBlock: differntLB, }, - },0) + }, 0) case 1: // send nil block i.e. pretend we don't have it sending.Send(p2p.Envelope{ From: envelope.To, @@ -759,7 +759,7 @@ func handleLightBlockRequests( Message: &ssproto.LightBlockResponse{ LightBlock: nil, }, - },0) + }, 0) case 2: // don't do anything } errorCount++ @@ -799,7 +799,7 @@ func handleConsensusParamsRequest( Height: msg.Height, ConsensusParams: paramsProto, }, - },0) + }, 0) case <-closeCh: return } @@ -908,7 +908,7 @@ func handleSnapshotRequests( Hash: snapshot.Hash, Metadata: snapshot.Metadata, }, - },0) + }, 0) } } } @@ -942,7 +942,7 @@ func handleChunkRequests( Chunk: chunk, Missing: false, }, - },0) + }, 0) } } diff --git a/libs/utils/testonly.go b/libs/utils/testonly.go index ca5884cb6..4ef001ead 100644 --- a/libs/utils/testonly.go +++ b/libs/utils/testonly.go @@ -7,9 +7,9 @@ import ( "reflect" "time" + "github.com/gogo/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/gogo/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) From dc494a726f58d60a3dfa4ed91c203e686696ee98 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 27 Aug 2025 10:24:08 +0200 Subject: [PATCH 22/41] moved size quirk outside of simple-queue --- internal/p2p/router.go | 6 ++++-- internal/p2p/rqueue.go | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index f6cc8dbec..737cbec8a 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -22,7 +22,7 @@ import ( "github.com/tendermint/tendermint/types" ) -const queueBufferDefault = 32 +const queueBufferDefault = 1024 // RouterOptions specifies options for a Router. type RouterOptions struct { @@ -233,7 +233,9 @@ func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { messageType := chDesc.MessageType - queue := NewQueue(chDesc.RecvBufferCapacity) + // TODO(gprusak): get rid of this random cap*cap value once we understand + // what the sizes per channel really should be. + queue := NewQueue(chDesc.RecvBufferCapacity * chDesc.RecvBufferCapacity) outCh := make(chan Envelope, chDesc.RecvBufferCapacity) errCh := make(chan PeerError, chDesc.RecvBufferCapacity) channel := NewChannel(id, queue, outCh, errCh) diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index a5ce4175e..f1000cc32 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -130,10 +130,10 @@ type Queue struct{ inner utils.Watch[*inner] } func NewQueue(size int) *Queue { if size <= 0 { + // prevent caller from shooting self in the foot. size = 1 } - // TODO(gprusak): this size*size looks ridiculous. Fix it. - return &Queue{inner: utils.NewWatch(newInner(size * size))} + return &Queue{inner: utils.NewWatch(newInner(size))} } func (q *Queue) Len() int { From 1c76fb69bddf4ce20e48c6916ac72ff3611ce9f9 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 27 Aug 2025 10:49:51 +0200 Subject: [PATCH 23/41] queue tests --- internal/p2p/rqueue_test.go | 79 +++++++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 11 deletions(-) diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 3187c69ec..2b40d8d1b 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -1,23 +1,80 @@ package p2p import ( + "context" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" + "slices" "testing" ) -func TestSimpleQueue(t *testing.T) { +func TestQueuePruning(t *testing.T) { ctx := t.Context() - - // set up a small queue with very small buffers so we can - // watch it shed load, then send a bunch of messages to the - // queue, most of which we'll watch it drop. - sq := NewQueue(1) + rng := utils.TestRng() + n := 20 + var want []int + sq := NewQueue(n) for range 100 { - sq.Send(Envelope{From: "merlin"}, 0) + // Send a bunch of messages. + for range 30 { + // priority is not part of the envelope currently, + // so we hack it by encoding it as a ChannelID. + v := ChannelID(rng.Int()) + sq.Send(Envelope{From: "merlin", ChannelID: v}, int(v)) + want = append(want, int(v)) + } + + // Low priority messages should be dropped. + slices.Sort(want) + l := len(want) + want = want[l-n:] + if len(want) != sq.Len() { + t.Fatalf("expected len %d, got %d", len(want), sq.Len()) + } + + // Receive a bunch of messages. + for range 5 { + got, err := sq.Recv(ctx) + if err != nil { + t.Fatal(err) + } + l := len(want) + if got, want := int(got.ChannelID), want[l-1]; got != want { + t.Fatalf("sq.Recv() = %d, want %d", got, want) + } + want = want[:l-1] + } + if len(want) != sq.Len() { + t.Fatalf("expected len %d, got %d", len(want), sq.Len()) + } } - if _, err := sq.Recv(ctx); err != nil { +} + +// Test that receivers are notified when a message is available. +func TestQueueConcurrency(t *testing.T) { + ctx := t.Context() + q1, q2 := NewQueue(1), NewQueue(1) + + if err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBg(func() error { + // Echo task. + for { + msg, err := q1.Recv(ctx) + if err != nil { + return err + } + q2.Send(msg, 0) + } + }) + // Send and receive a bunch of messages. + for range 100 { + q1.Send(Envelope{From: "merlin"}, 0) + if _, err := q2.Recv(ctx); err != nil { + return err + } + } + return nil + })); err != nil { t.Fatal(err) } - if sq.Len() != 0 { - t.Fatalf("queue length is %d, should be 0", sq.Len()) - } } From b8f730c7c2c2f1e6b285e2975315006d25291c9d Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 27 Aug 2025 11:11:58 +0200 Subject: [PATCH 24/41] metrics for dropped messages --- internal/p2p/metrics.gen.go | 17 +++++------------ internal/p2p/metrics.go | 12 +++--------- internal/p2p/router.go | 8 ++++++-- internal/p2p/rqueue.go | 11 ++++++++--- 4 files changed, 22 insertions(+), 26 deletions(-) diff --git a/internal/p2p/metrics.gen.go b/internal/p2p/metrics.gen.go index 101e652f7..d233d770a 100644 --- a/internal/p2p/metrics.gen.go +++ b/internal/p2p/metrics.gen.go @@ -62,18 +62,12 @@ func PrometheusMetrics(namespace string, labelsAndValues ...string) *Metrics { Name: "router_channel_queue_send", Help: "The time taken to send on a p2p channel's queue which will later be consued by the corresponding reactor/service.", }, labels).With(labelsAndValues...), - PeerQueueDroppedMsgs: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + QueueDroppedMsgs: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ Namespace: namespace, Subsystem: MetricsSubsystem, - Name: "router_channel_queue_dropped_msgs", - Help: "The number of messages dropped from a peer's queue for a specific p2p Channel.", - }, append(labels, "ch_id")).With(labelsAndValues...), - PeerQueueMsgSize: prometheus.NewGaugeFrom(stdprometheus.GaugeOpts{ - Namespace: namespace, - Subsystem: MetricsSubsystem, - Name: "peer_queue_msg_size", - Help: "The size of messages sent over a peer's queue for a specific p2p Channel.", - }, append(labels, "ch_id")).With(labelsAndValues...), + Name: "queue_dropped_msgs", + Help: "The number of messages dropped from router's queues.", + }, append(labels, "ch_id", "direction")).With(labelsAndValues...), } } @@ -87,7 +81,6 @@ func NopMetrics() *Metrics { RouterPeerQueueRecv: discard.NewHistogram(), RouterPeerQueueSend: discard.NewHistogram(), RouterChannelQueueSend: discard.NewHistogram(), - PeerQueueDroppedMsgs: discard.NewCounter(), - PeerQueueMsgSize: discard.NewGauge(), + QueueDroppedMsgs: discard.NewCounter(), } } diff --git a/internal/p2p/metrics.go b/internal/p2p/metrics.go index bc9678414..41513d032 100644 --- a/internal/p2p/metrics.go +++ b/internal/p2p/metrics.go @@ -52,15 +52,9 @@ type Metrics struct { //metrics:The time taken to send on a p2p channel's queue which will later be consued by the corresponding reactor/service. RouterChannelQueueSend metrics.Histogram - // PeerQueueDroppedMsgs defines the number of messages dropped from a peer's - // queue for a specific flow (i.e. Channel). - //metrics:The number of messages dropped from a peer's queue for a specific p2p Channel. - PeerQueueDroppedMsgs metrics.Counter `metrics_labels:"ch_id" metrics_name:"router_channel_queue_dropped_msgs"` - - // PeerQueueMsgSize defines the average size of messages sent over a peer's - // queue for a specific flow (i.e. Channel). - //metrics:The size of messages sent over a peer's queue for a specific p2p Channel. - PeerQueueMsgSize metrics.Gauge `metrics_labels:"ch_id" metric_name:"router_channel_queue_msg_size"` + // QueueDroppedMsgs counts the messages dropped from the router's queues. + //metrics:The number of messages dropped from router's queues. + QueueDroppedMsgs metrics.Counter `metrics_labels:"ch_id, direction"` } type metricsLabelCache struct { diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 737cbec8a..e20a70445 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -349,7 +349,9 @@ func (r *Router) routeChannel( } // send message to peers for _, q := range queues { - q.Send(envelope, chDesc.Priority) + if pruned, ok := q.Send(envelope, chDesc.Priority).Get(); ok { + r.metrics.QueueDroppedMsgs.With("ch_id", string(pruned.ChannelID), "direction", "out").Add(float64(1)) + } } } } @@ -786,7 +788,9 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn start := time.Now().UTC() // Priority is not used since all messages in this queue are from the same channel. - queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID}, 0) + if pruned, ok := queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID}, 0).Get(); ok { + r.metrics.QueueDroppedMsgs.With("ch_id", string(pruned.ChannelID), "direction", "in").Add(float64(1)) + } r.metrics.PeerReceiveBytesTotal.With( "chID", fmt.Sprint(chID), "peer_id", string(peerID), diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index f1000cc32..00cefc4d8 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -110,14 +110,16 @@ func newInner(capacity int) *inner { func (i *inner) Len() int { return i.byMin.Len() } -func (i *inner) Push(e *pqEnvelope) { +func (i *inner) Push(e *pqEnvelope) utils.Option[Envelope] { w := newWithIdx(e) heap.Push(&i.byMin, w) heap.Push(&i.byMax, w) if i.byMin.Len() > i.capacity { w := heap.Pop(&i.byMax).(*withIdx[*pqEnvelope]) heap.Remove(&i.byMin, w.minIdx) + return utils.Some(w.v.envelope) } + return utils.None[Envelope]() } func (i *inner) Pop() *pqEnvelope { @@ -144,7 +146,8 @@ func (q *Queue) Len() int { } // Non-blocking send. -func (q *Queue) Send(e Envelope, priority int) { +// Returns the pruned message if any. +func (q *Queue) Send(e Envelope, priority int) utils.Option[Envelope] { // We construct the pqEnvelope without holding the lock to avoid contention. pqe := &pqEnvelope{ envelope: e, @@ -153,9 +156,11 @@ func (q *Queue) Send(e Envelope, priority int) { timestamp: time.Now().UTC(), } for inner, ctrl := range q.inner.Lock() { - inner.Push(pqe) + pruned := inner.Push(pqe) ctrl.Updated() + return pruned } + panic("unreachable") } // Blocking recv. From 4f3f0564804e8dcef3dd57e27dfd7941c56fd7c6 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 27 Aug 2025 11:18:28 +0200 Subject: [PATCH 25/41] fix --- internal/p2p/router.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index e20a70445..872e896d2 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -350,7 +350,7 @@ func (r *Router) routeChannel( // send message to peers for _, q := range queues { if pruned, ok := q.Send(envelope, chDesc.Priority).Get(); ok { - r.metrics.QueueDroppedMsgs.With("ch_id", string(pruned.ChannelID), "direction", "out").Add(float64(1)) + r.metrics.QueueDroppedMsgs.With("ch_id", fmt.Sprint(pruned.ChannelID), "direction", "out").Add(float64(1)) } } } @@ -786,16 +786,14 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn } } - start := time.Now().UTC() // Priority is not used since all messages in this queue are from the same channel. if pruned, ok := queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID}, 0).Get(); ok { - r.metrics.QueueDroppedMsgs.With("ch_id", string(pruned.ChannelID), "direction", "in").Add(float64(1)) + r.metrics.QueueDroppedMsgs.With("ch_id", fmt.Sprint(pruned.ChannelID), "direction", "in").Add(float64(1)) } r.metrics.PeerReceiveBytesTotal.With( "chID", fmt.Sprint(chID), "peer_id", string(peerID), "message_type", r.lc.ValueToMetricLabel(msg)).Add(float64(proto.Size(msg))) - r.metrics.RouterChannelQueueSend.Observe(time.Since(start).Seconds()) r.logger.Debug("received message", "peer", peerID, "message", msg) } } From e8589df0f0f364a9660c212142e3d427ef4e6f8c Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Thu, 28 Aug 2025 11:58:38 +0200 Subject: [PATCH 26/41] race condition fix --- internal/p2p/p2ptest/network.go | 2 +- internal/p2p/router.go | 143 ++++++++++++-------------------- internal/p2p/router_test.go | 2 +- 3 files changed, 54 insertions(+), 93 deletions(-) diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index a8de3ff63..1e2c1732d 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -274,7 +274,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) transport, ep, nil, - p2p.RouterOptions{DialSleep: func(_ context.Context) {}}, + p2p.RouterOptions{DialSleep: func(_ context.Context) error { return nil }}, ) require.NoError(t, err) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 872e896d2..bee03d1ec 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -65,7 +65,7 @@ type RouterOptions struct { // sleeps between dialing peers. If not set, a default value // is used that sleeps for a (random) amount of time up to 3 // seconds between submitting each peer to be dialed. - DialSleep func(context.Context) + DialSleep func(context.Context) error // NumConcrruentDials controls how many parallel go routines // are used to dial peers. This defaults to the value of @@ -387,48 +387,28 @@ func (r *Router) filterPeersID(ctx context.Context, id types.NodeID) error { return r.options.FilterPeerByID(ctx, id) } -func (r *Router) dialSleep(ctx context.Context) { - if r.options.DialSleep == nil { - const ( - maxDialerInterval = 3000 - minDialerInterval = 250 - ) - - // nolint:gosec // G404: Use of weak random number generator - dur := time.Duration(rand.Int63n(maxDialerInterval-minDialerInterval+1) + minDialerInterval) - - timer := time.NewTimer(dur * time.Millisecond) - defer timer.Stop() - - select { - case <-ctx.Done(): - case <-timer.C: - } - - return +func (r *Router) dialSleep(ctx context.Context) error { + if r.options.DialSleep != nil { + return r.options.DialSleep(ctx) } + const ( + maxDialerInterval = 3000 + minDialerInterval = 250 + ) - r.options.DialSleep(ctx) + // nolint:gosec // G404: Use of weak random number generator + dur := time.Duration(rand.Int63n(maxDialerInterval-minDialerInterval+1) + minDialerInterval) + return utils.Sleep(ctx, dur*time.Millisecond) } // acceptPeers accepts inbound connections from peers on the given transport, // and spawns goroutines that route messages to/from them. -func (r *Router) acceptPeers(ctx context.Context, transport Transport) { +func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { for { conn, err := transport.Accept(ctx) - switch { - case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - r.logger.Debug("stopping accept routine", "transport", transport, "err", "context canceled") - return - case errors.Is(err, io.EOF): - r.logger.Debug("stopping accept routine", "transport", transport, "err", "EOF") - return - case err != nil: - // in this case we got an error from the net.Listener. - r.logger.Error("failed to accept connection", "transport", transport, "err", err) - continue + if err != nil { + return fmt.Errorf("failed to accept connection: %w", err) } - incomingIP := conn.RemoteEndpoint().IP if err := r.connTracker.AddConn(incomingIP); err != nil { closeErr := conn.Close() @@ -496,59 +476,44 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { } // dialPeers maintains outbound connections to peers by dialing them. -func (r *Router) dialPeers(ctx context.Context) { - addresses := make(chan NodeAddress) - wg := &sync.WaitGroup{} - - // Start a limited number of goroutines to dial peers in - // parallel. the goal is to avoid starting an unbounded number - // of goroutines thereby spamming the network, but also being - // able to add peers at a reasonable pace, though the number - // is somewhat arbitrary. The action is further throttled by a - // sleep after sending to the addresses channel. - for i := 0; i < r.numConccurentDials(); i++ { - wg.Add(1) - go func() { - defer wg.Done() - - for { - select { - case <-ctx.Done(): - return - case address := <-addresses: +func (r *Router) dialPeers(ctx context.Context) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + addresses := make(chan NodeAddress) + // Start a limited number of goroutines to dial peers in + // parallel. the goal is to avoid starting an unbounded number + // of goroutines thereby spamming the network, but also being + // able to add peers at a reasonable pace, though the number + // is somewhat arbitrary. The action is further throttled by a + // sleep after sending to the addresses channel. + for range r.numConccurentDials() { + s.Spawn(func() error { + for { + address, err := utils.Recv(ctx, addresses) + if err != nil { + return err + } r.logger.Debug(fmt.Sprintf("Going to dial next peer %s", address.NodeID)) r.connectPeer(ctx, address) } - } - }() - } - -LOOP: - for { - address, err := r.peerManager.DialNext(ctx) - switch { - case errors.Is(err, context.Canceled): - break LOOP - case err != nil: - r.logger.Error("failed to find next peer to dial", "err", err) - break LOOP + }) } - select { - case addresses <- address: + for { + address, err := r.peerManager.DialNext(ctx) + if err != nil { + return fmt.Errorf("failed to find next peer to dial: %w", err) + } + if err := utils.Send(ctx, addresses, address); err != nil { + return err + } // this jitters the frequency that we call // DialNext and prevents us from attempting to // create connections too quickly. - - r.dialSleep(ctx) - continue - case <-ctx.Done(): - close(addresses) - break LOOP + if err := r.dialSleep(ctx); err != nil { + return err + } } - } - - wg.Wait() + }) } func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { @@ -828,16 +793,11 @@ func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connect } // evictPeers evicts connected peers as requested by the peer manager. -func (r *Router) evictPeers(ctx context.Context) { +func (r *Router) evictPeers(ctx context.Context) error { for { peerID, err := r.peerManager.EvictNext(ctx) - - switch { - case errors.Is(err, context.Canceled): - return - case err != nil: - r.logger.Error("failed to find next peer to evict", "err", err) - return + if err != nil { + return fmt.Errorf("failed to find next peer to evict: %w", err) } r.logger.Info("evicting peer", "peer", peerID) @@ -870,11 +830,12 @@ func (r *Router) OnStart(ctx context.Context) error { } } - go r.dialPeers(ctx) - go r.evictPeers(ctx) - go r.acceptPeers(ctx, r.transport) - - return nil + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnNamed("dialPeers", func() error { return r.dialPeers(ctx) }) + s.SpawnNamed("evictPeers", func() error { return r.evictPeers(ctx) }) + s.SpawnNamed("acceptPeers", func() error { return r.acceptPeers(ctx, r.transport) }) + return nil + }) } // OnStop implements service.Service. diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index efb153714..caf60f8e5 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -698,7 +698,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { nil, nil, p2p.RouterOptions{ - DialSleep: func(_ context.Context) {}, + DialSleep: func(_ context.Context) error { return nil }, NumConcurrentDials: func() int { ncpu := runtime.NumCPU() if ncpu <= 3 { From 7592ac8152f488e7c61e8d5064ccca26830f3962 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Thu, 28 Aug 2025 12:02:11 +0200 Subject: [PATCH 27/41] fix --- internal/p2p/router.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index bee03d1ec..cdf415e8a 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -830,12 +830,10 @@ func (r *Router) OnStart(ctx context.Context) error { } } - return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { - s.SpawnNamed("dialPeers", func() error { return r.dialPeers(ctx) }) - s.SpawnNamed("evictPeers", func() error { return r.evictPeers(ctx) }) - s.SpawnNamed("acceptPeers", func() error { return r.acceptPeers(ctx, r.transport) }) - return nil - }) + r.Spawn("dialPeers", func(ctx context.Context) error { return r.dialPeers(ctx) }) + r.Spawn("evictPeers", func(ctx context.Context) error { return r.evictPeers(ctx) }) + r.Spawn("acceptPeers", func(ctx context.Context) error { return r.acceptPeers(ctx, r.transport) }) + return nil } // OnStop implements service.Service. From 4f0b9375f121e9d146479f40a5cc832f5d783edb Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Fri, 29 Aug 2025 12:18:17 +0200 Subject: [PATCH 28/41] snapshot --- internal/p2p/address.go | 6 +- internal/p2p/conn/connection.go | 42 +++---- internal/p2p/conn/connection_test.go | 50 ++++---- internal/p2p/conn_tracker.go | 22 ++-- internal/p2p/metrics.gen.go | 7 ++ internal/p2p/metrics.go | 2 + internal/p2p/mocks/transport.go | 74 ++--------- internal/p2p/p2ptest/network.go | 2 - internal/p2p/peermanager.go | 43 ++++--- internal/p2p/router.go | 52 ++++---- internal/p2p/transport.go | 64 ++++------ internal/p2p/transport_mconn.go | 177 ++++++++------------------- internal/p2p/transport_mconn_test.go | 61 ++------- internal/p2p/transport_memory.go | 37 ++---- internal/p2p/transport_test.go | 9 +- libs/service/service.go | 16 +++ libs/utils/tcp/tcp.go | 79 ++++++++++++ node/node.go | 6 +- types/node_info.go | 41 ++----- 19 files changed, 336 insertions(+), 454 deletions(-) create mode 100644 libs/utils/tcp/tcp.go diff --git a/internal/p2p/address.go b/internal/p2p/address.go index 0f4066faf..9a034a4ad 100644 --- a/internal/p2p/address.go +++ b/internal/p2p/address.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/netip" "net/url" "regexp" "strconv" @@ -121,10 +122,11 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { } endpoints := make([]*Endpoint, len(ips)) for i, ip := range ips { + ip,ok := netip.AddrFromSlice(ip) + if !ok { return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) } endpoints[i] = &Endpoint{ Protocol: a.Protocol, - IP: ip, - Port: a.Port, + Addr: netip.AddrPortFrom(ip, a.Port), Path: a.Path, } } diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 4bd85d6e2..8977dd6c1 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -19,6 +19,7 @@ import ( "github.com/tendermint/tendermint/internal/libs/flowrate" "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/internal/libs/timer" + "github.com/tendermint/tendermint/libs/utils" "github.com/tendermint/tendermint/libs/log" tmmath "github.com/tendermint/tendermint/libs/math" "github.com/tendermint/tendermint/libs/service" @@ -302,9 +303,9 @@ func (c *MConnection) stopForError(ctx context.Context, r interface{}) { } // Queues a message to be sent to channel. -func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool { +func (c *MConnection) Send(ctx context.Context, chID ChannelID, msgBytes []byte) error { if !c.IsRunning() { - return false + return errors.New("not running") } c.logger.Debug("Send", "channel", chID, "conn", c, "msgBytes", msgBytes) @@ -312,21 +313,18 @@ func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool { // Send message to channel. channel, ok := c.channelsIdx[chID] if !ok { - c.logger.Error(fmt.Sprintf("Cannot send bytes, unknown channel %X", chID)) - return false + return fmt.Errorf("Cannot send bytes, unknown channel %X", chID) } - success := channel.sendBytes(msgBytes) - if success { - // Wake up sendRoutine if necessary - select { - case c.send <- struct{}{}: - default: - } - } else { - c.logger.Debug("Send failed", "channel", chID, "conn", c, "msgBytes", msgBytes) + if err := channel.sendBytes(ctx, msgBytes); err != nil { + return fmt.Errorf("channel.sendBytes(): %v",err) } - return success + // Wake up sendRoutine if necessary + select { + case c.send <- struct{}{}: + default: + } + return nil } // sendRoutine polls for packets to send from channels. @@ -648,7 +646,6 @@ type channel struct { conn *MConnection desc ChannelDescriptor sendQueue chan []byte - sendQueueSize int32 // atomic. recving []byte sending []byte @@ -675,16 +672,10 @@ func newChannel(conn *MConnection, desc ChannelDescriptor) *channel { // Queues message to send to this channel. // Goroutine-safe // Times out (and returns false) after defaultSendTimeout -func (ch *channel) sendBytes(bytes []byte) bool { - timer := time.NewTimer(defaultSendTimeout) - defer timer.Stop() - select { - case ch.sendQueue <- bytes: - atomic.AddInt32(&ch.sendQueueSize, 1) - return true - case <-timer.C: - return false - } +func (ch *channel) sendBytes(ctx context.Context, bytes []byte) error { + ctx, cancel := context.WithTimeout(ctx, defaultSendTimeout) + defer cancel() + return utils.Send(ctx, ch.sendQueue, bytes) } // Returns true if any PacketMsgs are pending to be sent. @@ -709,7 +700,6 @@ func (ch *channel) nextPacketMsg() tmp2p.PacketMsg { if len(ch.sending) <= maxSize { packet.EOF = true ch.sending = nil - atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize } else { packet.EOF = false ch.sending = ch.sending[tmmath.MinInt(maxSize, len(ch.sending)):] diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index 72e65a1a4..3e148a708 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -29,7 +29,7 @@ func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection { func(ctx context.Context, chID ChannelID, msgBytes []byte) { }, // onError - func(ctx context.Context, r interface{}) { + func(ctx context.Context, r any) { }) } @@ -37,7 +37,7 @@ func createMConnectionWithCallbacks( logger log.Logger, conn net.Conn, onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte), - onError func(ctx context.Context, r interface{}), + onError func(ctx context.Context, r any), ) *MConnection { cfg := DefaultMConnConfig() cfg.PingInterval = 250 * time.Millisecond @@ -59,7 +59,7 @@ func TestMConnectionSendFlushStop(t *testing.T) { t.Cleanup(waitAll(clientConn)) msg := []byte("abc") - assert.True(t, clientConn.Send(0x01, msg)) + assert.NoError(t, clientConn.Send(ctx, 0x01, msg)) msgLength := 14 @@ -95,7 +95,7 @@ func TestMConnectionSend(t *testing.T) { t.Cleanup(waitAll(mconn)) msg := []byte("Ant-Man") - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) // Note: subsequent Send/TrySend calls could pass because we are reading from // the send queue in a separate goroutine. _, err = server.Read(make([]byte, len(msg))) @@ -104,13 +104,13 @@ func TestMConnectionSend(t *testing.T) { } msg = []byte("Spider-Man") - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) _, err = server.Read(make([]byte, len(msg))) if err != nil { t.Error(err) } - assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown") + assert.Error(t, mconn.Send(ctx, 0x05, []byte("Absorbing Man")), "Send should fail because channel is unknown") } func TestMConnectionReceive(t *testing.T) { @@ -118,14 +118,14 @@ func TestMConnectionReceive(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -146,7 +146,7 @@ func TestMConnectionReceive(t *testing.T) { t.Cleanup(waitAll(mconn2)) msg := []byte("Cyclops") - assert.True(t, mconn2.Send(0x01, msg)) + assert.NoError(t, mconn2.Send(ctx, 0x01, msg)) select { case receivedBytes := <-receivedCh: @@ -203,14 +203,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -261,14 +261,14 @@ func TestMConnectionMultiplePings(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -316,14 +316,14 @@ func TestMConnectionPingPongs(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -375,14 +375,14 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -418,7 +418,7 @@ func newClientAndServerConnsForReadErrors( server, client := net.Pipe() onReceive := func(context.Context, ChannelID, []byte) {} - onError := func(context.Context, interface{}) {} + onError := func(context.Context, any) {} // create client conn with two channels chDescs := []*ChannelDescriptor{ @@ -434,7 +434,7 @@ func newClientAndServerConnsForReadErrors( // create server conn with 1 channel // it fires on chOnErr when there's an error serverLogger := logger.With("module", "server") - onError = func(ctx context.Context, r interface{}) { + onError = func(ctx context.Context, r any) { select { case <-ctx.Done(): case chOnErr <- struct{}{}: @@ -481,11 +481,11 @@ func TestMConnectionReadErrorUnknownChannel(t *testing.T) { msg := []byte("Ant-Man") // fail to send msg on channel unknown by client - assert.False(t, mconnClient.Send(0x03, msg)) + assert.Error(t, mconnClient.Send(ctx, 0x03, msg)) // send msg on channel unknown by the server. // should cause an error - assert.True(t, mconnClient.Send(0x02, msg)) + assert.NoError(t, mconnClient.Send(ctx, 0x02, msg)) assert.True(t, expectSend(chOnErr), "unknown channel") t.Cleanup(waitAll(mconnClient, mconnServer)) } @@ -557,15 +557,15 @@ func TestMConnectionTrySend(t *testing.T) { msg := []byte("Semicolon-Woman") resultCh := make(chan string, 2) - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) _, err = server.Read(make([]byte, len(msg))) require.NoError(t, err) - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) go func() { - mconn.Send(0x01, msg) + mconn.Send(ctx, 0x01, msg) resultCh <- "TrySend" }() - assert.False(t, mconn.Send(0x01, msg)) + assert.Error(t, mconn.Send(ctx, 0x01, msg)) assert.Equal(t, "TrySend", <-resultCh) } diff --git a/internal/p2p/conn_tracker.go b/internal/p2p/conn_tracker.go index 54f9c8980..385e734e9 100644 --- a/internal/p2p/conn_tracker.go +++ b/internal/p2p/conn_tracker.go @@ -2,20 +2,20 @@ package p2p import ( "fmt" - "net" + "net/netip" "sync" "time" ) type connectionTracker interface { - AddConn(net.IP) error - RemoveConn(net.IP) + AddConn(netip.AddrPort) error + RemoveConn(netip.AddrPort) Len() int } type connTrackerImpl struct { - cache map[string]uint - lastConnect map[string]time.Time + cache map[netip.Addr]uint + lastConnect map[netip.Addr]time.Time mutex sync.RWMutex max uint window time.Duration @@ -23,8 +23,8 @@ type connTrackerImpl struct { func newConnTracker(max uint, window time.Duration) connectionTracker { return &connTrackerImpl{ - cache: make(map[string]uint), - lastConnect: make(map[string]time.Time), + cache: map[netip.Addr]uint{}, + lastConnect: map[netip.Addr]time.Time{}, max: max, window: window, } @@ -36,8 +36,8 @@ func (rat *connTrackerImpl) Len() int { return len(rat.cache) } -func (rat *connTrackerImpl) AddConn(addr net.IP) error { - address := addr.String() +func (rat *connTrackerImpl) AddConn(addrPort netip.AddrPort) error { + address := addrPort.Addr() rat.mutex.Lock() defer rat.mutex.Unlock() @@ -58,8 +58,8 @@ func (rat *connTrackerImpl) AddConn(addr net.IP) error { return nil } -func (rat *connTrackerImpl) RemoveConn(addr net.IP) { - address := addr.String() +func (rat *connTrackerImpl) RemoveConn(addrPort netip.AddrPort) { + address := addrPort.Addr() rat.mutex.Lock() defer rat.mutex.Unlock() diff --git a/internal/p2p/metrics.gen.go b/internal/p2p/metrics.gen.go index d233d770a..d07febed5 100644 --- a/internal/p2p/metrics.gen.go +++ b/internal/p2p/metrics.gen.go @@ -44,6 +44,12 @@ func PrometheusMetrics(namespace string, labelsAndValues ...string) *Metrics { Name: "peer_pending_send_bytes", Help: "Number of bytes pending being sent to a given peer.", }, append(labels, "peer_id")).With(labelsAndValues...), + NewConnections: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "new_connections", + Help: "Number of newly established connections.", + }, append(labels, "direction")).With(labelsAndValues...), RouterPeerQueueRecv: prometheus.NewHistogramFrom(stdprometheus.HistogramOpts{ Namespace: namespace, Subsystem: MetricsSubsystem, @@ -78,6 +84,7 @@ func NopMetrics() *Metrics { PeerReceiveBytesTotal: discard.NewCounter(), PeerSendBytesTotal: discard.NewCounter(), PeerPendingSendBytes: discard.NewGauge(), + NewConnections: discard.NewCounter(), RouterPeerQueueRecv: discard.NewHistogram(), RouterPeerQueueSend: discard.NewHistogram(), RouterChannelQueueSend: discard.NewHistogram(), diff --git a/internal/p2p/metrics.go b/internal/p2p/metrics.go index 41513d032..52fe5b1a3 100644 --- a/internal/p2p/metrics.go +++ b/internal/p2p/metrics.go @@ -36,6 +36,8 @@ type Metrics struct { PeerSendBytesTotal metrics.Counter `metrics_labels:"peer_id, chID, message_type"` // Number of bytes pending being sent to a given peer. PeerPendingSendBytes metrics.Gauge `metrics_labels:"peer_id"` + // Number of newly established connections. + NewConnections metrics.Counter `metrics_labels:"direction"` // RouterPeerQueueRecv defines the time taken to read off of a peer's queue // before sending on the connection. diff --git a/internal/p2p/mocks/transport.go b/internal/p2p/mocks/transport.go index cd9b7ae8c..7a789f8dd 100644 --- a/internal/p2p/mocks/transport.go +++ b/internal/p2p/mocks/transport.go @@ -52,24 +52,6 @@ func (_m *Transport) AddChannelDescriptors(_a0 []*conn.ChannelDescriptor) { _m.Called(_a0) } -// Close provides a mock function with no fields -func (_m *Transport) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - // Dial provides a mock function with given fields: _a0, _a1 func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connection, error) { ret := _m.Called(_a0, _a1) @@ -100,47 +82,37 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio return r0, r1 } -// Endpoint provides a mock function with no fields -func (_m *Transport) Endpoint() (*p2p.Endpoint, error) { +// Protocols provides a mock function with no fields +func (_m *Transport) Protocols() []p2p.Protocol { ret := _m.Called() if len(ret) == 0 { - panic("no return value specified for Endpoint") + panic("no return value specified for Protocols") } - var r0 *p2p.Endpoint - var r1 error - if rf, ok := ret.Get(0).(func() (*p2p.Endpoint, error)); ok { - return rf() - } - if rf, ok := ret.Get(0).(func() *p2p.Endpoint); ok { + var r0 []p2p.Protocol + if rf, ok := ret.Get(0).(func() []p2p.Protocol); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*p2p.Endpoint) + r0 = ret.Get(0).([]p2p.Protocol) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } -// Listen provides a mock function with given fields: _a0 -func (_m *Transport) Listen(_a0 *p2p.Endpoint) error { - ret := _m.Called(_a0) +// Run provides a mock function with given fields: ctx, endpoint +func (_m *Transport) Run(ctx context.Context, endpoint *p2p.Endpoint) error { + ret := _m.Called(ctx, endpoint) if len(ret) == 0 { - panic("no return value specified for Listen") + panic("no return value specified for Run") } var r0 error - if rf, ok := ret.Get(0).(func(*p2p.Endpoint) error); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, *p2p.Endpoint) error); ok { + r0 = rf(ctx, endpoint) } else { r0 = ret.Error(0) } @@ -148,26 +120,6 @@ func (_m *Transport) Listen(_a0 *p2p.Endpoint) error { return r0 } -// Protocols provides a mock function with no fields -func (_m *Transport) Protocols() []p2p.Protocol { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Protocols") - } - - var r0 []p2p.Protocol - if rf, ok := ret.Get(0).(func() []p2p.Protocol); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]p2p.Protocol) - } - } - - return r0 -} - // String provides a mock function with no fields func (_m *Transport) String() string { ret := _m.Called() diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 1e2c1732d..940394e01 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -203,7 +203,6 @@ func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) { subs = append(subs, sub) } - require.NoError(t, node.Transport.Close()) node.cancel() if node.Router.IsRunning() { node.Router.Stop() @@ -285,7 +284,6 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) router.Stop() router.Wait() } - require.NoError(t, transport.Close()) cancel() }) diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 9070afb90..45d0d8e70 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -20,6 +20,7 @@ import ( tmsync "github.com/tendermint/tendermint/internal/libs/sync" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils" ) const ( @@ -315,7 +316,7 @@ type PeerManager struct { upgrading map[types.NodeID]types.NodeID // peers claimed for upgrade (DialNext → Dialed/DialFail) connected map[types.NodeID]bool // connected peers (Dialed/Accepted → Disconnected) ready map[types.NodeID]bool // ready peers (Ready → Disconnected) - evict map[types.NodeID]bool // peers scheduled for eviction (Connected → EvictNext) + evict map[types.NodeID]error // peers scheduled for eviction (Connected → EvictNext) evicting map[types.NodeID]bool // peers being evicted (EvictNext → Disconnected) metrics *Metrics } @@ -355,7 +356,7 @@ func NewPeerManager( upgrading: map[types.NodeID]types.NodeID{}, connected: map[types.NodeID]bool{}, ready: map[types.NodeID]bool{}, - evict: map[types.NodeID]bool{}, + evict: map[types.NodeID]error{}, evicting: map[types.NodeID]bool{}, subscriptions: map[*PeerUpdates]*PeerUpdates{}, metrics: metrics, @@ -689,7 +690,7 @@ func (m *PeerManager) Dialed(address NodeAddress) error { upgradeFromPeer = u } } - m.evict[upgradeFromPeer] = true + m.evict[upgradeFromPeer] = errors.New("too many peers") } m.connected[peer.ID] = true m.evictWaker.Wake() @@ -758,7 +759,7 @@ func (m *PeerManager) Accepted(peerID types.NodeID) error { m.connected[peerID] = true if upgradeFromPeer != "" { - m.evict[upgradeFromPeer] = true + m.evict[upgradeFromPeer] = errors.New("found better peer") } m.evictWaker.Wake() return nil @@ -787,40 +788,48 @@ func (m *PeerManager) Ready(ctx context.Context, peerID types.NodeID, channels C // EvictNext returns the next peer to evict (i.e. disconnect). If no evictable // peers are found, the call will block until one becomes available. -func (m *PeerManager) EvictNext(ctx context.Context) (types.NodeID, error) { +func (m *PeerManager) EvictNext(ctx context.Context) (Eviction, error) { for { - id, err := m.TryEvictNext() - if err != nil || id != "" { - return id, err + ev, err := m.TryEvictNext() + if err != nil { + return Eviction{}, err + } + if ev,ok := ev.Get(); ok { + return ev,nil } select { case <-m.evictWaker.Sleep(): case <-ctx.Done(): - return "", ctx.Err() + return Eviction{}, ctx.Err() } } } +type Eviction struct { + ID types.NodeID + Cause error +} + // TryEvictNext is equivalent to EvictNext, but immediately returns an empty // node ID if no evictable peers are found. -func (m *PeerManager) TryEvictNext() (types.NodeID, error) { +func (m *PeerManager) TryEvictNext() (utils.Option[Eviction], error) { m.mtx.Lock() defer m.mtx.Unlock() // If any connected peers are explicitly scheduled for eviction, we return a // random one. - for peerID := range m.evict { + for peerID,cause := range m.evict { delete(m.evict, peerID) if m.connected[peerID] && !m.evicting[peerID] { m.evicting[peerID] = true - return peerID, nil + return utils.Some(Eviction{peerID,cause}), nil } } // If we're below capacity, we don't need to evict anything. if m.options.MaxConnected == 0 || m.NumConnected()-len(m.evicting) <= int(m.options.MaxConnected) { - return "", nil + return utils.None[Eviction](), nil } // If we're above capacity (shouldn't really happen), just pick the @@ -830,11 +839,11 @@ func (m *PeerManager) TryEvictNext() (types.NodeID, error) { peer := ranked[i] if m.connected[peer.ID] && !m.evicting[peer.ID] { m.evicting[peer.ID] = true - return peer.ID, nil + return utils.Some(Eviction{peer.ID,errors.New("too many peers")}), nil } } - return "", nil + return utils.None[Eviction](), nil } // Disconnected unmarks a peer as connected, allowing it to be dialed or @@ -888,7 +897,7 @@ func (m *PeerManager) Errored(peerID types.NodeID, err error) { defer m.mtx.Unlock() if m.connected[peerID] { - m.evict[peerID] = true + m.evict[peerID] = err } m.evictWaker.Wake() @@ -1144,7 +1153,7 @@ func (m *PeerManager) findUpgradeCandidate(id types.NodeID, score PeerScore) typ case candidate.Score() >= score: return "" // no further peers can be scored lower, due to sorting case !m.connected[candidate.ID]: - case m.evict[candidate.ID]: + case m.evict[candidate.ID]!=nil: case m.evicting[candidate.ID]: case m.upgrading[candidate.ID] != "": default: diff --git a/internal/p2p/router.go b/internal/p2p/router.go index cdf415e8a..6f67e06f4 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "math/rand" - "net" + "net/netip" "runtime" "strings" "sync" @@ -51,7 +51,7 @@ type RouterOptions struct { // the remote IP of the incoming connection the port number as // arguments. Functions should return an error to reject the // peer. - FilterPeerByIP func(context.Context, net.IP, uint16) error + FilterPeerByIP func(context.Context, netip.AddrPort) error // FilterPeerByID is used by the router to inject filtering // behavior for new incoming connections. The router passes @@ -364,12 +364,12 @@ func (r *Router) numConccurentDials() int { return r.options.NumConcurrentDials() } -func (r *Router) filterPeersIP(ctx context.Context, ip net.IP, port uint16) error { +func (r *Router) filterPeersIP(ctx context.Context, addrPort netip.AddrPort) error { if r.options.FilterPeerByIP == nil { return nil } - return r.options.FilterPeerByIP(ctx, ip, port) + return r.options.FilterPeerByIP(ctx, addrPort) } func (r *Router) filterPeersID(ctx context.Context, id types.NodeID) error { @@ -409,12 +409,13 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { if err != nil { return fmt.Errorf("failed to accept connection: %w", err) } - incomingIP := conn.RemoteEndpoint().IP - if err := r.connTracker.AddConn(incomingIP); err != nil { + r.metrics.NewConnections.With("direction","in").Add(1) + incomingAddr := conn.RemoteEndpoint().Addr + if err := r.connTracker.AddConn(incomingAddr); err != nil { closeErr := conn.Close() r.logger.Error("rate limiting incoming peer", "err", err, - "ip", incomingIP.String(), + "addr", incomingAddr.String(), "close_err", closeErr, ) @@ -428,13 +429,12 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { func (r *Router) openConnection(ctx context.Context, conn Connection) error { defer conn.Close() - defer r.connTracker.RemoveConn(conn.RemoteEndpoint().IP) + incomingAddr := conn.RemoteEndpoint().Addr + defer r.connTracker.RemoveConn(incomingAddr) - re := conn.RemoteEndpoint() - incomingIP := re.IP - if err := r.filterPeersIP(ctx, incomingIP, re.Port); err != nil { - r.logger.Debug("peer filtered by IP", "ip", incomingIP.String(), "err", err) + if err := r.filterPeersIP(ctx, incomingAddr); err != nil { + r.logger.Debug("peer filtered by IP", "ip", incomingAddr, "err", err) return nil } @@ -609,6 +609,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, if err != nil { r.logger.Debug("failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) } else { + r.metrics.NewConnections.With("direction","out").Add(1) r.logger.Debug("dialed peer", "peer", address.NodeID, "endpoint", endpoint) return conn, nil } @@ -795,14 +796,13 @@ func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connect // evictPeers evicts connected peers as requested by the peer manager. func (r *Router) evictPeers(ctx context.Context) error { for { - peerID, err := r.peerManager.EvictNext(ctx) + ev, err := r.peerManager.EvictNext(ctx) if err != nil { return fmt.Errorf("failed to find next peer to evict: %w", err) } - - r.logger.Info("evicting peer", "peer", peerID) + r.logger.Info("evicting peer", "peer", ev.ID,"cause",ev.Cause) for states := range r.peerStates.Lock() { - if s, ok := states[peerID]; ok { + if s, ok := states[ev.ID]; ok { s.cancel() } } @@ -818,10 +818,6 @@ func (r *Router) AddChDescToBeAdded(chDesc *ChannelDescriptor, callback func(*Ch // OnStart implements service.Service. func (r *Router) OnStart(ctx context.Context) error { - if err := r.transport.Listen(r.endpoint); err != nil { - return err - } - for _, chDescWithCb := range r.chDescsToBeAdded { if ch, err := r.OpenChannel(chDescWithCb.chDesc); err != nil { return err @@ -830,9 +826,12 @@ func (r *Router) OnStart(ctx context.Context) error { } } - r.Spawn("dialPeers", func(ctx context.Context) error { return r.dialPeers(ctx) }) - r.Spawn("evictPeers", func(ctx context.Context) error { return r.evictPeers(ctx) }) - r.Spawn("acceptPeers", func(ctx context.Context) error { return r.acceptPeers(ctx, r.transport) }) + r.SpawnCritical("transport.Run",func(ctx context.Context) error { + return r.transport.Run(ctx, r.endpoint) + }) + r.SpawnCritical("dialPeers", func(ctx context.Context) error { return r.dialPeers(ctx) }) + r.SpawnCritical("evictPeers", func(ctx context.Context) error { return r.evictPeers(ctx) }) + r.SpawnCritical("acceptPeers", func(ctx context.Context) error { return r.acceptPeers(ctx, r.transport) }) return nil } @@ -842,12 +841,7 @@ func (r *Router) OnStart(ctx context.Context) error { // router, to prevent blocked channel sends in reactors. Channels are not closed // here, since that would cause any reactor senders to panic, so it is the // sender's responsibility. -func (r *Router) OnStop() { - // Close transport listeners (unblocks Accept calls). - if err := r.transport.Close(); err != nil { - r.logger.Error("failed to close transport", "err", err) - } -} +func (r *Router) OnStop() { } type ChannelIDSet map[ChannelID]struct{} diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index 7a965260a..ed8d4cb90 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "net" + "net/netip" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/types" @@ -23,19 +23,12 @@ type Protocol string // Transport is a connection-oriented mechanism for exchanging data with a peer. type Transport interface { - // Listen starts the transport on the specified endpoint. - Listen(*Endpoint) error - + // Run executes the background tasks of transport. + Run(ctx context.Context, endpoint *Endpoint) error // Protocols returns the protocols supported by the transport. The Router // uses this to pick a transport for an Endpoint. Protocols() []Protocol - // Endpoints returns the local endpoints the transport is listening on, if any. - // - // How to listen is transport-dependent, e.g. MConnTransport uses Listen() while - // MemoryTransport starts listening via MemoryNetwork.CreateTransport(). - Endpoint() (*Endpoint, error) - // Accept waits for the next inbound connection on a listening endpoint, blocking // until either a connection is available or the transport is closed. On closure, // io.EOF is returned and further Accept calls are futile. @@ -44,9 +37,6 @@ type Transport interface { // Dial creates an outbound connection to an endpoint. Dial(context.Context, *Endpoint) (Connection, error) - // Close stops accepting new connections, but does not close active connections. - Close() error - // AddChannelDescriptors is only part of this interface // temporarily AddChannelDescriptors([]*ChannelDescriptor) @@ -115,14 +105,8 @@ type Connection interface { type Endpoint struct { // Protocol specifies the transport protocol. Protocol Protocol - - // IP is an IP address (v4 or v6) to connect to. If set, this defines the - // endpoint as a networked endpoint. - IP net.IP - - // Port is a network port (either TCP or UDP). If 0, a default port may be - // used depending on the protocol. - Port uint16 + // TCP endpoint address. + Addr netip.AddrPort // Path is an optional transport-specific path or identifier. Path string @@ -130,15 +114,14 @@ type Endpoint struct { // NewEndpoint constructs an Endpoint from a types.NetAddress structure. func NewEndpoint(addr string) (*Endpoint, error) { - ip, port, err := types.ParseAddressString(addr) + addrPort, err := types.ParseAddressString(addr) if err != nil { return nil, err } return &Endpoint{ Protocol: MConnProtocol, - IP: ip, - Port: port, + Addr: addrPort, }, nil } @@ -149,9 +132,9 @@ func (e Endpoint) NodeAddress(nodeID types.NodeID) NodeAddress { Protocol: e.Protocol, Path: e.Path, } - if len(e.IP) > 0 { - address.Hostname = e.IP.String() - address.Port = e.Port + if e.Addr!=(netip.AddrPort{}) { + address.Hostname = e.Addr.Addr().String() + address.Port = e.Addr.Port() } return address } @@ -161,7 +144,7 @@ func (e Endpoint) String() string { // If this is a non-networked endpoint with a valid node ID as a path, // assume that path is a node ID (to handle opaque URLs of the form // scheme:id). - if e.IP == nil { + if e.Addr == (netip.AddrPort{}) { if nodeID, err := types.NewNodeID(e.Path); err == nil { return e.NodeAddress(nodeID).String() } @@ -171,20 +154,17 @@ func (e Endpoint) String() string { // Validate validates the endpoint. func (e Endpoint) Validate() error { - switch { - case e.Protocol == "": + if e.Protocol == "" { return errors.New("endpoint has no protocol") - - case len(e.IP) > 0 && e.IP.To16() == nil: - return fmt.Errorf("invalid IP address %v", e.IP) - - case e.Port > 0 && len(e.IP) == 0: - return fmt.Errorf("endpoint has port %v but no IP", e.Port) - - case len(e.IP) == 0 && e.Path == "": - return errors.New("endpoint has neither path nor IP") - - default: - return nil } + if e.Addr==(netip.AddrPort{}) { + if e.Path == "" { + return errors.New("endpoint has neither path nor IP") + } + } else { + if !e.Addr.IsValid() { + return fmt.Errorf("endpoint has invalid address %q", e.Addr.String()) + } + } + return nil } diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 3709fb58f..1aec41486 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -7,7 +7,7 @@ import ( "io" "math" "net" - "strconv" + "net/netip" "sync" "golang.org/x/net/netutil" @@ -16,6 +16,8 @@ import ( "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/internal/p2p/conn" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -44,10 +46,7 @@ type MConnTransport struct { options MConnTransportOptions mConnConfig conn.MConnConfig channelDescs []*ChannelDescriptor - - closeOnce sync.Once - doneCh chan struct{} - listener net.Listener + listener chan *mConnConnection } // NewMConnTransport sets up a new MConnection transport. This uses the @@ -55,69 +54,30 @@ type MConnTransport struct { // conn.MConnection. func NewMConnTransport( logger log.Logger, + endpoint *Endpoint, mConnConfig conn.MConnConfig, channelDescs []*ChannelDescriptor, options MConnTransportOptions, ) *MConnTransport { return &MConnTransport{ logger: logger, + endpoint: endpoint, options: options, mConnConfig: mConnConfig, - doneCh: make(chan struct{}), channelDescs: channelDescs, + // This is rendezvous channel, so that no unclosed connections get stuck inside + // when transport is closing. + listener: make(chan *mConnConnection), } } -// String implements Transport. -func (m *MConnTransport) String() string { - return string(MConnProtocol) -} - -// Protocols implements Transport. We support tcp for backwards-compatibility. -func (m *MConnTransport) Protocols() []Protocol { - return []Protocol{MConnProtocol, TCPProtocol} -} - -// Endpoint implements Transport. -func (m *MConnTransport) Endpoint() (*Endpoint, error) { - if m.listener == nil { - return nil, errors.New("listenter not defined") - } - select { - case <-m.doneCh: - return nil, errors.New("transport closed") - default: - } - - endpoint := &Endpoint{ - Protocol: MConnProtocol, - } - if addr, ok := m.listener.Addr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) - } - return endpoint, nil -} - -// Listen asynchronously listens for inbound connections on the given endpoint. -// It must be called exactly once before calling Accept(), and the caller must -// call Close() to shut down the listener. -// -// FIXME: Listen currently only supports listening on a single endpoint, it -// might be useful to support listening on multiple addresses (e.g. IPv4 and -// IPv6, or a private and public address) via multiple Listen() calls. -func (m *MConnTransport) Listen(endpoint *Endpoint) error { - if m.listener != nil { - return errors.New("transport is already listening") - } - if err := m.validateEndpoint(endpoint); err != nil { +func (m *MConnTransport) Run(ctx context.Context) error { + if err := m.validateEndpoint(m.endpoint); err != nil { return err } - - listener, err := net.Listen("tcp", net.JoinHostPort( - endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) + listener, err := net.Listen("tcp", endpoint.Addr.String()) if err != nil { - return err + return fmt.Errorf("net.Listen(): %w",err) } if m.options.MaxAcceptedConnections > 0 { // FIXME: This will establish the inbound connection but simply hang it @@ -127,46 +87,42 @@ func (m *MConnTransport) Listen(endpoint *Endpoint) error { // This was just carried over from the legacy P2P stack. listener = netutil.LimitListener(listener, int(m.options.MaxAcceptedConnections)) } - m.listener = listener + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { + <-ctx.Done() + listener.Close() + return nil + }) + for { + conn, err := listener.Accept() + if err!=nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + mconn := newMConnConnection(m.logger, conn, m.mConnConfig, m.channelDescs) + if err:=utils.Send(ctx, m.listener, mconn); err!=nil { + mconn.Close() + return err + } + } + }) +} - return nil +// String implements Transport. +func (m *MConnTransport) String() string { + return string(MConnProtocol) +} + +// Protocols implements Transport. We support tcp for backwards-compatibility. +func (m *MConnTransport) Protocols() []Protocol { + return []Protocol{MConnProtocol, TCPProtocol} } // Accept implements Transport. func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { - if m.listener == nil { - return nil, errors.New("transport is not listening") - } - - conCh := make(chan net.Conn) - errCh := make(chan error) - go func() { - tcpConn, err := m.listener.Accept() - if err != nil { - select { - case errCh <- err: - case <-ctx.Done(): - } - } - select { - case conCh <- tcpConn: - case <-ctx.Done(): - } - }() - - select { - case <-ctx.Done(): - m.listener.Close() - return nil, io.EOF - case <-m.doneCh: - m.listener.Close() - return nil, io.EOF - case err := <-errCh: - return nil, err - case tcpConn := <-conCh: - return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil - } - + return utils.Recv(ctx, m.listener) } // Dial implements Transport. @@ -174,37 +130,17 @@ func (m *MConnTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connecti if err := m.validateEndpoint(endpoint); err != nil { return nil, err } - if endpoint.Port == 0 { - endpoint.Port = 26657 + if endpoint.Addr.Port() == 0 { + endpoint.Addr = netip.AddrPortFrom(endpoint.Addr.Addr(),26657) } - dialer := net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort( - endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) + tcpConn, err := dialer.DialContext(ctx, "tcp", endpoint.Addr.String()) if err != nil { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - return nil, err - } + return nil,fmt.Errorf("dialer.DialContext(%v): %w", endpoint.Addr, err) } - return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil } -// Close implements Transport. -func (m *MConnTransport) Close() error { - var err error - m.closeOnce.Do(func() { - close(m.doneCh) - if m.listener != nil { - err = m.listener.Close() - } - }) - return err -} - // SetChannels sets the channel descriptors to be used when // establishing a connection. // @@ -224,8 +160,8 @@ func (m *MConnTransport) validateEndpoint(endpoint *Endpoint) error { if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol { return fmt.Errorf("unsupported protocol %q", endpoint.Protocol) } - if len(endpoint.IP) == 0 { - return errors.New("endpoint has no IP address") + if !endpoint.Addr.IsValid() { + return errors.New("endpoint has invalid address") } if endpoint.Path != "" { return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path) @@ -428,13 +364,10 @@ func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg [ select { case err := <-c.errorCh: return err - case <-ctx.Done(): - return io.EOF default: - if ok := c.mconn.Send(chID, msg); !ok { - return errors.New("sending message timed out") + if err := c.mconn.Send(ctx, chID, msg); err!=nil { + return fmt.Errorf("m.mconn.Send(%v): %w", chID, err) } - return nil } } @@ -459,8 +392,7 @@ func (c *mConnConnection) LocalEndpoint() Endpoint { Protocol: MConnProtocol, } if addr, ok := c.conn.LocalAddr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) + endpoint.Addr = addr.AddrPort() } return endpoint } @@ -471,8 +403,7 @@ func (c *mConnConnection) RemoteEndpoint() Endpoint { Protocol: MConnProtocol, } if addr, ok := c.conn.RemoteAddr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) + endpoint.Addr = addr.AddrPort() } return endpoint } diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 18d7f4fb3..434e5997e 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -1,11 +1,12 @@ package p2p_test import ( - "io" + "net/netip" "net" "testing" "time" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/fortytw2/leaktest" "github.com/stretchr/testify/require" @@ -17,67 +18,29 @@ import ( // Transports are mainly tested by common tests in transport_test.go, we // register a transport factory here to get included in those tests. func init() { - testTransports["mconn"] = func(t *testing.T) p2p.Transport { + testTransports["mconn"] = func(t *testing.T, addr netip.AddrPort) p2p.Transport { transport := p2p.NewMConnTransport( log.NewNopLogger(), conn.DefaultMConnConfig(), []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, p2p.MConnTransportOptions{}, ) - err := transport.Listen(&p2p.Endpoint{ - Protocol: p2p.MConnProtocol, - IP: net.IPv4(127, 0, 0, 1), - Port: 0, // assign a random port - }) - require.NoError(t, err) - - t.Cleanup(func() { _ = transport.Close() }) - + go func() { + if err:=transport.Run(t.Context(),&p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + Addr: addr, + }); err != nil { + panic(err) + } + }() return transport } } -func TestMConnTransport_AcceptBeforeListen(t *testing.T) { - transport := p2p.NewMConnTransport( - log.NewNopLogger(), - conn.DefaultMConnConfig(), - []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, - p2p.MConnTransportOptions{ - MaxAcceptedConnections: 2, - }, - ) - t.Cleanup(func() { - _ = transport.Close() - }) - ctx := t.Context() - - _, err := transport.Accept(ctx) - require.Error(t, err) - require.NotEqual(t, io.EOF, err) // io.EOF should be returned after Close() -} - func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { ctx := t.Context() - transport := p2p.NewMConnTransport( - log.NewNopLogger(), - conn.DefaultMConnConfig(), - []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, - p2p.MConnTransportOptions{ - MaxAcceptedConnections: 2, - }, - ) - t.Cleanup(func() { - _ = transport.Close() - }) - err := transport.Listen(&p2p.Endpoint{ - Protocol: p2p.MConnProtocol, - IP: net.IPv4(127, 0, 0, 1), - }) - require.NoError(t, err) - endpoint, err := transport.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint) + transport := testTransports["mconn"](t) // Start a goroutine to just accept any connections. acceptCh := make(chan p2p.Connection, 10) diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 3eb4c5b51..53a6cbe3e 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "sync" "github.com/tendermint/tendermint/crypto" @@ -61,22 +61,6 @@ func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport { return n.transports[id] } -// RemoveTransport removes a transport from the network and closes it. -func (n *MemoryNetwork) RemoveTransport(id types.NodeID) { - n.mtx.Lock() - t, ok := n.transports[id] - delete(n.transports, id) - n.mtx.Unlock() - - if ok { - // Close may recursively call RemoveTransport() again, but this is safe - // because we've already removed the transport from the map above. - if err := t.Close(); err != nil { - n.logger.Error("failed to close memory transport", "id", id, "err", err) - } - } -} - // Size returns the number of transports in the network. func (n *MemoryNetwork) Size() int { return len(n.transports) @@ -119,7 +103,14 @@ func (t *MemoryTransport) String() string { return string(MemoryProtocol) } -func (*MemoryTransport) Listen(*Endpoint) error { return nil } +func (t *MemoryTransport) Run(ctx context.Context, e *Endpoint) error { + <-ctx.Done() + t.network.mtx.Lock() + delete(t.network.transports, t.nodeID) + t.network.mtx.Unlock() + t.closeFn() + return nil +} func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {} @@ -139,8 +130,7 @@ func (t *MemoryTransport) Endpoint() (*Endpoint, error) { Path: string(t.nodeID), // An arbitrary IP and port is used in order for the pex // reactor to be able to send addresses to one another. - IP: net.IPv4zero, - Port: 0, + Addr: netip.AddrPort{}, }, nil } @@ -202,13 +192,6 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connect } } -// Close implements Transport. -func (t *MemoryTransport) Close() error { - t.network.RemoveTransport(t.nodeID) - t.closeFn() - return nil -} - // MemoryConnection is an in-memory connection between two transport endpoints. type MemoryConnection struct { logger log.Logger diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index ccb783f1d..32b8eda73 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "net/netip" "testing" "time" @@ -15,10 +16,13 @@ import ( "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/libs/bytes" "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils/tcp" ) // transportFactory is used to set up transports for tests. -type transportFactory func(t *testing.T) p2p.Transport +type transportFactory interface { + SpawnTransport(ctx context.Context, addr netip.AddrPort) p2p.Transport +} // testTransports is a registry of transport factories for withTransports(). var testTransports = map[string]transportFactory{} @@ -28,7 +32,6 @@ var testTransports = map[string]transportFactory{} func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) { t.Helper() for name, transportFactory := range testTransports { - transportFactory := transportFactory t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) tester(t, transportFactory) @@ -40,7 +43,7 @@ func TestTransport_AcceptClose(t *testing.T) { // Just test accept unblock on close, happy path is tested widely elsewhere. withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) + a := makeTransport(t,tcp.TestReserveAddr()) opctx, opcancel := context.WithTimeout(ctx, 200*time.Millisecond) defer opcancel() diff --git a/libs/service/service.go b/libs/service/service.go index 685b267c1..c96b32046 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/utils" "sync" @@ -164,6 +165,21 @@ func (bs *BaseService) Spawn(name string, task func(ctx context.Context) error) }() } +func (bs *BaseService) SpawnCritical(name string, task func(ctx context.Context) error) { + inner := bs.inner.Load() + if inner == nil { + panic("service is not started yet") + } + + inner.wg.Add(1) + go func() { + defer inner.wg.Done() + if err := utils.IgnoreCancel(task(inner.ctx)); err != nil { + panic(fmt.Sprintf("critical task failed: name=%v, service=%v: %v", name, bs.name, err)) + } + }() +} + // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { diff --git a/libs/utils/tcp/tcp.go b/libs/utils/tcp/tcp.go new file mode 100644 index 000000000..2a5dd6bfa --- /dev/null +++ b/libs/utils/tcp/tcp.go @@ -0,0 +1,79 @@ +package tcp + +import ( + "context" + "errors" + "net" + "net/netip" + "syscall" + + "golang.org/x/sys/unix" + + "github.com/tendermint/tendermint/libs/utils" +) + +var reservedAddrs = utils.NewMutex(map[netip.AddrPort]struct{}{}) + +// IPv4Loopback returns the IPv4 loopback address. +func IPv4Loopback() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } + +// Listen opens a TCP listener on the given address. +// It takes into account the reserved addresses (in tests) and sets the SO_REUSEPORT. +// nolint: contextcheck +func Listen(addr netip.AddrPort) (net.Listener, error) { + if addr.Port() == 0 { + return nil, errors.New("listening on anyport (i.e. 0) is not allowed. If you are implementing a test use TestReserveAddr() instead") // nolint:lll + } + cfg := net.ListenConfig{} + for addrs := range reservedAddrs.Lock() { + if _, ok := addrs[addr]; ok { + cfg.Control = func(network, address string, c syscall.RawConn) error { + var errInner error + if err := c.Control(func(fd uintptr) { + errInner = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }); err != nil { + return err + } + return errInner + } + } + } + // Passing the background context is ok, because Listen is + // non-blocking if it doesn't need to resolve the address + // against a DNS server. + return cfg.Listen(context.Background(), "tcp", addr.String()) +} + +// TestReserveAddr (testonly) reserves a port in ephemeral range to open a TCP listener on it. +// Reservation prevents race conditions with other processes. +func TestReserveAddr() netip.AddrPort { + // Bind a new socket to reserve a port, + // Don't mark it as listening to avoid the kernel from queueing up connections + // on that socket. + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0) + if err != nil { + panic(err) + } + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { + panic(err) + } + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + panic(err) + } + ip := IPv4Loopback() + addrAny := &unix.SockaddrInet4{Port: 0, Addr: ip.As4()} + if err := unix.Bind(fd, addrAny); err != nil { + panic(err) + } + + addrRaw, err := unix.Getsockname(fd) + if err != nil { + panic(err) + } + port := uint16(addrRaw.(*unix.SockaddrInet4).Port) + addr := netip.AddrPortFrom(ip, port) + for addrs := range reservedAddrs.Lock() { + addrs[addr] = struct{}{} + } + return addr +} diff --git a/node/node.go b/node/node.go index 6ed7ef5c1..7a3c2d43a 100644 --- a/node/node.go +++ b/node/node.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" "net" + "net/netip" "net/http" - "strconv" "strings" "time" @@ -826,9 +826,9 @@ func getRouterConfig(conf *config.Config, appClient abciclient.Client) p2p.Route return nil } - opts.FilterPeerByIP = func(ctx context.Context, ip net.IP, port uint16) error { + opts.FilterPeerByIP = func(ctx context.Context, addrPort netip.AddrPort) error { res, err := appClient.Query(ctx, &abci.RequestQuery{ - Path: fmt.Sprintf("/p2p/filter/addr/%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))), + Path: fmt.Sprintf("/p2p/filter/addr/%v", addrPort), }) if err != nil { return err diff --git a/types/node_info.go b/types/node_info.go index fd47816e2..3c7758b67 100644 --- a/types/node_info.go +++ b/types/node_info.go @@ -3,8 +3,7 @@ package types import ( "errors" "fmt" - "net" - "strconv" + "net/netip" "strings" "github.com/tendermint/tendermint/libs/bytes" @@ -77,7 +76,7 @@ func (info NodeInfo) ID() NodeID { // url-encoding), and we just need to be careful with how we handle that in our // clients. (e.g. off by default). func (info NodeInfo) Validate() error { - if _, _, err := ParseAddressString(info.ID().AddressString(info.ListenAddr)); err != nil { + if _, err := ParseAddressString(info.ID().AddressString(info.ListenAddr)); err != nil { return err } @@ -236,48 +235,22 @@ func NodeInfoFromProto(pb *tmp2p.NodeInfo) (NodeInfo, error) { // ParseAddressString reads an address string, and returns the IP // address and port information, returning an error for any validation // errors. -func ParseAddressString(addr string) (net.IP, uint16, error) { +func ParseAddressString(addr string) (netip.AddrPort, error) { addrWithoutProtocol := removeProtocolIfDefined(addr) spl := strings.Split(addrWithoutProtocol, "@") if len(spl) != 2 { - return nil, 0, errors.New("invalid address") + return netip.AddrPort{}, errors.New("invalid address") } id, err := NewNodeID(spl[0]) if err != nil { - return nil, 0, err + return netip.AddrPort{}, err } if err := id.Validate(); err != nil { - return nil, 0, err + return netip.AddrPort{}, err } - - addrWithoutProtocol = spl[1] - - // get host and port - host, portStr, err := net.SplitHostPort(addrWithoutProtocol) - if err != nil { - return nil, 0, err - } - if len(host) == 0 { - return nil, 0, err - } - - ip := net.ParseIP(host) - if ip == nil { - ips, err := net.LookupIP(host) - if err != nil { - return nil, 0, err - } - ip = ips[0] - } - - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, 0, err - } - - return ip, uint16(port), nil + return netip.ParseAddrPort(spl[1]) } func removeProtocolIfDefined(addr string) string { From 36c939fa867ed96fb37c339b3579a0f533ed9903 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Fri, 29 Aug 2025 14:58:20 +0200 Subject: [PATCH 29/41] WIP --- internal/p2p/address.go | 8 +- internal/p2p/address_test.go | 60 ++---- internal/p2p/conn_tracker_test.go | 23 ++- internal/p2p/mocks/transport.go | 36 +++- internal/p2p/p2ptest/network.go | 7 +- internal/p2p/router.go | 5 +- internal/p2p/router_filter_test.go | 4 +- internal/p2p/router_test.go | 11 - internal/p2p/transport.go | 14 +- internal/p2p/transport_mconn.go | 39 +++- internal/p2p/transport_mconn_test.go | 282 ++++++++++++++------------ internal/p2p/transport_memory.go | 14 +- internal/p2p/transport_memory_test.go | 33 ++- internal/p2p/transport_test.go | 232 +++++++-------------- node/setup.go | 13 +- 15 files changed, 358 insertions(+), 423 deletions(-) diff --git a/internal/p2p/address.go b/internal/p2p/address.go index 9a034a4ad..31d7561f7 100644 --- a/internal/p2p/address.go +++ b/internal/p2p/address.go @@ -98,7 +98,7 @@ func ParseNodeAddress(urlString string) (NodeAddress, error) { // Resolve resolves a NodeAddress into a set of Endpoints, by expanding // out a DNS hostname to IP addresses. -func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { +func (a NodeAddress) Resolve(ctx context.Context) ([]Endpoint, error) { if a.Protocol == "" { return nil, errors.New("address has no protocol") } @@ -110,7 +110,7 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { if a.NodeID == "" { return nil, errors.New("local address has no node ID") } - return []*Endpoint{{ + return []Endpoint{{ Protocol: a.Protocol, Path: string(a.NodeID), }}, nil @@ -120,11 +120,11 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { if err != nil { return nil, err } - endpoints := make([]*Endpoint, len(ips)) + endpoints := make([]Endpoint, len(ips)) for i, ip := range ips { ip,ok := netip.AddrFromSlice(ip) if !ok { return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) } - endpoints[i] = &Endpoint{ + endpoints[i] = Endpoint{ Protocol: a.Protocol, Addr: netip.AddrPortFrom(ip, a.Port), Path: a.Path, diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index 7c6fdb9bc..e0e982d45 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -1,7 +1,7 @@ package p2p_test import ( - "net" + "net/netip" "strings" "testing" @@ -9,6 +9,7 @@ import ( "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/internal/p2p" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/types" ) @@ -202,61 +203,61 @@ func TestNodeAddress_Resolve(t *testing.T) { testcases := []struct { address p2p.NodeAddress - expect *p2p.Endpoint + expect p2p.Endpoint ok bool }{ // Valid networked addresses (with hostname). { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1", Port: 80, Path: "/path"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1), Port: 80, Path: "/path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(),80), Path: "/path"}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(),0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "::1"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv6loopback}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.IPv6Loopback(),0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "8.8.8.8"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(8, 8, 8, 8)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}),0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "2001:0db8::ff00:0042:8329"}, - &p2p.Endpoint{Protocol: "tcp", IP: []byte{ - 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x42, 0x83, 0x29}}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom16([16]byte{ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x42, 0x83, 0x29}),0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "some.missing.host.tendermint.com"}, - &p2p.Endpoint{}, + p2p.Endpoint{}, false, }, // Valid non-networked addresses. { p2p.NodeAddress{Protocol: "memory", NodeID: id}, - &p2p.Endpoint{Protocol: "memory", Path: string(id)}, + p2p.Endpoint{Protocol: "memory", Path: string(id)}, true, }, { p2p.NodeAddress{Protocol: "memory", NodeID: id, Path: string(id)}, - &p2p.Endpoint{Protocol: "memory", Path: string(id)}, + p2p.Endpoint{Protocol: "memory", Path: string(id)}, true, }, // Invalid addresses. - {p2p.NodeAddress{}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Hostname: "127.0.0.1"}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1:80"}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "memory"}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "memory", Path: string(id)}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "tcp", Hostname: "💥"}, &p2p.Endpoint{}, false}, + {p2p.NodeAddress{}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Hostname: "127.0.0.1"}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1:80"}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "memory"}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "memory", Path: string(id)}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "tcp", Hostname: "💥"}, p2p.Endpoint{}, false}, } for _, tc := range testcases { t.Run(tc.address.String(), func(t *testing.T) { @@ -265,25 +266,6 @@ func TestNodeAddress_Resolve(t *testing.T) { require.Error(t, err) return } - - // Special handling for localhost tests - accept either IPv4 or IPv6 - if tc.address.Hostname == "localhost" && tc.address.Port == 80 && tc.address.Path == "/path" { - hasIPv4 := false - hasIPv6 := false - for _, ep := range endpoints { - if ep.Protocol == "tcp" && ep.Port == 80 && ep.Path == "/path" { - if ep.IP.Equal(net.IPv4(127, 0, 0, 1)) { - hasIPv4 = true - } - if ep.IP.Equal(net.IPv6loopback) { - hasIPv6 = true - } - } - } - require.True(t, hasIPv4 || hasIPv6, "localhost should resolve to either IPv4 or IPv6") - return - } - require.Contains(t, endpoints, tc.expect) }) } @@ -291,13 +273,11 @@ func TestNodeAddress_Resolve(t *testing.T) { addr := p2p.NodeAddress{Protocol: "tcp", Hostname: "localhost", Port: 80, Path: "/path"} endpoints, err := addr.Resolve(t.Context()) require.NoError(t, err) - - want := &p2p.Endpoint{Protocol: "tcp", Port: 80, Path: "/path"} require.True(t, len(endpoints) > 0) for _, got := range endpoints { - require.True(t, got.IP.IsLoopback()) + require.True(t, got.Addr.Addr().IsLoopback()) // Any loopback address is acceptable, so ignore it in comparison. - want.IP = got.IP + want := &p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(got.Addr.Addr(),80), Path: "/path"} require.Equal(t, want, got) } }) diff --git a/internal/p2p/conn_tracker_test.go b/internal/p2p/conn_tracker_test.go index daa3351f2..93216bdcd 100644 --- a/internal/p2p/conn_tracker_test.go +++ b/internal/p2p/conn_tracker_test.go @@ -3,7 +3,7 @@ package p2p import ( "math" "math/rand" - "net" + "net/netip" "testing" "time" @@ -14,8 +14,15 @@ func randByte() byte { return byte(rand.Intn(math.MaxUint8)) } -func randLocalIPv4() net.IP { - return net.IPv4(127, randByte(), randByte(), randByte()) +func randPort() uint16 { + return uint16(rand.Intn(math.MaxUint16)) +} + +func randLocalAddr() netip.AddrPort { + return netip.AddrPortFrom( + netip.AddrFrom4([4]byte{127, randByte(), randByte(), randByte()}), + randPort(), + ) } func TestConnTracker(t *testing.T) { @@ -35,7 +42,7 @@ func TestConnTracker(t *testing.T) { }) t.Run("RepeatedAdding", func(t *testing.T) { ct := factory() - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) for i := 0; i < 100; i++ { _ = ct.AddConn(ip) @@ -45,14 +52,14 @@ func TestConnTracker(t *testing.T) { t.Run("AddingMany", func(t *testing.T) { ct := factory() for i := 0; i < 100; i++ { - _ = ct.AddConn(randLocalIPv4()) + _ = ct.AddConn(randLocalAddr()) } require.Equal(t, 100, ct.Len()) }) t.Run("Cycle", func(t *testing.T) { ct := factory() for i := 0; i < 100; i++ { - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) ct.RemoveConn(ip) } @@ -63,7 +70,7 @@ func TestConnTracker(t *testing.T) { t.Run("VeryShort", func(t *testing.T) { ct := newConnTracker(10, time.Microsecond) for i := 0; i < 10; i++ { - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) time.Sleep(2 * time.Microsecond) require.NoError(t, ct.AddConn(ip)) @@ -73,7 +80,7 @@ func TestConnTracker(t *testing.T) { t.Run("Window", func(t *testing.T) { const window = 100 * time.Millisecond ct := newConnTracker(10, window) - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) ct.RemoveConn(ip) require.Error(t, ct.AddConn(ip)) diff --git a/internal/p2p/mocks/transport.go b/internal/p2p/mocks/transport.go index 7a789f8dd..e2ee3c913 100644 --- a/internal/p2p/mocks/transport.go +++ b/internal/p2p/mocks/transport.go @@ -53,7 +53,7 @@ func (_m *Transport) AddChannelDescriptors(_a0 []*conn.ChannelDescriptor) { } // Dial provides a mock function with given fields: _a0, _a1 -func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connection, error) { +func (_m *Transport) Dial(_a0 context.Context, _a1 p2p.Endpoint) (p2p.Connection, error) { ret := _m.Called(_a0, _a1) if len(ret) == 0 { @@ -62,10 +62,10 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio var r0 p2p.Connection var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *p2p.Endpoint) (p2p.Connection, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, p2p.Endpoint) (p2p.Connection, error)); ok { return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context, *p2p.Endpoint) p2p.Connection); ok { + if rf, ok := ret.Get(0).(func(context.Context, p2p.Endpoint) p2p.Connection); ok { r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { @@ -73,7 +73,7 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio } } - if rf, ok := ret.Get(1).(func(context.Context, *p2p.Endpoint) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, p2p.Endpoint) error); ok { r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) @@ -82,6 +82,24 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio return r0, r1 } +// Endpoint provides a mock function with no fields +func (_m *Transport) Endpoint() p2p.Endpoint { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Endpoint") + } + + var r0 p2p.Endpoint + if rf, ok := ret.Get(0).(func() p2p.Endpoint); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(p2p.Endpoint) + } + + return r0 +} + // Protocols provides a mock function with no fields func (_m *Transport) Protocols() []p2p.Protocol { ret := _m.Called() @@ -102,17 +120,17 @@ func (_m *Transport) Protocols() []p2p.Protocol { return r0 } -// Run provides a mock function with given fields: ctx, endpoint -func (_m *Transport) Run(ctx context.Context, endpoint *p2p.Endpoint) error { - ret := _m.Called(ctx, endpoint) +// Run provides a mock function with given fields: ctx +func (_m *Transport) Run(ctx context.Context) error { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for Run") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *p2p.Endpoint) error); ok { - r0 = rf(ctx, endpoint) + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) } else { r0 = ret.Error(0) } diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 940394e01..a0f6ed51c 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -245,10 +245,6 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) } transport := n.memoryNetwork.CreateTransport(nodeID) - ep, err := transport.Endpoint() - require.NoError(t, err) - require.NotNil(t, ep, "transport not listening an endpoint") - maxRetryTime := 1000 * time.Millisecond if opts.MaxRetryTime > 0 { maxRetryTime = opts.MaxRetryTime @@ -271,7 +267,6 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) peerManager, func() *types.NodeInfo { return &nodeInfo }, transport, - ep, nil, p2p.RouterOptions{DialSleep: func(_ context.Context) error { return nil }}, ) @@ -290,7 +285,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) return &Node{ NodeID: nodeID, NodeInfo: nodeInfo, - NodeAddress: ep.NodeAddress(nodeID), + NodeAddress: transport.Endpoint().NodeAddress(nodeID), PrivKey: privKey, Router: router, PeerManager: peerManager, diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 6f67e06f4..73e82dd81 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -147,7 +147,6 @@ type Router struct { peerManager *PeerManager chDescs []*ChannelDescriptor transport Transport - endpoint *Endpoint connTracker connectionTracker peerStates utils.RWMutex[map[types.NodeID]*peerState] @@ -180,7 +179,6 @@ func NewRouter( peerManager *PeerManager, nodeInfoProducer func() *types.NodeInfo, transport Transport, - endpoint *Endpoint, dynamicIDFilterer func(context.Context, types.NodeID) error, options RouterOptions, ) (*Router, error) { @@ -201,7 +199,6 @@ func NewRouter( ), chDescs: make([]*ChannelDescriptor, 0), transport: transport, - endpoint: endpoint, peerManager: peerManager, options: options, channelQueues: map[ChannelID]*Queue{}, @@ -827,7 +824,7 @@ func (r *Router) OnStart(ctx context.Context) error { } r.SpawnCritical("transport.Run",func(ctx context.Context) error { - return r.transport.Run(ctx, r.endpoint) + return r.transport.Run(ctx) }) r.SpawnCritical("dialPeers", func(ctx context.Context) error { return r.dialPeers(ctx) }) r.SpawnCritical("evictPeers", func(ctx context.Context) error { return r.evictPeers(ctx) }) diff --git a/internal/p2p/router_filter_test.go b/internal/p2p/router_filter_test.go index 5b1d7219a..afd9879bd 100644 --- a/internal/p2p/router_filter_test.go +++ b/internal/p2p/router_filter_test.go @@ -3,7 +3,7 @@ package p2p import ( "context" "errors" - "net" + "net/netip" "testing" "time" @@ -21,7 +21,7 @@ func TestConnectionFiltering(t *testing.T) { logger: logger, connTracker: newConnTracker(1, time.Second), options: RouterOptions{ - FilterPeerByIP: func(ctx context.Context, ip net.IP, port uint16) error { + FilterPeerByIP: func(ctx context.Context, addr netip.AddrPort) error { filterByIPCount++ return errors.New("mock") }, diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index caf60f8e5..433b477c4 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -115,7 +115,6 @@ func TestRouter_Channel_Basic(t *testing.T) { peerManager, func() *types.NodeInfo { return &selfInfo }, testnet.RandomNode().Transport, - &p2p.Endpoint{}, nil, p2p.RouterOptions{}, ) @@ -400,7 +399,6 @@ func TestRouter_AcceptPeers(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -457,7 +455,6 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -512,7 +509,6 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -611,7 +607,6 @@ func TestRouter_DialPeers(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -696,7 +691,6 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{ DialSleep: func(_ context.Context) error { return nil }, NumConcurrentDials: func() int { @@ -771,7 +765,6 @@ func TestRouter_EvictPeers(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -833,7 +826,6 @@ func TestRouter_ChannelCompatability(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -888,7 +880,6 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -952,7 +943,6 @@ func TestRouter_Channel_FilterByID(t *testing.T) { peerManager, func() *types.NodeInfo { return &selfInfo }, mockTransport, - &p2p.Endpoint{}, nil, p2p.RouterOptions{}, ) @@ -976,7 +966,6 @@ func TestRouter_Channel_FilterByID(t *testing.T) { peerManager, func() *types.NodeInfo { return &selfInfo }, mockTransport, - &p2p.Endpoint{}, func(_ context.Context, _ types.NodeID) error { return errors.New("should filter") }, p2p.RouterOptions{}, ) diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index ed8d4cb90..1b202f8c9 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -24,18 +24,22 @@ type Protocol string // Transport is a connection-oriented mechanism for exchanging data with a peer. type Transport interface { // Run executes the background tasks of transport. - Run(ctx context.Context, endpoint *Endpoint) error + Run(ctx context.Context) error // Protocols returns the protocols supported by the transport. The Router // uses this to pick a transport for an Endpoint. Protocols() []Protocol + // Endpoints returns the local endpoints the transport is listening on. + Endpoint() Endpoint + + // Accept waits for the next inbound connection on a listening endpoint, blocking // until either a connection is available or the transport is closed. On closure, // io.EOF is returned and further Accept calls are futile. Accept(context.Context) (Connection, error) // Dial creates an outbound connection to an endpoint. - Dial(context.Context, *Endpoint) (Connection, error) + Dial(context.Context, Endpoint) (Connection, error) // AddChannelDescriptors is only part of this interface // temporarily @@ -113,13 +117,13 @@ type Endpoint struct { } // NewEndpoint constructs an Endpoint from a types.NetAddress structure. -func NewEndpoint(addr string) (*Endpoint, error) { +func NewEndpoint(addr string) (Endpoint, error) { addrPort, err := types.ParseAddressString(addr) if err != nil { - return nil, err + return Endpoint{}, err } - return &Endpoint{ + return Endpoint{ Protocol: MConnProtocol, Addr: addrPort, }, nil diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 1aec41486..f164630ba 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -17,6 +17,7 @@ import ( "github.com/tendermint/tendermint/internal/p2p/conn" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/libs/utils/scope" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" @@ -43,9 +44,11 @@ type MConnTransportOptions struct { // Tendermint protocol ("MConn"). type MConnTransport struct { logger log.Logger + endpoint Endpoint options MConnTransportOptions mConnConfig conn.MConnConfig channelDescs []*ChannelDescriptor + started chan struct{} listener chan *mConnConnection } @@ -54,31 +57,43 @@ type MConnTransport struct { // conn.MConnection. func NewMConnTransport( logger log.Logger, - endpoint *Endpoint, + endpoint Endpoint, mConnConfig conn.MConnConfig, channelDescs []*ChannelDescriptor, options MConnTransportOptions, ) *MConnTransport { return &MConnTransport{ logger: logger, - endpoint: endpoint, + endpoint: endpoint, options: options, mConnConfig: mConnConfig, channelDescs: channelDescs, // This is rendezvous channel, so that no unclosed connections get stuck inside // when transport is closing. + started: make(chan struct{}), listener: make(chan *mConnConnection), } } +// WaitForStart waits until transport starts listening for incoming connections. +func (m *MConnTransport) WaitForStart(ctx context.Context) error { + _,_,err := utils.RecvOrClosed(ctx, m.started) + return err +} + +func (m *MConnTransport) Endpoint() Endpoint { + return m.endpoint +} + func (m *MConnTransport) Run(ctx context.Context) error { if err := m.validateEndpoint(m.endpoint); err != nil { return err } - listener, err := net.Listen("tcp", endpoint.Addr.String()) + listener, err := tcp.Listen(m.endpoint.Addr) if err != nil { return fmt.Errorf("net.Listen(): %w",err) } + close(m.started) // signal that we are listening if m.options.MaxAcceptedConnections > 0 { // FIXME: This will establish the inbound connection but simply hang it // until another connection is released. It would probably be better to @@ -96,7 +111,7 @@ func (m *MConnTransport) Run(ctx context.Context) error { for { conn, err := listener.Accept() if err!=nil { - if errors.Is(err, io.EOF) { + if errors.Is(err, net.ErrClosed) { return nil } return err @@ -126,7 +141,7 @@ func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { } // Dial implements Transport. -func (m *MConnTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connection, error) { +func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { if err := m.validateEndpoint(endpoint); err != nil { return nil, err } @@ -152,19 +167,21 @@ func (m *MConnTransport) AddChannelDescriptors(channelDesc []*ChannelDescriptor) m.channelDescs = append(m.channelDescs, channelDesc...) } +type InvalidEndpointErr struct { error } + // validateEndpoint validates an endpoint. -func (m *MConnTransport) validateEndpoint(endpoint *Endpoint) error { +func (m *MConnTransport) validateEndpoint(endpoint Endpoint) error { if err := endpoint.Validate(); err != nil { - return err + return InvalidEndpointErr{err} } if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol { - return fmt.Errorf("unsupported protocol %q", endpoint.Protocol) + return InvalidEndpointErr{fmt.Errorf("unsupported protocol %q", endpoint.Protocol)} } if !endpoint.Addr.IsValid() { - return errors.New("endpoint has invalid address") + return InvalidEndpointErr{errors.New("endpoint has invalid address")} } if endpoint.Path != "" { - return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path) + return InvalidEndpointErr{fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path)} } return nil } @@ -336,7 +353,7 @@ func (c *mConnConnection) onReceive(ctx context.Context, chID ChannelID, payload // onError is a callback for MConnection errors. The error is passed via errorCh // to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior). -func (c *mConnConnection) onError(ctx context.Context, e interface{}) { +func (c *mConnConnection) onError(ctx context.Context, e any) { err, ok := e.(error) if !ok { err = fmt.Errorf("%v", err) diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 434e5997e..104b06ba2 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -1,12 +1,17 @@ package p2p_test import ( + "context" "net/netip" - "net" + "io" "testing" "time" + "fmt" + "errors" "github.com/tendermint/tendermint/libs/utils/tcp" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" "github.com/fortytw2/leaktest" "github.com/stretchr/testify/require" @@ -18,169 +23,186 @@ import ( // Transports are mainly tested by common tests in transport_test.go, we // register a transport factory here to get included in those tests. func init() { - testTransports["mconn"] = func(t *testing.T, addr netip.AddrPort) p2p.Transport { - transport := p2p.NewMConnTransport( - log.NewNopLogger(), - conn.DefaultMConnConfig(), - []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, - p2p.MConnTransportOptions{}, - ) - go func() { - if err:=transport.Run(t.Context(),&p2p.Endpoint{ - Protocol: p2p.MConnProtocol, - Addr: addr, - }); err != nil { - panic(err) - } - }() - return transport + testTransports["mconn"] = func() func(context.Context) p2p.Transport { + return func(ctx context.Context) p2p.Transport { + transport := p2p.NewMConnTransport( + log.NewNopLogger(), + p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + Addr: tcp.TestReserveAddr(), + }, + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, + p2p.MConnTransportOptions{}, + ) + go func() { + if err:=transport.Run(ctx); err != nil { + panic(err) + } + }() + return transport + } } } +func connect(ctx context.Context, tr *p2p.MConnTransport) (c1 p2p.Connection, c2 p2p.Connection, err error) { + defer func() { + if err != nil { + if c1 != nil { c1.Close() } + if c2 != nil { c2.Close() } + } + }() + // Here we are utilizing the fact that MConnTransport accepts connection proactively + // before Accept is called. + c1, err = tr.Dial(ctx, tr.Endpoint()) + if err != nil { return nil,nil,fmt.Errorf("Dial(): %w", err) } + c2, err = tr.Accept(ctx) + if err != nil { return nil,nil,fmt.Errorf("Accept(): %w", err) } + if got,want := c1.LocalEndpoint(),c2.RemoteEndpoint(); got!=want { + return nil,nil,fmt.Errorf("c1.LocalEndpoint() = %v, want %v", got, want) + } + if got,want := c1.RemoteEndpoint(),c2.LocalEndpoint(); got!=want { + return nil,nil,fmt.Errorf("c1.RemoteEndpoint() = %v, want %v", got, want) + } + return c1,c2,nil +} + func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { ctx := t.Context() - - transport := testTransports["mconn"](t) - - // Start a goroutine to just accept any connections. - acceptCh := make(chan p2p.Connection, 10) - go func() { - for { - conn, err := transport.Accept(ctx) - if err != nil { - return + transport := p2p.NewMConnTransport( + log.NewNopLogger(), + p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + Addr: tcp.TestReserveAddr(), + }, + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, + p2p.MConnTransportOptions{ + MaxAcceptedConnections: 2, + }, + ) + + err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBgNamed("transport", func() error { return transport.Run(ctx) }) + if err:=transport.WaitForStart(ctx); err!=nil { + return err + } + t.Logf("The first two connections should be accepted just fine.") + + a1,a2,err := connect(ctx, transport) + if err!=nil { return fmt.Errorf("1st connect(): %w", err) } + defer a1.Close() + defer a2.Close() + + b1,b2,err := connect(ctx, transport) + if err!=nil { return fmt.Errorf("2nd connect(): %w",err) } + defer b1.Close() + defer b2.Close() + + t.Logf("The third connection will be dialed successfully, but the accept should not go through.") + c1, err := transport.Dial(ctx, transport.Endpoint()) + if err!=nil { return fmt.Errorf("3rd Dial(): %w", err) } + defer c1.Close() + if err := utils.WithTimeout(ctx, time.Second, func(ctx context.Context) error { + c2, err := transport.Accept(ctx) + if err==nil { + c2.Close() } - acceptCh <- conn + return err + }); !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("Accept() over cap: %v, want %v", err, context.DeadlineExceeded) } - }() - // The first two connections should be accepted just fine. - dial1, err := transport.Dial(ctx, endpoint) - require.NoError(t, err) - defer dial1.Close() - accept1 := <-acceptCh - defer accept1.Close() - require.Equal(t, dial1.LocalEndpoint(), accept1.RemoteEndpoint()) - - dial2, err := transport.Dial(ctx, endpoint) - require.NoError(t, err) - defer dial2.Close() - accept2 := <-acceptCh - defer accept2.Close() - require.Equal(t, dial2.LocalEndpoint(), accept2.RemoteEndpoint()) - - // The third connection will be dialed successfully, but the accept should - // not go through. - dial3, err := transport.Dial(ctx, endpoint) - require.NoError(t, err) - defer dial3.Close() - select { - case <-acceptCh: - require.Fail(t, "unexpected accept") - case <-time.After(time.Second): + t.Logf("once either of the other connections are closed, the accept goes through.") + a1.Close() + a2.Close() // we close both a1 and a2 to make sure the connection count drops below the limit. + c2,err := transport.Accept(ctx) + if err!=nil { return fmt.Errorf("3rd Accept(): %w",err) } + defer c2.Close() + return nil + })) + if err != nil { + t.Fatal(err) } - - // However, once either of the other connections are closed, the accept - // goes through. - require.NoError(t, accept1.Close()) - accept3 := <-acceptCh - defer accept3.Close() - require.Equal(t, dial3.LocalEndpoint(), accept3.RemoteEndpoint()) } func TestMConnTransport_Listen(t *testing.T) { - ctx := t.Context() + reservePort := func(ip netip.Addr) netip.AddrPort { + addr := tcp.TestReserveAddr() + return netip.AddrPortFrom(ip, addr.Port()) + } testcases := []struct { - endpoint *p2p.Endpoint + endpoint p2p.Endpoint ok bool }{ // Valid v4 and v6 addresses, with mconn and tcp protocols. - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero}, true}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4(127, 0, 0, 1)}, true}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6zero}, true}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6loopback}, true}, - {&p2p.Endpoint{Protocol: p2p.TCPProtocol, IP: net.IPv4zero}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv4Unspecified())}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(tcp.IPv4Loopback())}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv6Unspecified())}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv6Loopback())}, true}, + {p2p.Endpoint{Protocol: p2p.TCPProtocol, Addr: reservePort(netip.IPv4Unspecified())}, true}, // Invalid endpoints. - {&p2p.Endpoint{}, false}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, Path: "foo"}, false}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero, Path: "foo"}, false}, + {p2p.Endpoint{}, false}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Path: "foo"}, false}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv4Unspecified()), Path: "foo"}, false}, } for _, tc := range testcases { t.Run(tc.endpoint.String(), func(t *testing.T) { + ctx := t.Context() t.Cleanup(leaktest.Check(t)) transport := p2p.NewMConnTransport( log.NewNopLogger(), + tc.endpoint, conn.DefaultMConnConfig(), []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, p2p.MConnTransportOptions{}, ) + if got,want := transport.Endpoint(),tc.endpoint; got!=want { + t.Fatalf("transport.Endpoint() = %v, want %v", got, want) + } - // Transport should not listen on any endpoints yet. - endpoint, err := transport.Endpoint() - require.Error(t, err) - require.Nil(t, endpoint) - - // Start listening, and check any expected errors. - err = transport.Listen(tc.endpoint) + err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBgNamed("transport", func() error { return transport.Run(ctx) }) + if err:=transport.WaitForStart(ctx); err!=nil { + return err + } + s.SpawnNamed("dial",func() error { + conn, err := transport.Dial(ctx, tc.endpoint) + if err != nil { return fmt.Errorf("transport.Dial(): %w", err) } + if err:=conn.Close(); err!=nil { + return fmt.Errorf("conn.Close(): %w", err) + } + if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err,io.EOF) { + return fmt.Errorf("conn.ReceiveMessage() = %v, want %v", err, io.EOF) + } + return nil + }) + s.SpawnNamed("accept",func() error { + conn, err := transport.Accept(ctx) + if err != nil { return fmt.Errorf("transport.Accept(): %w",err) } + if err:=conn.Close(); err!=nil { + return fmt.Errorf("conn.Close(): %w", err) + } + if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err,io.EOF) { + return fmt.Errorf("conn.ReceiveMessage() = %v, want %v", err, io.EOF) + } + return nil + }) + return nil + })) if !tc.ok { - require.Error(t, err) - return - } - require.NoError(t, err) - - // Check the endpoint. - endpoint, err = transport.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint) - - require.Equal(t, p2p.MConnProtocol, endpoint.Protocol) - if tc.endpoint.IP.IsUnspecified() { - require.True(t, endpoint.IP.IsUnspecified(), - "expected unspecified IP, got %v", endpoint.IP) - } else { - require.True(t, tc.endpoint.IP.Equal(endpoint.IP), - "expected %v, got %v", tc.endpoint.IP, endpoint.IP) + var want p2p.InvalidEndpointErr + if !errors.As(err, &want) { + t.Fatalf("error = %v, want %T", err, want) + } + } else if err != nil { + t.Fatal(err) } - require.NotZero(t, endpoint.Port) - require.Empty(t, endpoint.Path) - - dialedChan := make(chan struct{}) - - var peerConn p2p.Connection - go func() { - // Dialing the endpoint should work. - var err error - ctx := t.Context() - - peerConn, err = transport.Dial(ctx, endpoint) - require.NoError(t, err) - close(dialedChan) - }() - - conn, err := transport.Accept(ctx) - require.NoError(t, err) - _ = conn.Close() - <-dialedChan - - // closing the connection should not error - require.NoError(t, peerConn.Close()) - - // try to read from the connection should error - _, _, err = peerConn.ReceiveMessage(ctx) - require.Error(t, err) - - // Trying to listen again should error. - err = transport.Listen(tc.endpoint) - require.Error(t, err) - - // close the transport - _ = transport.Close() - // Dialing the closed endpoint should error - _, err = transport.Dial(ctx, endpoint) + _, err = transport.Dial(ctx, tc.endpoint) require.Error(t, err) }) } diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 53a6cbe3e..ccbc3dc5a 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -103,7 +103,7 @@ func (t *MemoryTransport) String() string { return string(MemoryProtocol) } -func (t *MemoryTransport) Run(ctx context.Context, e *Endpoint) error { +func (t *MemoryTransport) Run(ctx context.Context) error { <-ctx.Done() t.network.mtx.Lock() delete(t.network.transports, t.nodeID) @@ -120,18 +120,14 @@ func (t *MemoryTransport) Protocols() []Protocol { } // Endpoints implements Transport. -func (t *MemoryTransport) Endpoint() (*Endpoint, error) { - if n := t.network.GetTransport(t.nodeID); n == nil { - return nil, errors.New("node not defined") - } - - return &Endpoint{ +func (t *MemoryTransport) Endpoint() Endpoint { + return Endpoint{ Protocol: MemoryProtocol, Path: string(t.nodeID), // An arbitrary IP and port is used in order for the pex // reactor to be able to send addresses to one another. Addr: netip.AddrPort{}, - }, nil + } } // Accept implements Transport. @@ -148,7 +144,7 @@ func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { } // Dial implements Transport. -func (t *MemoryTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connection, error) { +func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { if endpoint.Protocol != MemoryProtocol { return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol) } diff --git a/internal/p2p/transport_memory_test.go b/internal/p2p/transport_memory_test.go index 33d96cdb8..ef5b1f299 100644 --- a/internal/p2p/transport_memory_test.go +++ b/internal/p2p/transport_memory_test.go @@ -3,9 +3,7 @@ package p2p_test import ( "bytes" "encoding/hex" - "testing" - - "github.com/stretchr/testify/require" + "context" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/libs/log" @@ -15,22 +13,19 @@ import ( // Transports are mainly tested by common tests in transport_test.go, we // register a transport factory here to get included in those tests. func init() { - var network *p2p.MemoryNetwork // shared by transports in the same test - - testTransports["memory"] = func(t *testing.T) p2p.Transport { - if network == nil { - network = p2p.NewMemoryNetwork(log.NewNopLogger(), 1) + testTransports["memory"] = func() func(context.Context) p2p.Transport { + network := p2p.NewMemoryNetwork(log.NewNopLogger(), 1) + return func(ctx context.Context) p2p.Transport { + i := byte(network.Size()) + nodeID, err := types.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20))) + if err!=nil { panic(err) } + t := network.CreateTransport(nodeID) + go func() { + if err:=t.Run(ctx); err!=nil { + panic(err) + } + }() + return t } - i := byte(network.Size()) - nodeID, err := types.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20))) - require.NoError(t, err) - transport := network.CreateTransport(nodeID) - - t.Cleanup(func() { - require.NoError(t, transport.Close()) - network = nil // set up a new memory network for the next test - }) - - return transport } } diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index 32b8eda73..17958ea00 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -3,7 +3,6 @@ package p2p_test import ( "context" "io" - "net" "net/netip" "testing" "time" @@ -16,16 +15,13 @@ import ( "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/libs/bytes" "github.com/tendermint/tendermint/types" - "github.com/tendermint/tendermint/libs/utils/tcp" ) // transportFactory is used to set up transports for tests. -type transportFactory interface { - SpawnTransport(ctx context.Context, addr netip.AddrPort) p2p.Transport -} +type transportFactory = func(ctx context.Context) p2p.Transport // testTransports is a registry of transport factories for withTransports(). -var testTransports = map[string]transportFactory{} +var testTransports = map[string](func() transportFactory){} // withTransports is a test helper that runs a test against all transports // registered in testTransports. @@ -34,58 +30,27 @@ func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) { for name, transportFactory := range testTransports { t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - tester(t, transportFactory) + tester(t, transportFactory()) }) } } -func TestTransport_AcceptClose(t *testing.T) { - // Just test accept unblock on close, happy path is tested widely elsewhere. - withTransports(t, func(t *testing.T, makeTransport transportFactory) { - ctx := t.Context() - a := makeTransport(t,tcp.TestReserveAddr()) - opctx, opcancel := context.WithTimeout(ctx, 200*time.Millisecond) - defer opcancel() - - _, err := a.Accept(opctx) - require.Error(t, err) - require.Equal(t, io.EOF, err) - - <-opctx.Done() - _ = a.Close() - - // Closed transport should return error immediately, - // because the transport is closed. We use the base - // context (ctx) rather than the operation context - // (opctx) because using the later would mean this - // could error because the context was canceled. - _, err = a.Accept(ctx) - require.Error(t, err) - require.Equal(t, io.EOF, err) - }) -} - func TestTransport_DialEndpoints(t *testing.T) { ipTestCases := []struct { - ip net.IP + ip netip.Addr ok bool }{ - {net.IPv4zero, true}, - {net.IPv6zero, true}, - - {nil, false}, - {net.IPv4bcast, false}, - {net.IPv4allsys, false}, - {[]byte{1, 2, 3}, false}, - {[]byte{1, 2, 3, 4, 5}, false}, + {netip.IPv4Unspecified(), true}, + {netip.IPv6Unspecified(), true}, + + {netip.AddrFrom4([4]byte{255,255,255,255}), false}, + {netip.AddrFrom4([4]byte{224, 0, 0, 1}), false}, } withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - endpoint, err := a.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint) + a := makeTransport(ctx) + endpoint := a.Endpoint() // Spawn a goroutine to simply accept any connections until closed. go func() { @@ -104,28 +69,27 @@ func TestTransport_DialEndpoints(t *testing.T) { require.NoError(t, conn.Close()) // Dialing empty endpoint should error. - _, err = a.Dial(ctx, &p2p.Endpoint{}) + _, err = a.Dial(ctx, p2p.Endpoint{}) require.Error(t, err) // Dialing without protocol should error. - noProtocol := *endpoint + noProtocol := endpoint noProtocol.Protocol = "" - _, err = a.Dial(ctx, &noProtocol) + _, err = a.Dial(ctx, noProtocol) require.Error(t, err) // Dialing with invalid protocol should error. - fooProtocol := *endpoint + fooProtocol := endpoint fooProtocol.Protocol = "foo" - _, err = a.Dial(ctx, &fooProtocol) + _, err = a.Dial(ctx, fooProtocol) require.Error(t, err) // Tests for networked endpoints (with IP). - if len(endpoint.IP) > 0 && endpoint.Protocol != p2p.MemoryProtocol { + if endpoint.Addr!=(netip.AddrPort{}) && endpoint.Protocol != p2p.MemoryProtocol { for _, tc := range ipTestCases { t.Run(tc.ip.String(), func(t *testing.T) { e := endpoint - require.NotNil(t, e) - e.IP = tc.ip + e.Addr = netip.AddrPortFrom(tc.ip,endpoint.Addr.Port()) conn, err := a.Dial(ctx, e) if tc.ok { require.NoError(t, err) @@ -138,8 +102,7 @@ func TestTransport_DialEndpoints(t *testing.T) { // Non-networked endpoints should error. noIP := endpoint - noIP.IP = nil - noIP.Port = 0 + noIP.Addr = netip.AddrPort{} noIP.Path = "foo" _, err := a.Dial(ctx, noIP) require.Error(t, err) @@ -158,25 +121,21 @@ func TestTransport_Dial(t *testing.T) { // Most just tests dial failures, happy path is tested widely elsewhere. withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) - aEndpoint, err := a.Endpoint() - require.NoError(t, err) - require.NotNil(t, aEndpoint) - bEndpoint, err := b.Endpoint() - require.NoError(t, err) - require.NotNil(t, bEndpoint) + aEndpoint := a.Endpoint() + bEndpoint := b.Endpoint() // Context cancellation should error. We can't test timeouts since we'd // need a non-responsive endpoint. cancelCtx, cancel := context.WithCancel(ctx) cancel() - _, err = a.Dial(cancelCtx, bEndpoint) + _, err := a.Dial(cancelCtx, bEndpoint) require.Error(t, err) // Unavailable endpoint should error. - err = b.Close() + // TODO: err = b.Close() require.NoError(t, err) _, err = a.Dial(ctx, bEndpoint) require.Error(t, err) @@ -200,49 +159,35 @@ func TestTransport_Dial(t *testing.T) { func TestTransport_Endpoints(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { - a := makeTransport(t) - b := makeTransport(t) + ctx := t.Context() + a := makeTransport(ctx) + b := makeTransport(ctx) // Both transports return valid and different endpoints. - aEndpoint, err := a.Endpoint() - require.NoError(t, err) - require.NotNil(t, aEndpoint) - bEndpoint, err := b.Endpoint() - require.NoError(t, err) - require.NotNil(t, bEndpoint) + aEndpoint := a.Endpoint() + bEndpoint := b.Endpoint() require.NotEqual(t, aEndpoint, bEndpoint) - for _, endpoint := range []*p2p.Endpoint{aEndpoint, bEndpoint} { + for _, endpoint := range []p2p.Endpoint{aEndpoint, bEndpoint} { err := endpoint.Validate() require.NoError(t, err, "invalid endpoint %q", endpoint) } - - // When closed, the transport should no longer return any endpoints. - require.NoError(t, a.Close()) - aEndpoint, err = a.Endpoint() - require.Error(t, err) - require.Nil(t, aEndpoint) - bEndpoint, err = b.Endpoint() - require.NoError(t, err) - require.NotNil(t, bEndpoint) }) } func TestTransport_Protocols(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { - a := makeTransport(t) + ctx := t.Context() + a := makeTransport(ctx) protocols := a.Protocols() - endpoint, err := a.Endpoint() - require.NoError(t, err) + endpoint := a.Endpoint() require.NotEmpty(t, protocols) - require.NotNil(t, endpoint) - require.Contains(t, protocols, endpoint.Protocol) }) } func TestTransport_String(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { - a := makeTransport(t) + a := makeTransport(t.Context()) require.NotEmpty(t, a.String()) }) } @@ -250,8 +195,8 @@ func TestTransport_String(t *testing.T) { func TestConnection_Handshake(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, ba := dialAccept(ctx, t, a, b) // A handshake should pass the given keys and NodeInfo. @@ -302,8 +247,8 @@ func TestConnection_Handshake(t *testing.T) { func TestConnection_HandshakeCancel(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) // Handshake should error on context cancellation. ab, ba := dialAccept(ctx, t, a, b) @@ -330,8 +275,8 @@ func TestConnection_HandshakeCancel(t *testing.T) { func TestConnection_FlushClose(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, _ := dialAcceptHandshake(ctx, t, a, b) err := ab.Close() @@ -349,8 +294,8 @@ func TestConnection_FlushClose(t *testing.T) { func TestConnection_LocalRemoteEndpoint(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, ba := dialAcceptHandshake(ctx, t, a, b) // Local and remote connection endpoints correspond to each other. @@ -365,8 +310,8 @@ func TestConnection_SendReceive(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, ba := dialAcceptHandshake(ctx, t, a, b) // Can send and receive a to b. @@ -386,19 +331,6 @@ func TestConnection_SendReceive(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("bar"), msg) - // Connections should still be active after closing the transports. - err = a.Close() - require.NoError(t, err) - err = b.Close() - require.NoError(t, err) - - err = ab.SendMessage(ctx, chID, []byte("still here")) - require.NoError(t, err) - ch, msg, err = ba.ReceiveMessage(ctx) - require.NoError(t, err) - require.Equal(t, chID, ch) - require.Equal(t, []byte("still here"), msg) - // Close one side of the connection. Both sides should then error // with io.EOF when trying to send or receive. err = ba.Close() @@ -424,8 +356,8 @@ func TestConnection_SendReceive(t *testing.T) { func TestConnection_String(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, _ := dialAccept(ctx, t, a, b) require.NotEmpty(t, ab.String()) }) @@ -433,9 +365,8 @@ func TestConnection_String(t *testing.T) { func TestEndpoint_NodeAddress(t *testing.T) { var ( - ip4 = []byte{1, 2, 3, 4} - ip4in6 = net.IPv4(1, 2, 3, 4) - ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + ip4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 = netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) id = types.NodeID("00112233445566778899aabbccddeeff00112233") ) @@ -445,15 +376,11 @@ func TestEndpoint_NodeAddress(t *testing.T) { }{ // Valid endpoints. { - p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080), Path: "path"}, p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, }, { - p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"}, - p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, - }, - { - p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 8080), Path: "path"}, p2p.NodeAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"}, }, { @@ -468,8 +395,8 @@ func TestEndpoint_NodeAddress(t *testing.T) { // Partial (invalid) endpoints. {p2p.Endpoint{}, p2p.NodeAddress{}}, {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}}, - {p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, - {p2p.Endpoint{Port: 8080}, p2p.NodeAddress{}}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080)}, p2p.NodeAddress{}}, {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}}, } for _, tc := range testcases { @@ -487,9 +414,8 @@ func TestEndpoint_NodeAddress(t *testing.T) { func TestEndpoint_String(t *testing.T) { var ( - ip4 = []byte{1, 2, 3, 4} - ip4in6 = net.IPv4(1, 2, 3, 4) - ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + ip4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 = netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) nodeID = types.NodeID("00112233445566778899aabbccddeeff00112233") ) @@ -503,24 +429,23 @@ func TestEndpoint_String(t *testing.T) { {p2p.Endpoint{Protocol: "file", Path: "👋"}, "file:///%F0%9F%91%8B"}, // IPv4 endpoints. - {p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,0)}, "tcp://1.2.3.4"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080)}, "tcp://1.2.3.4:8080"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080), Path: "/path"}, "tcp://1.2.3.4:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,0), Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, // IPv6 endpoints. - {p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,0)}, "tcp://b10c::1"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,8080)}, "tcp://[b10c::1]:8080"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,8080), Path: "/path"}, "tcp://[b10c::1]:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,0), Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, // Partial (invalid) endpoints. {p2p.Endpoint{}, ""}, {p2p.Endpoint{Protocol: "tcp"}, "tcp:"}, - {p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"}, - {p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"}, - {p2p.Endpoint{Port: 8080}, ""}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, "1.2.3.4"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip6,0)}, "b10c::1"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080)}, ""}, {p2p.Endpoint{Path: "foo"}, "/foo"}, } for _, tc := range testcases { @@ -531,30 +456,25 @@ func TestEndpoint_String(t *testing.T) { } func TestEndpoint_Validate(t *testing.T) { - var ( - ip4 = []byte{1, 2, 3, 4} - ip4in6 = net.IPv4(1, 2, 3, 4) - ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} - ) + ip4 := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 := netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) testcases := []struct { endpoint p2p.Endpoint expectValid bool }{ // Valid endpoints. - {p2p.Endpoint{Protocol: "tcp", IP: ip4}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,0)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,0)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8008)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080), Path: "path"}, true}, {p2p.Endpoint{Protocol: "memory", Path: "path"}, true}, // Invalid endpoints. {p2p.Endpoint{}, false}, - {p2p.Endpoint{IP: ip4}, false}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, false}, {p2p.Endpoint{Protocol: "tcp"}, false}, - {p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false}, - {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080), Path: "path"}, false}, } for _, tc := range testcases { t.Run(tc.endpoint.String(), func(t *testing.T) { @@ -573,9 +493,7 @@ func TestEndpoint_Validate(t *testing.T) { func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { t.Helper() - endpoint, err := b.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint, "peer not listening on any endpoints") + endpoint := b.Endpoint() ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() diff --git a/node/setup.go b/node/setup.go index 526382440..205c205c9 100644 --- a/node/setup.go +++ b/node/setup.go @@ -300,23 +300,21 @@ func createRouter( p2pLogger := logger.With("module", "p2p") + ep, err := p2p.NewEndpoint(nodeKey.ID.AddressString(cfg.P2P.ListenAddress)) + if err != nil { + return nil, err + } transportConf := conn.DefaultMConnConfig() transportConf.FlushThrottle = cfg.P2P.FlushThrottleTimeout transportConf.SendRate = cfg.P2P.SendRate transportConf.RecvRate = cfg.P2P.RecvRate transportConf.MaxPacketMsgPayloadSize = cfg.P2P.MaxPacketMsgPayloadSize transport := p2p.NewMConnTransport( - p2pLogger, transportConf, []*p2p.ChannelDescriptor{}, + p2pLogger, ep, transportConf, []*p2p.ChannelDescriptor{}, p2p.MConnTransportOptions{ MaxAcceptedConnections: uint32(cfg.P2P.MaxConnections), }, ) - - ep, err := p2p.NewEndpoint(nodeKey.ID.AddressString(cfg.P2P.ListenAddress)) - if err != nil { - return nil, err - } - return p2p.NewRouter( p2pLogger, p2pMetrics, @@ -324,7 +322,6 @@ func createRouter( peerManager, nodeInfoProducer, transport, - ep, nil, // TODO: replace with mempool CheckTx failure based filterer getRouterConfig(cfg, appClient), ) From 66111f82cc290b8f504dcc5300469d2b389d8a67 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Fri, 29 Aug 2025 16:31:43 +0200 Subject: [PATCH 30/41] fixed address tests --- internal/p2p/address.go | 1 + internal/p2p/address_test.go | 15 +++++++++++---- internal/p2p/transport.go | 9 ++++----- internal/p2p/transport_mconn_test.go | 5 +++++ internal/p2p/transport_test.go | 5 ++--- libs/utils/tcp/tcp.go | 5 +++++ 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/internal/p2p/address.go b/internal/p2p/address.go index 31d7561f7..7f7d48bbf 100644 --- a/internal/p2p/address.go +++ b/internal/p2p/address.go @@ -123,6 +123,7 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]Endpoint, error) { endpoints := make([]Endpoint, len(ips)) for i, ip := range ips { ip,ok := netip.AddrFromSlice(ip) + fmt.Printf("%v\n",ip) if !ok { return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) } endpoints[i] = Endpoint{ Protocol: a.Protocol, diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index e0e982d45..121550516 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -5,11 +5,10 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" - "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/libs/utils/tcp" + "github.com/tendermint/tendermint/libs/utils/require" "github.com/tendermint/tendermint/types" ) @@ -266,7 +265,15 @@ func TestNodeAddress_Resolve(t *testing.T) { require.Error(t, err) return } - require.Contains(t, endpoints, tc.expect) + ok := false + tc.expect.Addr = tcp.Norm(tc.expect.Addr) + for _,e := range endpoints { + e.Addr = tcp.Norm(e.Addr) + ok = ok || e==tc.expect + } + if !ok { + t.Fatalf("%v not in %v",tc.expect,endpoints) + } }) } t.Run("Resolve localhost", func(t *testing.T) { @@ -277,7 +284,7 @@ func TestNodeAddress_Resolve(t *testing.T) { for _, got := range endpoints { require.True(t, got.Addr.Addr().IsLoopback()) // Any loopback address is acceptable, so ignore it in comparison. - want := &p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(got.Addr.Addr(),80), Path: "/path"} + want := p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(got.Addr.Addr(),80), Path: "/path"} require.Equal(t, want, got) } }) diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index 1b202f8c9..c1ac098ab 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -161,11 +161,10 @@ func (e Endpoint) Validate() error { if e.Protocol == "" { return errors.New("endpoint has no protocol") } - if e.Addr==(netip.AddrPort{}) { - if e.Path == "" { - return errors.New("endpoint has neither path nor IP") - } - } else { + if (e.Addr==netip.AddrPort{}) && (e.Path=="") { + return errors.New("endpoint has neither path nor IP") + } + if e.Addr!=(netip.AddrPort{}) { if !e.Addr.IsValid() { return fmt.Errorf("endpoint has invalid address %q", e.Addr.String()) } diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 104b06ba2..f161cdc40 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -40,11 +40,16 @@ func init() { panic(err) } }() + if err:=transport.WaitForStart(ctx); err!=nil { + panic(err) + } return transport } } } +// Establishes a connection to the transport. +// Returns both ends of the connection. func connect(ctx context.Context, tr *p2p.MConnTransport) (c1 p2p.Connection, c2 p2p.Connection, err error) { defer func() { if err != nil { diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index 17958ea00..014d518b4 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -118,6 +118,7 @@ func TestTransport_DialEndpoints(t *testing.T) { } func TestTransport_Dial(t *testing.T) { + t.Skip() // TODO // Most just tests dial failures, happy path is tested widely elsewhere. withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() @@ -396,7 +397,6 @@ func TestEndpoint_NodeAddress(t *testing.T) { {p2p.Endpoint{}, p2p.NodeAddress{}}, {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}}, {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080)}, p2p.NodeAddress{}}, {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}}, } for _, tc := range testcases { @@ -445,7 +445,7 @@ func TestEndpoint_String(t *testing.T) { {p2p.Endpoint{Protocol: "tcp"}, "tcp:"}, {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, "1.2.3.4"}, {p2p.Endpoint{Addr: netip.AddrPortFrom(ip6,0)}, "b10c::1"}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080)}, ""}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080)}, "0.0.0.0:8080"}, {p2p.Endpoint{Path: "foo"}, "/foo"}, } for _, tc := range testcases { @@ -474,7 +474,6 @@ func TestEndpoint_Validate(t *testing.T) { {p2p.Endpoint{}, false}, {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, false}, {p2p.Endpoint{Protocol: "tcp"}, false}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080), Path: "path"}, false}, } for _, tc := range testcases { t.Run(tc.endpoint.String(), func(t *testing.T) { diff --git a/libs/utils/tcp/tcp.go b/libs/utils/tcp/tcp.go index 2a5dd6bfa..ccfe745b5 100644 --- a/libs/utils/tcp/tcp.go +++ b/libs/utils/tcp/tcp.go @@ -17,6 +17,11 @@ var reservedAddrs = utils.NewMutex(map[netip.AddrPort]struct{}{}) // IPv4Loopback returns the IPv4 loopback address. func IPv4Loopback() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } +// Norm normalizes address by unmapping IPv4 -> IPv6 embedding. +func Norm(addr netip.AddrPort) netip.AddrPort { + return netip.AddrPortFrom(addr.Addr().Unmap(),addr.Port()) +} + // Listen opens a TCP listener on the given address. // It takes into account the reserved addresses (in tests) and sets the SO_REUSEPORT. // nolint: contextcheck From 2e83f6ea73f955cb3ad77659933a8bd1ff97a2a0 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Fri, 29 Aug 2025 16:50:04 +0200 Subject: [PATCH 31/41] peermanager tests --- internal/p2p/peermanager.go | 11 +++++++- internal/p2p/peermanager_test.go | 43 ++++++++++++++++++++++---------- libs/utils/require/require.go | 21 ++++++++++++++++ 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 45d0d8e70..438641f4e 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -28,6 +28,15 @@ const ( retryNever time.Duration = math.MaxInt64 ) +type DialFailuresError struct { + Failures uint32 + Address types.NodeID +} + +func (e DialFailuresError) Error() string { + return fmt.Sprintf("dialing failed %d times will not retry for address=%s, deleting peer",e.Failures,e.Address) +} + // PeerStatus is a peer status. // // The peer manager has many more internal states for a peer (e.g. dialing, @@ -614,7 +623,7 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error if err := m.store.Delete(address.NodeID); err != nil { return err } - return fmt.Errorf("dialing failed %d times will not retry for address=%s, deleting peer", addressInfo.DialFailures, address.NodeID) + return DialFailuresError{addressInfo.DialFailures, address.NodeID} } go func() { // Use an explicit timer with deferred cleanup instead of diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 7a8f1344d..3a75bce04 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -12,11 +12,11 @@ import ( "github.com/fortytw2/leaktest" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" dbm "github.com/tendermint/tm-db" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils/require" ) // FIXME: We should probably have some randomized property-based tests for the @@ -432,7 +432,10 @@ func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { require.GreaterOrEqual(t, elapsed, time.Duration(math.Pow(2, float64(i)))*options.MinRetryTime) } if i == 3 { - require.ErrorContains(t, peerManager.DialFailed(ctx, a), "dialing failed 4 times") + if got,err:=(p2p.DialFailuresError{}),peerManager.DialFailed(ctx, a); !errors.As(err, &got) || got.Failures!=4 { + t.Errorf("expected 4 failures, got error %v", err) + } + continue } require.NoError(t, peerManager.DialFailed(ctx, a)) @@ -1045,7 +1048,9 @@ func TestPeerManager_Dialed_Upgrade(t *testing.T) { // a should now be evicted. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } } func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { @@ -1101,7 +1106,9 @@ func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { require.NoError(t, peerManager.Dialed(c)) evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, d.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=d.NodeID { + t.Fatalf("evict = %v, expected %s", evict, d.NodeID) + } } func TestPeerManager_Dialed_UpgradeNoEvict(t *testing.T) { @@ -1312,7 +1319,9 @@ func TestPeerManager_Accepted_Upgrade(t *testing.T) { // This should cause a to get evicted. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } peerManager.Disconnected(ctx, a.NodeID) // c still cannot get accepted, since it's not scored above b. @@ -1362,7 +1371,9 @@ func TestPeerManager_Accepted_UpgradeDialing(t *testing.T) { // This should cause a to get evicted, and the dial upgrade to fail. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } require.Error(t, peerManager.Dialed(b)) } @@ -1450,7 +1461,7 @@ func TestPeerManager_EvictNext(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err := peerManager.EvictNext(timeoutCtx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) // Since there are no more peers to evict, the next call should block. timeoutCtx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) @@ -1485,7 +1496,7 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) } func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { @@ -1525,7 +1536,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) } func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { @@ -1559,7 +1570,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) } func TestPeerManager_TryEvictNext(t *testing.T) { ctx := t.Context() @@ -1586,7 +1597,9 @@ func TestPeerManager_TryEvictNext(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err = peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } // While a is being evicted (before disconnect), it shouldn't get evicted again. evict, err = peerManager.TryEvictNext() @@ -1689,7 +1702,9 @@ func TestPeerManager_Errored(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err = peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } } func TestPeerManager_Subscribe(t *testing.T) { @@ -1738,7 +1753,9 @@ func TestPeerManager_Subscribe(t *testing.T) { evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } peerManager.Disconnected(ctx, a.NodeID) require.NotEmpty(t, sub.Updates()) diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go index 66bb750d3..a15dfb083 100644 --- a/libs/utils/require/require.go +++ b/libs/utils/require/require.go @@ -17,9 +17,19 @@ var False = require.False // True . var True = require.True +// Zero . +var Zero = require.Zero + +// NotZero . +var NotZero = require.NotZero + // Contains . var Contains = require.Contains +func ElementsMatch[T any](t TestingT, a []T, b []T, msgAndArgs ...any) { + require.ElementsMatch(t,a,b,msgAndArgs...) +} + // EqualError . // TODO: get rid of comparing errors by strings, // use concrete error types instead. @@ -65,11 +75,22 @@ func Less[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { require.Less(t, e1, e2, msgAndArgs...) } +// LessOrEqual . +func LessOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.LessOrEqual(t, e1, e2, msgAndArgs...) +} + // Greater . func Greater[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { require.Greater(t, e1, e2, msgAndArgs...) } +// GreaterOrEqual . +func GreaterOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.GreaterOrEqual(t, e1, e2, msgAndArgs...) +} + + // Equal . func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { require.Equal(t, expected, actual, msgAndArgs...) From c838558ae184a1e6d7613cc726b47f6f4dfeef06 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Fri, 29 Aug 2025 17:01:22 +0200 Subject: [PATCH 32/41] wip --- internal/mempool/reactor_test.go | 2 +- internal/p2p/transport_memory.go | 24 +++++------------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index 3387bcd9c..2de190e8e 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -174,7 +174,7 @@ func TestReactorBroadcastDoesNotPanic(t *testing.T) { go primaryReactor.broadcastTxRoutine(ctx, secondary, rts.mempoolChannels[primary]) wg := &sync.WaitGroup{} - for i := 0; i < 50; i++ { + for range 50 { next := &WrappedTx{} wg.Add(1) go func() { diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index ccbc3dc5a..79b442d66 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -11,6 +11,7 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils" ) const ( @@ -85,16 +86,12 @@ type MemoryTransport struct { // newMemoryTransport creates a new MemoryTransport. This is for internal use by // MemoryNetwork, use MemoryNetwork.CreateTransport() instead. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport { - once := &sync.Once{} - closeCh := make(chan struct{}) return &MemoryTransport{ logger: network.logger.With("local", nodeID), network: network, nodeID: nodeID, bufferSize: network.bufferSize, acceptCh: make(chan *MemoryConnection), - closeCh: closeCh, - closeFn: func() { once.Do(func() { close(closeCh) }) }, } } @@ -108,7 +105,6 @@ func (t *MemoryTransport) Run(ctx context.Context) error { t.network.mtx.Lock() delete(t.network.transports, t.nodeID) t.network.mtx.Unlock() - t.closeFn() return nil } @@ -132,15 +128,7 @@ func (t *MemoryTransport) Endpoint() Endpoint { // Accept implements Transport. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { - select { - case <-t.closeCh: - return nil, io.EOF - case conn := <-t.acceptCh: - t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path) - return conn, nil - case <-ctx.Done(): - return nil, io.EOF - } + return utils.Recv(ctx,t.acceptCh) } // Dial implements Transport. @@ -180,12 +168,10 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connecti inConn.closeCh = closeCh inConn.closeFn = closeFn - select { - case peer.acceptCh <- inConn: - return outConn, nil - case <-ctx.Done(): - return nil, io.EOF + if err:=utils.Send(ctx,peer.acceptCh,inConn); err!=nil { + return nil, err } + return outConn, nil } // MemoryConnection is an in-memory connection between two transport endpoints. From 0d858a7034c436df5774684ad6d6d699f68bad76 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 1 Sep 2025 13:46:13 +0200 Subject: [PATCH 33/41] pex kinda fixed --- internal/p2p/p2ptest/network.go | 2 + internal/p2p/peermanager.go | 18 +++--- internal/p2p/pex/reactor.go | 108 +++++++++++++------------------ internal/p2p/pex/reactor_test.go | 86 +++++++++++------------- internal/p2p/router.go | 60 +++++++---------- internal/p2p/transport_test.go | 41 ------------ 6 files changed, 118 insertions(+), 197 deletions(-) diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index a0f6ed51c..d4fcfd110 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -219,6 +219,7 @@ func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) { // Node is a node in a Network, with a Router and a PeerManager. type Node struct { + Logger log.Logger NodeID types.NodeID NodeInfo types.NodeInfo NodeAddress p2p.NodeAddress @@ -283,6 +284,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) }) return &Node{ + Logger: logger, NodeID: nodeID, NodeInfo: nodeInfo, NodeAddress: transport.Endpoint().NodeAddress(nodeID), diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 438641f4e..5210b1387 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -619,6 +619,7 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error // calculate the retry delay outside the goroutine, since it must hold // the mutex lock. if d := m.retryDelay(addressInfo.DialFailures, peer.Persistent); d != 0 && d != retryNever { + m.logger.Info("will dial","after",d) if d == m.options.MaxRetryTime { if err := m.store.Delete(address.NodeID); err != nil { return err @@ -626,12 +627,8 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error return DialFailuresError{addressInfo.DialFailures, address.NodeID} } go func() { - // Use an explicit timer with deferred cleanup instead of - // time.After(), to avoid leaking goroutines on PeerManager.Close(). - timer := time.NewTimer(d) - defer timer.Stop() select { - case <-timer.C: + case <-time.After(d): m.dialWaker.Wake() case <-ctx.Done(): } @@ -645,9 +642,11 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error // Dialed marks a peer as successfully dialed. Any further connections will be // rejected, and once disconnected the peer may be dialed again. -func (m *PeerManager) Dialed(address NodeAddress) error { +func (m *PeerManager) Dialed(address NodeAddress) (err error) { + m.logger.Info("DUPASO dial() LOCK","peer",address.NodeID[:5]) m.mtx.Lock() defer m.mtx.Unlock() + defer m.logger.Info("DUPASO dial() UNLOCK","peer",address.NodeID[:5],"err",err) delete(m.dialing, address.NodeID) @@ -664,8 +663,7 @@ func (m *PeerManager) Dialed(address NodeAddress) error { return fmt.Errorf("rejecting connection to self (%v)", address.NodeID) } if m.connected[address.NodeID] { - dupeConnectionErr := fmt.Errorf("cant dial, peer=%q is already connected", address.NodeID) - return dupeConnectionErr + return fmt.Errorf("cant dial, peer=%q is already connected", address.NodeID) } if m.options.MaxConnected > 0 && m.NumConnected() >= int(m.options.MaxConnected) { if upgradeFromPeer == "" || m.NumConnected() >= @@ -732,8 +730,7 @@ func (m *PeerManager) Accepted(peerID types.NodeID) error { return fmt.Errorf("rejecting connection from self (%v)", peerID) } if m.connected[peerID] { - dupeConnectionErr := fmt.Errorf("can't accept, peer=%q is already connected", peerID) - return dupeConnectionErr + return fmt.Errorf("can't accept, peer=%q is already connected", peerID) } if !m.options.isUnconditional(peerID) && m.options.MaxConnected > 0 && m.NumConnected() >= int(m.options.MaxConnected)+int(m.options.MaxConnectedUpgrade) { @@ -987,6 +984,7 @@ func (m *PeerManager) Subscribe(ctx context.Context) *PeerUpdates { // instance in a timely fashion and close the subscription when done, // otherwise the PeerManager will halt. func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) { + m.logger.Info("DUPASON REGISTER") m.mtx.Lock() defer m.mtx.Unlock() m.subscriptions[peerUpdates] = peerUpdates diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 972ed0499..a31114f94 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -2,6 +2,7 @@ package pex import ( "context" + "errors" "fmt" "sync" "time" @@ -11,6 +12,7 @@ import ( "github.com/tendermint/tendermint/internal/p2p/conn" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/service" + "github.com/tendermint/tendermint/libs/utils" protop2p "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -52,13 +54,7 @@ const ( fullCapacityInterval = 10 * time.Minute ) -type NoPeersAvailableError struct { - error -} - -func (e *NoPeersAvailableError) Error() string { - return fmt.Sprintf("no available peers to send a PEX request to (retrying)") -} +var NoPeersAvailableError = errors.New("no available peers to send a PEX request to (retrying)") // TODO: We should decide whether we want channel descriptors to be housed // within each reactor (as they are now) or, considering that the reactor doesn't @@ -114,7 +110,7 @@ type Reactor struct { channel *p2p.Channel // Used to signal a restart the node on the application level - restartCh chan struct{} + restartCh chan<- struct{} restartNoAvailablePeersWindow time.Duration } @@ -123,7 +119,7 @@ func NewReactor( logger log.Logger, peerManager *p2p.PeerManager, peerEvents p2p.PeerEventSubscriber, - restartCh chan struct{}, + restartCh chan<- struct{}, selfRemediationConfig *config.SelfRemediationConfig, ) *Reactor { r := &Reactor{ @@ -152,8 +148,8 @@ func (r *Reactor) SetChannel(ch *p2p.Channel) { // OnStop to ensure the outbound p2p Channels are closed. func (r *Reactor) OnStart(ctx context.Context) error { peerUpdates := r.peerEvents(ctx) - go r.processPexCh(ctx, r.channel) - go r.processPeerUpdates(ctx, peerUpdates) + r.Spawn("processPexCh",func(ctx context.Context) error { return r.processPexCh(ctx) }) + r.Spawn("processPeerUpdates",func(ctx context.Context) error { return r.processPeerUpdates(ctx, peerUpdates) }) return nil } @@ -163,16 +159,14 @@ func (r *Reactor) OnStop() {} // processPexCh implements a blocking event loop where we listen for p2p // Envelope messages from the pexCh. -func (r *Reactor) processPexCh(ctx context.Context, pexCh *p2p.Channel) { +func (r *Reactor) processPexCh(ctx context.Context) error { incoming := make(chan *p2p.Envelope) go func() { defer close(incoming) - iter := pexCh.Receive(ctx) + iter := r.channel.Receive(ctx) for iter.Next(ctx) { - select { - case <-ctx.Done(): + if err:=utils.Send(ctx, incoming, iter.Envelope()); err!=nil { return - case incoming <- iter.Envelope(): } } }() @@ -184,52 +178,48 @@ func (r *Reactor) processPexCh(ctx context.Context, pexCh *p2p.Channel) { lastNoAvailablePeersTime := time.Now() timer := time.NewTimer(0) - defer timer.Stop() - for { timer.Reset(nextPeerRequest) select { case <-ctx.Done(): - return + return nil case <-timer.C: // back off sending peer requests if there's none available. // Let the loop continue to handle incoming pex messages - if noAvailablePeerFailCounter > 0 { - waitPeriod := float64(noAvailablePeersWaitPeriod) * float64(noAvailablePeerFailCounter) - if time.Since(lastNoAvailablePeersTime).Seconds() < time.Duration(waitPeriod).Seconds() { - r.logger.Debug(fmt.Sprintf("waiting for more peers to become available still in the waitPeriod=%f\n", time.Duration(waitPeriod).Seconds())) - continue - } + waitPeriod := noAvailablePeersWaitPeriod * time.Duration(noAvailablePeerFailCounter) + if time.Since(lastNoAvailablePeersTime) < waitPeriod { + r.logger.Debug(fmt.Sprintf("waiting for more peers to become available still in the waitPeriod=%v\n", waitPeriod)) + continue } // Send a request for more peer addresses. - if err := r.sendRequestForPeers(ctx, pexCh); err != nil { - r.logger.Error("failed to send request for peers", "err", err) - if _, ok := err.(*NoPeersAvailableError); ok { + if err := r.sendRequestForPeers(ctx); err != nil { + r.logger.Error("DUPASO failed to send request for peers", "err", err) + if errors.Is(err,NoPeersAvailableError) { noAvailablePeerFailCounter++ lastNoAvailablePeersTime = time.Now() continue } - return + return err } noAvailablePeerFailCounter = 0 case envelope, ok := <-incoming: if !ok { - return // channel closed + return nil// channel closed } // A request from another peer, or a response to one of our requests. - dur, err := r.handlePexMessage(ctx, envelope, pexCh) + dur, err := r.handlePexMessage(ctx, envelope) if err != nil { r.logger.Error("failed to process message", "ch_id", envelope.ChannelID, "envelope", envelope, "err", err) - if serr := pexCh.SendError(ctx, p2p.PeerError{ + if serr := r.channel.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, Err: err, }); serr != nil { - return + return serr } } else if dur != 0 { // We got a useful result; update the poll timer. @@ -244,29 +234,26 @@ func (r *Reactor) processPexCh(ctx context.Context, pexCh *p2p.Channel) { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerUpdates) { +func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerUpdates) error { for { - select { - case <-ctx.Done(): - return - case peerUpdate := <-peerUpdates.Updates(): - r.processPeerUpdate(peerUpdate) - } + peerUpdate,err:=utils.Recv(ctx,peerUpdates.Updates()) + if err!=nil { return err } + r.processPeerUpdate(peerUpdate) } } // handlePexMessage handles envelopes sent from peers on the PexChannel. // If an update was received, a new polling interval is returned; otherwise the // duration is 0. -func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, pexCh *p2p.Channel) (time.Duration, error) { +func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope) (time.Duration, error) { logger := r.logger.With("peer", envelope.From) switch msg := envelope.Message.(type) { case *protop2p.PexRequest: + r.logger.Info("DUPAS PexRequest","from", envelope.From[:5]) // Verify that this peer hasn't sent us another request too recently. if err := r.markPeerRequest(envelope.From); err != nil { - r.logger.Error(fmt.Sprintf("PEX mark peer req from %s error %s", envelope.From, err)) - return 0, err + return 0, fmt.Errorf("PEX mark peer req from %s: %w", envelope.From, err) } // Fetch peers from the peer manager, convert NodeAddresses into URL @@ -278,22 +265,20 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, URL: addr.String(), } } - return 0, pexCh.Send(ctx, p2p.Envelope{ + return 0, r.channel.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &protop2p.PexResponse{Addresses: pexAddresses}, }) case *protop2p.PexResponse: + r.logger.Info("DUPAS PexResponse","from",envelope.From[:5],"got",msg.Addresses) // Verify that this response corresponds to one of our pending requests. if err := r.markPeerResponse(envelope.From); err != nil { - r.logger.Error(fmt.Sprintf("PEX mark peer resp from %s error %s", envelope.From, err)) - return 0, err + return 0, fmt.Errorf("PEX mark peer resp from %s: %w", envelope.From, err) } // Verify that the response does not exceed the safety limit. if len(msg.Addresses) > maxAddresses { - r.logger.Error(fmt.Sprintf("peer %s sent too many addresses (%d > maxiumum %d)", - envelope.From, len(msg.Addresses), maxAddresses)) return 0, fmt.Errorf("peer sent too many addresses (%d > maxiumum %d)", len(msg.Addresses), maxAddresses) } @@ -302,17 +287,17 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, for _, pexAddress := range msg.Addresses { peerAddress, err := p2p.ParseNodeAddress(pexAddress.URL) if err != nil { - r.logger.Error(fmt.Sprintf("PEX parse node address error %s", err)) - continue + return 0,fmt.Errorf("PEX parse node address error %s", err) } added, err := r.peerManager.Add(peerAddress) if err != nil { - logger.Error("failed to add PEX address", "address", peerAddress, "err", err) + // TODO(gprusak): This does not distinguish between bad messages (should drop peer) and internal errors (ignore/abort). + logger.Error("DUPAS failed to add PEX address", "address", peerAddress, "err", err) continue } if added { numAdded++ - logger.Debug("added PEX address", "address", peerAddress) + logger.Info("DUPAS added PEX address", "address", peerAddress.NodeID) } } @@ -357,11 +342,11 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // that peer a request for more peer addresses. The chosen peer is moved into // the requestsSent bucket so that we will not attempt to contact them again // until they've replied or updated. -func (r *Reactor) sendRequestForPeers(ctx context.Context, pexCh *p2p.Channel) error { +func (r *Reactor) sendRequestForPeers(ctx context.Context) error { r.mtx.Lock() defer r.mtx.Unlock() if len(r.availablePeers) == 0 { - return &NoPeersAvailableError{} + return NoPeersAvailableError } // Select an arbitrary peer from the available set. @@ -369,19 +354,16 @@ func (r *Reactor) sendRequestForPeers(ctx context.Context, pexCh *p2p.Channel) e for peerID = range r.availablePeers { break } - - if err := pexCh.Send(ctx, p2p.Envelope{ - To: peerID, - Message: &protop2p.PexRequest{}, - }); err != nil { - return err - } - // Move the peer from available to pending. delete(r.availablePeers, peerID) r.requestsSent[peerID] = struct{}{} - return nil + r.logger.Info("DUPASO PexRequest","to",peerID[:5]) + // TODO(gprusak): blocking send while holding a mutex. + return r.channel.Send(ctx, p2p.Envelope{ + To: peerID, + Message: &protop2p.PexRequest{}, + }) } // calculateNextRequestTime selects how long we should wait before attempting diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index a1c2d20af..cea9d3fe4 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -60,10 +60,10 @@ func TestReactorConnectFullNetwork(t *testing.T) { // make every node be only connected with one other node (it actually ends up // being two because of two way connections but oh well) - testNet.connectN(ctx, t, 1) + testNet.seedAddrs(t) testNet.start(ctx, t) - // assert that all nodes add each other in the network + t.Logf("assert that all nodes add each other in the network") for idx := 0; idx < len(testNet.nodes); idx++ { testNet.requireNumberOfPeers(t, idx, len(testNet.nodes)-1, longWait) } @@ -152,7 +152,7 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { require.True(t, added) addresses := make([]p2pproto.PexAddress, 101) - for i := 0; i < len(addresses); i++ { + for i := range addresses { nodeAddress := p2p.NodeAddress{Protocol: p2p.MemoryProtocol, NodeID: randomNodeID()} addresses[i] = p2pproto.PexAddress{ URL: nodeAddress.String(), @@ -192,17 +192,18 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ - TotalNodes: 8, - MaxPeers: 4, + TotalNodes: 4, + MaxPeers: 3, MaxConnected: 3, BufferSize: 8, MaxRetryTime: 5 * time.Minute, }) - testNet.connectN(ctx, t, 1) + testNet.seedAddrs(t) testNet.start(ctx, t) - // test that all nodes reach full capacity + t.Logf("test that all nodes reach full capacity") for _, nodeID := range testNet.nodes { + t.Logf("OK %v",nodeID) require.Eventually(t, func() bool { // nolint:scopelint return testNet.network.Nodes[nodeID].PeerManager.PeerRatio() >= 0.9 @@ -221,7 +222,7 @@ func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { BufferSize: 5, MaxRetryTime: 5 * time.Minute, }) - testNet.connectN(ctx, t, 1) + testNet.seedAddrs(t) testNet.start(ctx, t) // assert that all nodes add each other in the network @@ -318,13 +319,11 @@ func setupSingle(ctx context.Context, t *testing.T) *singleTestReactor { type reactorTestSuite struct { network *p2ptest.Network - logger log.Logger reactors map[types.NodeID]*pex.Reactor pexChannels map[types.NodeID]*p2p.Channel peerChans map[types.NodeID]chan p2p.PeerUpdate - peerUpdates map[types.NodeID]*p2p.PeerUpdates nodes []types.NodeID mocks []types.NodeID @@ -363,12 +362,10 @@ func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorT realNodes := opts.TotalNodes - opts.MockNodes rts := &reactorTestSuite{ - logger: log.NewNopLogger().With("testCase", t.Name()), network: p2ptest.MakeNetwork(ctx, t, networkOpts), reactors: make(map[types.NodeID]*pex.Reactor, realNodes), pexChannels: make(map[types.NodeID]*p2p.Channel, opts.TotalNodes), peerChans: make(map[types.NodeID]chan p2p.PeerUpdate, opts.TotalNodes), - peerUpdates: make(map[types.NodeID]*p2p.PeerUpdates, opts.TotalNodes), total: opts.TotalNodes, opts: opts, } @@ -380,17 +377,17 @@ func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorT idx := 0 for nodeID := range rts.network.Nodes { rts.peerChans[nodeID] = make(chan p2p.PeerUpdate, chBuf) - rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], chBuf) - rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) + peerUpdates := p2p.NewPeerUpdates(rts.peerChans[nodeID], chBuf) + rts.network.Nodes[nodeID].PeerManager.Register(ctx, peerUpdates) // the first nodes in the array are always mock nodes if idx < opts.MockNodes { rts.mocks = append(rts.mocks, nodeID) } else { rts.reactors[nodeID] = pex.NewReactor( - rts.logger.With("nodeID", nodeID), + rts.network.Nodes[nodeID].Logger, rts.network.Nodes[nodeID].PeerManager, - func(_ context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] }, + func(_ context.Context) *p2p.PeerUpdates { return peerUpdates }, make(chan struct{}), config.DefaultSelfRemediationConfig(), ) @@ -439,13 +436,13 @@ func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int nodeID := node.NodeID r.pexChannels[nodeID] = node.MakeChannelNoCleanup(t, pex.ChannelDescriptor()) r.peerChans[nodeID] = make(chan p2p.PeerUpdate, r.opts.BufferSize) - r.peerUpdates[nodeID] = p2p.NewPeerUpdates(r.peerChans[nodeID], r.opts.BufferSize) - r.network.Nodes[nodeID].PeerManager.Register(ctx, r.peerUpdates[nodeID]) + peerUpdates := p2p.NewPeerUpdates(r.peerChans[nodeID], r.opts.BufferSize) + r.network.Nodes[nodeID].PeerManager.Register(ctx, peerUpdates) r.reactors[nodeID] = pex.NewReactor( - r.logger.With("nodeID", nodeID), + r.network.Nodes[nodeID].Logger, r.network.Nodes[nodeID].PeerManager, - func(_ context.Context) *p2p.PeerUpdates { return r.peerUpdates[nodeID] }, + func(_ context.Context) *p2p.PeerUpdates { return peerUpdates }, make(chan struct{}), config.DefaultSelfRemediationConfig(), ) @@ -628,22 +625,24 @@ func (r *reactorTestSuite) requireNumberOfPeers( } func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { - r.connectN(ctx, t, r.total-1) -} - -// connects all nodes with n other nodes -func (r *reactorTestSuite) connectN(ctx context.Context, t *testing.T, n int) { - if n >= r.total { - require.Fail(t, "connectN: n must be less than the size of the network - 1") - } - - for i := 0; i < r.total; i++ { - for j := range n { + for i := range r.total { + for j := range r.total-1 { r.connectPeers(ctx, t, i, (i+j+1)%r.total) } } } +// Adds enough addresses to peerManagers, so that all nodes are discoverable. +func (r *reactorTestSuite) seedAddrs(t *testing.T) { + t.Helper() + for i := range r.total-1 { + n1 := r.network.Nodes[r.nodes[i]] + n2 := r.network.Nodes[r.nodes[i+1]] + _,err := n1.PeerManager.Add(n2.NodeAddress) + require.NoError(t, err) + } +} + // connects node1 to node2 func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourceNode, targetNode int) { t.Helper() @@ -661,6 +660,9 @@ func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourc return } + // Subscription is for the ctx lifetime. + ctx,cancel := context.WithCancel(ctx) + defer cancel() sourceSub := n1.PeerManager.Subscribe(ctx) targetSub := n2.PeerManager.Subscribe(ctx) @@ -674,22 +676,12 @@ func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourc return } - select { - case peerUpdate := <-targetSub.Updates(): - require.Equal(t, peerUpdate.NodeID, node1) - require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) - case <-time.After(2 * time.Second): - require.Fail(t, "timed out waiting for peer", "%v accepting %v", - targetNode, sourceNode) - } - select { - case peerUpdate := <-sourceSub.Updates(): - require.Equal(t, peerUpdate.NodeID, node2) - require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) - case <-time.After(2 * time.Second): - require.Fail(t, "timed out waiting for peer", "%v dialing %v", - sourceNode, targetNode) - } + peerUpdate := <-targetSub.Updates() + require.Equal(t, peerUpdate.NodeID, node1) + require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) + peerUpdate = <-sourceSub.Updates() + require.Equal(t, peerUpdate.NodeID, node2) + require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) added, err = n2.PeerManager.Add(sourceAddress) require.NoError(t, err) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 73e82dd81..e7b281a95 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -8,7 +8,6 @@ import ( "math/rand" "net/netip" "runtime" - "strings" "sync" "time" @@ -406,6 +405,7 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { if err != nil { return fmt.Errorf("failed to accept connection: %w", err) } + r.logger.Info("DUPASON accepted") r.metrics.NewConnections.With("direction","in").Add(1) incomingAddr := conn.RemoteEndpoint().Addr if err := r.connTracker.AddConn(incomingAddr); err != nil { @@ -458,15 +458,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { r.logger.Debug("peer filtered by node ID", "node", peerInfo.NodeID, "err", err) return nil } - - // TODO(gprusak): this is fragile that updating peerManager requires a lock on peerStates. - // If this is intended, they should just share the same mutex. - // Also currently the pattern of keeping the mutex locked for peerManager accesses is inconsistent. - if err := r.runWithPeerMutex(func() error { return r.peerManager.Accepted(peerInfo.NodeID) }); err != nil { - // If peer is trying to reconnect, error and let it reconnect - if strings.Contains(err.Error(), "is already connected") { - r.peerManager.Errored(peerInfo.NodeID, err) - } + if err := r.peerManager.Accepted(peerInfo.NodeID); err != nil { return fmt.Errorf("failed to accept connection: op=incoming/accepted, peer=%v: %w", peerInfo.NodeID, err) } return r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) @@ -489,7 +481,7 @@ func (r *Router) dialPeers(ctx context.Context) error { if err != nil { return err } - r.logger.Debug(fmt.Sprintf("Going to dial next peer %s", address.NodeID)) + r.logger.Info(fmt.Sprintf("DUPASO Going to dial next peer %s", address.NodeID[:5])) r.connectPeer(ctx, address) } }) @@ -519,9 +511,9 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { case errors.Is(err, context.Canceled): return case err != nil: - r.logger.Debug("failed to dial peer", "peer", address, "err", err) + r.logger.Info("DUPASO failed to dial peer", "peer", address, "err", err) if err = r.peerManager.DialFailed(ctx, address); err != nil { - r.logger.Debug("failed to report dial failure", "peer", address, "err", err) + r.logger.Info("DUPASO failed to report dial failure", "peer", address, "err", err) } return } @@ -533,30 +525,24 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return case err != nil: r.logger.Debug("failed to handshake with peer", "peer", address, "err", err) - if err = r.peerManager.DialFailed(ctx, address); err != nil { + if err := r.peerManager.DialFailed(ctx, address); err != nil { r.logger.Error("failed to report dial failure", "peer", address, "err", err) } conn.Close() return } - if err := r.runWithPeerMutex(func() error { return r.peerManager.Dialed(address) }); err != nil { - // If peer is trying to reconnect, fail it and let it reconnect - // TODO(gprusak): this symmetric logic for handling duplicate connections is a source of race conditions: - // if 2 nodes try to establish a connection to each other at the same time, both connections will be dropped. - // Instead either: - // * break the symmetry by favoring incoming connection iff my.NodeID > peer.NodeID - // * keep incoming and outcoming connection pools separate to avoid the collision (recommended) - if strings.Contains(err.Error(), "is already connected") { - r.logger.Error(fmt.Sprintf("Disconnecting %s because of %s", address.NodeID, err)) - r.peerManager.Disconnected(ctx, address.NodeID) - } - - r.logger.Debug("failed to dial peer", - "op", "outgoing/dialing", "peer", address.NodeID, "err", err) + // TODO(gprusak): this symmetric logic for handling duplicate connections is a source of race conditions: + // if 2 nodes try to establish a connection to each other at the same time, both connections will be dropped. + // Instead either: + // * break the symmetry by favoring incoming connection iff my.NodeID > peer.NodeID + // * keep incoming and outcoming connection pools separate to avoid the collision (recommended) + if err := r.peerManager.Dialed(address); err != nil { + r.logger.Info("failed to dial peer", "op", "outgoing/dialing", "peer", address.NodeID, "err", err) conn.Close() return } + r.logger.Info("DUPASON dial+hs+dial", "to",address.NodeID[:5], "err", err) r.Spawn("routePeer", func(ctx context.Context) error { defer conn.Close() @@ -573,7 +559,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, defer cancel() } - r.logger.Debug("dialing peer address", "peer", address) + r.logger.Info("dialing peer address", "peer", address) endpoints, err := address.Resolve(resolveCtx) switch { case err != nil: @@ -604,7 +590,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, // Internet can't and needs a different public address. conn, err := r.transport.Dial(dialCtx, endpoint) if err != nil { - r.logger.Debug("failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) + r.logger.Info("DUPASO failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) } else { r.metrics.NewConnections.With("direction","out").Add(1) r.logger.Debug("dialed peer", "peer", address.NodeID, "endpoint", endpoint) @@ -675,9 +661,10 @@ func (r *Router) runWithPeerMutex(fn func() error) error { // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) error { + r.logger.Info("DUPASON [PRE] new connection","from",peerID[:5]) r.metrics.Peers.Add(1) r.peerManager.Ready(ctx, peerID, channels) - + r.logger.Info("DUPASON new connection","from",peerID[:5]) peerCtx, cancel := context.WithCancel(ctx) state := &peerState{ cancel: cancel, @@ -725,27 +712,27 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn return err } + r.logger.Info("DUPA RECV","from",peerID[:5],"chan",chID) r.channelMtx.RLock() queue, ok := r.channelQueues[chID] messageType := r.channelMessages[chID] r.channelMtx.RUnlock() if !ok { + // TODO(gprusak): verify if this is a misbehavior, and drop the peer if it is. r.logger.Debug("dropping message for unknown channel", "peer", peerID, "channel", chID) continue } msg := proto.Clone(messageType) if err := proto.Unmarshal(bz, msg); err != nil { - r.logger.Error("message decoding failed, dropping message", "peer", peerID, "err", err) - continue + return fmt.Errorf("message decoding failed, dropping message: [peer=%v] %w", peerID, err) } if wrapper, ok := msg.(Wrapper); ok { msg, err = wrapper.Unwrap() if err != nil { - r.logger.Error("failed to unwrap message", "err", err) - continue + return fmt.Errorf("failed to unwrap message: %w", err) } } @@ -781,6 +768,7 @@ func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connect continue } + r.logger.Info("DUPA SEND","to",envelope.To[:5],"chan",envelope.ChannelID) if err = conn.SendMessage(ctx, envelope.ChannelID, bz); err != nil { r.logger.Error("failed to send message", "peer", peerID, "err", err) return err @@ -797,9 +785,9 @@ func (r *Router) evictPeers(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to find next peer to evict: %w", err) } - r.logger.Info("evicting peer", "peer", ev.ID,"cause",ev.Cause) for states := range r.peerStates.Lock() { if s, ok := states[ev.ID]; ok { + r.logger.Info("evicting peer", "peer", ev.ID,"cause",ev.Cause) s.cancel() } } diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index 014d518b4..119274a8d 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -117,48 +117,7 @@ func TestTransport_DialEndpoints(t *testing.T) { }) } -func TestTransport_Dial(t *testing.T) { - t.Skip() // TODO - // Most just tests dial failures, happy path is tested widely elsewhere. - withTransports(t, func(t *testing.T, makeTransport transportFactory) { - ctx := t.Context() - a := makeTransport(ctx) - b := makeTransport(ctx) - - aEndpoint := a.Endpoint() - bEndpoint := b.Endpoint() - - // Context cancellation should error. We can't test timeouts since we'd - // need a non-responsive endpoint. - cancelCtx, cancel := context.WithCancel(ctx) - cancel() - _, err := a.Dial(cancelCtx, bEndpoint) - require.Error(t, err) - - // Unavailable endpoint should error. - // TODO: err = b.Close() - require.NoError(t, err) - _, err = a.Dial(ctx, bEndpoint) - require.Error(t, err) - - // Dialing from a closed transport should still work. - errCh := make(chan error, 1) - go func() { - conn, err := a.Accept(ctx) - if err == nil { - _ = conn.Close() - } - errCh <- err - }() - conn, err := b.Dial(ctx, aEndpoint) - require.NoError(t, err) - require.NoError(t, conn.Close()) - require.NoError(t, <-errCh) - }) -} - func TestTransport_Endpoints(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() a := makeTransport(ctx) From 9b43f0614c911eee55e2ee9035ab748a5c79c841 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 1 Sep 2025 14:35:30 +0200 Subject: [PATCH 34/41] p2p tests pass --- internal/p2p/peermanager.go | 6 +--- internal/p2p/pex/reactor.go | 9 ++--- internal/p2p/pex/reactor_test.go | 22 +++++++----- internal/p2p/router.go | 16 +++------ internal/p2p/router_test.go | 59 ++++++++++++-------------------- libs/utils/require/require.go | 3 ++ 6 files changed, 47 insertions(+), 68 deletions(-) diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 5210b1387..1ee669ea9 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -619,7 +619,6 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error // calculate the retry delay outside the goroutine, since it must hold // the mutex lock. if d := m.retryDelay(addressInfo.DialFailures, peer.Persistent); d != 0 && d != retryNever { - m.logger.Info("will dial","after",d) if d == m.options.MaxRetryTime { if err := m.store.Delete(address.NodeID); err != nil { return err @@ -642,11 +641,9 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error // Dialed marks a peer as successfully dialed. Any further connections will be // rejected, and once disconnected the peer may be dialed again. -func (m *PeerManager) Dialed(address NodeAddress) (err error) { - m.logger.Info("DUPASO dial() LOCK","peer",address.NodeID[:5]) +func (m *PeerManager) Dialed(address NodeAddress) error { m.mtx.Lock() defer m.mtx.Unlock() - defer m.logger.Info("DUPASO dial() UNLOCK","peer",address.NodeID[:5],"err",err) delete(m.dialing, address.NodeID) @@ -984,7 +981,6 @@ func (m *PeerManager) Subscribe(ctx context.Context) *PeerUpdates { // instance in a timely fashion and close the subscription when done, // otherwise the PeerManager will halt. func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) { - m.logger.Info("DUPASON REGISTER") m.mtx.Lock() defer m.mtx.Unlock() m.subscriptions[peerUpdates] = peerUpdates diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index a31114f94..138ac3795 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -196,7 +196,7 @@ func (r *Reactor) processPexCh(ctx context.Context) error { // Send a request for more peer addresses. if err := r.sendRequestForPeers(ctx); err != nil { - r.logger.Error("DUPASO failed to send request for peers", "err", err) + r.logger.Error("failed to send request for peers", "err", err) if errors.Is(err,NoPeersAvailableError) { noAvailablePeerFailCounter++ lastNoAvailablePeersTime = time.Now() @@ -250,7 +250,6 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope) switch msg := envelope.Message.(type) { case *protop2p.PexRequest: - r.logger.Info("DUPAS PexRequest","from", envelope.From[:5]) // Verify that this peer hasn't sent us another request too recently. if err := r.markPeerRequest(envelope.From); err != nil { return 0, fmt.Errorf("PEX mark peer req from %s: %w", envelope.From, err) @@ -271,7 +270,6 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope) }) case *protop2p.PexResponse: - r.logger.Info("DUPAS PexResponse","from",envelope.From[:5],"got",msg.Addresses) // Verify that this response corresponds to one of our pending requests. if err := r.markPeerResponse(envelope.From); err != nil { return 0, fmt.Errorf("PEX mark peer resp from %s: %w", envelope.From, err) @@ -292,12 +290,12 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope) added, err := r.peerManager.Add(peerAddress) if err != nil { // TODO(gprusak): This does not distinguish between bad messages (should drop peer) and internal errors (ignore/abort). - logger.Error("DUPAS failed to add PEX address", "address", peerAddress, "err", err) + logger.Error("failed to add PEX address", "address", peerAddress, "err", err) continue } if added { numAdded++ - logger.Info("DUPAS added PEX address", "address", peerAddress.NodeID) + logger.Debug("added PEX address", "address", peerAddress) } } @@ -358,7 +356,6 @@ func (r *Reactor) sendRequestForPeers(ctx context.Context) error { delete(r.availablePeers, peerID) r.requestsSent[peerID] = struct{}{} - r.logger.Info("DUPASO PexRequest","to",peerID[:5]) // TODO(gprusak): blocking send while holding a mutex. return r.channel.Send(ctx, p2p.Envelope{ To: peerID, diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index cea9d3fe4..714e0706c 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -192,22 +192,21 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { ctx := t.Context() testNet := setupNetwork(ctx, t, testOptions{ - TotalNodes: 4, - MaxPeers: 3, - MaxConnected: 3, - BufferSize: 8, + TotalNodes: 8, + MaxPeers: 7, // total-1, because PeerManager doesn't count self + MaxConnected: 2, // enough capacity to establish a connected graph + BufferSize: 8, // reactor deadlocks if peer updates' subscribers are full (which is stupid) MaxRetryTime: 5 * time.Minute, }) - testNet.seedAddrs(t) + testNet.connectCycle(ctx, t) // Saturate capacity by connecting nodes in a cycle. testNet.start(ctx, t) - t.Logf("test that all nodes reach full capacity") + t.Logf("test that peers are gossiped even if connection cap is reached") for _, nodeID := range testNet.nodes { - t.Logf("OK %v",nodeID) require.Eventually(t, func() bool { // nolint:scopelint return testNet.network.Nodes[nodeID].PeerManager.PeerRatio() >= 0.9 - }, longWait, checkFrequency, + }, time.Minute, checkFrequency, "peer ratio is: %f", testNet.network.Nodes[nodeID].PeerManager.PeerRatio()) } } @@ -624,6 +623,13 @@ func (r *reactorTestSuite) requireNumberOfPeers( ) } +func (r *reactorTestSuite) connectCycle(ctx context.Context, t *testing.T) { + if r.total==0 { return } + for i := range r.total { + r.connectPeers(ctx, t, i, (i+1)%r.total) + } +} + func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { for i := range r.total { for j := range r.total-1 { diff --git a/internal/p2p/router.go b/internal/p2p/router.go index e7b281a95..55055b4f4 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -405,7 +405,6 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { if err != nil { return fmt.Errorf("failed to accept connection: %w", err) } - r.logger.Info("DUPASON accepted") r.metrics.NewConnections.With("direction","in").Add(1) incomingAddr := conn.RemoteEndpoint().Addr if err := r.connTracker.AddConn(incomingAddr); err != nil { @@ -481,7 +480,7 @@ func (r *Router) dialPeers(ctx context.Context) error { if err != nil { return err } - r.logger.Info(fmt.Sprintf("DUPASO Going to dial next peer %s", address.NodeID[:5])) + r.logger.Debug(fmt.Sprintf("Going to dial next peer %s", address.NodeID)) r.connectPeer(ctx, address) } }) @@ -511,9 +510,9 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { case errors.Is(err, context.Canceled): return case err != nil: - r.logger.Info("DUPASO failed to dial peer", "peer", address, "err", err) + r.logger.Debug("failed to dial peer", "peer", address, "err", err) if err = r.peerManager.DialFailed(ctx, address); err != nil { - r.logger.Info("DUPASO failed to report dial failure", "peer", address, "err", err) + r.logger.Debug("failed to report dial failure", "peer", address, "err", err) } return } @@ -542,7 +541,6 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { conn.Close() return } - r.logger.Info("DUPASON dial+hs+dial", "to",address.NodeID[:5], "err", err) r.Spawn("routePeer", func(ctx context.Context) error { defer conn.Close() @@ -559,7 +557,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, defer cancel() } - r.logger.Info("dialing peer address", "peer", address) + r.logger.Debug("dialing peer address", "peer", address) endpoints, err := address.Resolve(resolveCtx) switch { case err != nil: @@ -590,7 +588,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, // Internet can't and needs a different public address. conn, err := r.transport.Dial(dialCtx, endpoint) if err != nil { - r.logger.Info("DUPASO failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) + r.logger.Debug("failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) } else { r.metrics.NewConnections.With("direction","out").Add(1) r.logger.Debug("dialed peer", "peer", address.NodeID, "endpoint", endpoint) @@ -661,10 +659,8 @@ func (r *Router) runWithPeerMutex(fn func() error) error { // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) error { - r.logger.Info("DUPASON [PRE] new connection","from",peerID[:5]) r.metrics.Peers.Add(1) r.peerManager.Ready(ctx, peerID, channels) - r.logger.Info("DUPASON new connection","from",peerID[:5]) peerCtx, cancel := context.WithCancel(ctx) state := &peerState{ cancel: cancel, @@ -712,7 +708,6 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn return err } - r.logger.Info("DUPA RECV","from",peerID[:5],"chan",chID) r.channelMtx.RLock() queue, ok := r.channelQueues[chID] messageType := r.channelMessages[chID] @@ -768,7 +763,6 @@ func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connect continue } - r.logger.Info("DUPA SEND","to",envelope.To[:5],"chan",envelope.ChannelID) if err = conn.SendMessage(ctx, envelope.ChannelID, bz); err != nil { r.logger.Error("failed to send message", "peer", peerID, "err", err) return err diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 433b477c4..50b15957d 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -16,7 +16,6 @@ import ( "github.com/gogo/protobuf/proto" gogotypes "github.com/gogo/protobuf/types" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" dbm "github.com/tendermint/tm-db" "github.com/tendermint/tendermint/crypto" @@ -24,6 +23,7 @@ import ( "github.com/tendermint/tendermint/internal/p2p/mocks" "github.com/tendermint/tendermint/internal/p2p/p2ptest" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils/require" "github.com/tendermint/tendermint/types" ) @@ -379,11 +379,9 @@ func TestRouter_AcceptPeers(t *testing.T) { } mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil).Maybe() mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -438,10 +436,8 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { // Set up a mock transport that returns io.EOF once, which should prevent // the router from calling Accept again. mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) - mockTransport.On("Close").Return(nil) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -489,13 +485,11 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Times(3).Run(func(_ mock.Arguments) { acceptCh <- true }).Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -558,7 +552,7 @@ func TestRouter_DialPeers(t *testing.T) { ctx := t.Context() address := p2p.NodeAddress{Protocol: "mock", NodeID: tc.dialID} - endpoint := &p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} + endpoint := p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} // Set up a mock transport that handshakes. connCtx, connCancel := context.WithCancel(ctx) @@ -575,10 +569,8 @@ func TestRouter_DialPeers(t *testing.T) { } mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil).Maybe() - mockTransport.On("Listen", mock.Anything).Return(nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) + mockTransport.On("Run", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) if tc.dialErr == nil { mockTransport.On("Dial", mock.Anything, endpoint).Once().Return(mockConnection, nil) // This handles the retry when a dialed connection gets closed after ReceiveMessage @@ -656,12 +648,10 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { mockConnection.On("Close").Return(nil) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) - mockTransport.On("Listen", mock.Anything).Return(nil) - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) + mockTransport.On("Run", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) for _, address := range []p2p.NodeAddress{a, b, c} { - endpoint := &p2p.Endpoint{Protocol: address.Protocol, Path: string(address.NodeID)} + endpoint := p2p.Endpoint{Protocol: address.Protocol, Path: string(address.NodeID)} mockTransport.On("Dial", mock.Anything, endpoint).Run(func(_ mock.Arguments) { dialCh <- true }).Return(mockConnection, nil) @@ -745,11 +735,9 @@ func TestRouter_EvictPeers(t *testing.T) { }).Return(nil) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -808,11 +796,9 @@ func TestRouter_ChannelCompatability(t *testing.T) { mockConnection.On("Close").Return(nil) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) + mockTransport.On("Run", mock.Anything).Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -860,11 +846,9 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { mockTransport := &mocks.Transport{} mockTransport.On("AddChannelDescriptors", mock.Anything).Return() - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -927,10 +911,9 @@ func TestRouter_Channel_FilterByID(t *testing.T) { mockTransport := &mocks.Transport{} mockTransport.On("AddChannelDescriptors", mock.Anything).Return() mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go index a15dfb083..1bfa00eda 100644 --- a/libs/utils/require/require.go +++ b/libs/utils/require/require.go @@ -30,6 +30,9 @@ func ElementsMatch[T any](t TestingT, a []T, b []T, msgAndArgs ...any) { require.ElementsMatch(t,a,b,msgAndArgs...) } +// Eventually . +var Eventually = require.Eventually + // EqualError . // TODO: get rid of comparing errors by strings, // use concrete error types instead. From 9da3bfcbb81e83ca0b4e2e84217c7f2e8d0689e0 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 1 Sep 2025 14:44:25 +0200 Subject: [PATCH 35/41] test fix --- types/node_info_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/types/node_info_test.go b/types/node_info_test.go index 1f8480d02..16a47fdb7 100644 --- a/types/node_info_test.go +++ b/types/node_info_test.go @@ -242,11 +242,10 @@ func TestParseAddressString(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - addr, port, err := ParseAddressString(tc.addr) + addr, err := ParseAddressString(tc.addr) if tc.correct { require.NoError(t, err, tc.addr) assert.Contains(t, tc.expected, addr.String()) - assert.Contains(t, tc.expected, fmt.Sprint(port)) } else { assert.Error(t, err, "%v", tc.addr) } From 19a29cb9eabafa8841b75cd6b5b51082a6a45fde Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 1 Sep 2025 14:44:47 +0200 Subject: [PATCH 36/41] fmt --- internal/p2p/address.go | 8 ++- internal/p2p/address_test.go | 20 +++--- internal/p2p/conn/connection.go | 14 ++-- internal/p2p/p2ptest/network.go | 2 +- internal/p2p/peermanager.go | 24 +++---- internal/p2p/peermanager_test.go | 18 ++--- internal/p2p/pex/reactor.go | 18 ++--- internal/p2p/pex/reactor_test.go | 14 ++-- internal/p2p/router.go | 11 ++- internal/p2p/transport.go | 9 ++- internal/p2p/transport_mconn.go | 18 ++--- internal/p2p/transport_mconn_test.go | 96 ++++++++++++++++----------- internal/p2p/transport_memory.go | 6 +- internal/p2p/transport_memory_test.go | 8 ++- internal/p2p/transport_test.go | 48 +++++++------- libs/utils/require/require.go | 3 +- libs/utils/tcp/tcp.go | 2 +- node/node.go | 2 +- 18 files changed, 173 insertions(+), 148 deletions(-) diff --git a/internal/p2p/address.go b/internal/p2p/address.go index 7f7d48bbf..d67f3c906 100644 --- a/internal/p2p/address.go +++ b/internal/p2p/address.go @@ -122,9 +122,11 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]Endpoint, error) { } endpoints := make([]Endpoint, len(ips)) for i, ip := range ips { - ip,ok := netip.AddrFromSlice(ip) - fmt.Printf("%v\n",ip) - if !ok { return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) } + ip, ok := netip.AddrFromSlice(ip) + fmt.Printf("%v\n", ip) + if !ok { + return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) + } endpoints[i] = Endpoint{ Protocol: a.Protocol, Addr: netip.AddrPortFrom(ip, a.Port), diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index 121550516..6b660aac1 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -7,8 +7,8 @@ import ( "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/internal/p2p" - "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/libs/utils/require" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/types" ) @@ -208,28 +208,28 @@ func TestNodeAddress_Resolve(t *testing.T) { // Valid networked addresses (with hostname). { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1", Port: 80, Path: "/path"}, - p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(),80), Path: "/path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(), 80), Path: "/path"}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1"}, - p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(),0)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "::1"}, - p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.IPv6Loopback(),0)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.IPv6Loopback(), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "8.8.8.8"}, - p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}),0)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "2001:0db8::ff00:0042:8329"}, p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x42, 0x83, 0x29}),0)}, + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x42, 0x83, 0x29}), 0)}, true, }, { @@ -267,12 +267,12 @@ func TestNodeAddress_Resolve(t *testing.T) { } ok := false tc.expect.Addr = tcp.Norm(tc.expect.Addr) - for _,e := range endpoints { + for _, e := range endpoints { e.Addr = tcp.Norm(e.Addr) - ok = ok || e==tc.expect + ok = ok || e == tc.expect } if !ok { - t.Fatalf("%v not in %v",tc.expect,endpoints) + t.Fatalf("%v not in %v", tc.expect, endpoints) } }) } @@ -284,7 +284,7 @@ func TestNodeAddress_Resolve(t *testing.T) { for _, got := range endpoints { require.True(t, got.Addr.Addr().IsLoopback()) // Any loopback address is acceptable, so ignore it in comparison. - want := p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(got.Addr.Addr(),80), Path: "/path"} + want := p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(got.Addr.Addr(), 80), Path: "/path"} require.Equal(t, want, got) } }) diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 8977dd6c1..83054ecb4 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -19,10 +19,10 @@ import ( "github.com/tendermint/tendermint/internal/libs/flowrate" "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/internal/libs/timer" - "github.com/tendermint/tendermint/libs/utils" "github.com/tendermint/tendermint/libs/log" tmmath "github.com/tendermint/tendermint/libs/math" "github.com/tendermint/tendermint/libs/service" + "github.com/tendermint/tendermint/libs/utils" tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p" ) @@ -317,7 +317,7 @@ func (c *MConnection) Send(ctx context.Context, chID ChannelID, msgBytes []byte) } if err := channel.sendBytes(ctx, msgBytes); err != nil { - return fmt.Errorf("channel.sendBytes(): %v",err) + return fmt.Errorf("channel.sendBytes(): %v", err) } // Wake up sendRoutine if necessary select { @@ -643,11 +643,11 @@ type channel struct { // See https://github.com/tendermint/tendermint/issues/7000. recentlySent int64 - conn *MConnection - desc ChannelDescriptor - sendQueue chan []byte - recving []byte - sending []byte + conn *MConnection + desc ChannelDescriptor + sendQueue chan []byte + recving []byte + sending []byte maxPacketMsgPayloadSize int diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index d4fcfd110..cea53defc 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -284,7 +284,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) }) return &Node{ - Logger: logger, + Logger: logger, NodeID: nodeID, NodeInfo: nodeInfo, NodeAddress: transport.Endpoint().NodeAddress(nodeID), diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 1ee669ea9..caf127cd7 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -18,9 +18,9 @@ import ( dbm "github.com/tendermint/tm-db" tmsync "github.com/tendermint/tendermint/internal/libs/sync" + "github.com/tendermint/tendermint/libs/utils" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" - "github.com/tendermint/tendermint/libs/utils" ) const ( @@ -29,12 +29,12 @@ const ( ) type DialFailuresError struct { - Failures uint32 - Address types.NodeID + Failures uint32 + Address types.NodeID } func (e DialFailuresError) Error() string { - return fmt.Sprintf("dialing failed %d times will not retry for address=%s, deleting peer",e.Failures,e.Address) + return fmt.Sprintf("dialing failed %d times will not retry for address=%s, deleting peer", e.Failures, e.Address) } // PeerStatus is a peer status. @@ -325,7 +325,7 @@ type PeerManager struct { upgrading map[types.NodeID]types.NodeID // peers claimed for upgrade (DialNext → Dialed/DialFail) connected map[types.NodeID]bool // connected peers (Dialed/Accepted → Disconnected) ready map[types.NodeID]bool // ready peers (Ready → Disconnected) - evict map[types.NodeID]error // peers scheduled for eviction (Connected → EvictNext) + evict map[types.NodeID]error // peers scheduled for eviction (Connected → EvictNext) evicting map[types.NodeID]bool // peers being evicted (EvictNext → Disconnected) metrics *Metrics } @@ -797,8 +797,8 @@ func (m *PeerManager) EvictNext(ctx context.Context) (Eviction, error) { if err != nil { return Eviction{}, err } - if ev,ok := ev.Get(); ok { - return ev,nil + if ev, ok := ev.Get(); ok { + return ev, nil } select { case <-m.evictWaker.Sleep(): @@ -809,7 +809,7 @@ func (m *PeerManager) EvictNext(ctx context.Context) (Eviction, error) { } type Eviction struct { - ID types.NodeID + ID types.NodeID Cause error } @@ -821,11 +821,11 @@ func (m *PeerManager) TryEvictNext() (utils.Option[Eviction], error) { // If any connected peers are explicitly scheduled for eviction, we return a // random one. - for peerID,cause := range m.evict { + for peerID, cause := range m.evict { delete(m.evict, peerID) if m.connected[peerID] && !m.evicting[peerID] { m.evicting[peerID] = true - return utils.Some(Eviction{peerID,cause}), nil + return utils.Some(Eviction{peerID, cause}), nil } } @@ -842,7 +842,7 @@ func (m *PeerManager) TryEvictNext() (utils.Option[Eviction], error) { peer := ranked[i] if m.connected[peer.ID] && !m.evicting[peer.ID] { m.evicting[peer.ID] = true - return utils.Some(Eviction{peer.ID,errors.New("too many peers")}), nil + return utils.Some(Eviction{peer.ID, errors.New("too many peers")}), nil } } @@ -1156,7 +1156,7 @@ func (m *PeerManager) findUpgradeCandidate(id types.NodeID, score PeerScore) typ case candidate.Score() >= score: return "" // no further peers can be scored lower, due to sorting case !m.connected[candidate.ID]: - case m.evict[candidate.ID]!=nil: + case m.evict[candidate.ID] != nil: case m.evicting[candidate.ID]: case m.upgrading[candidate.ID] != "": default: diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 3a75bce04..04df86f82 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -15,8 +15,8 @@ import ( dbm "github.com/tendermint/tm-db" "github.com/tendermint/tendermint/internal/p2p" - "github.com/tendermint/tendermint/types" "github.com/tendermint/tendermint/libs/utils/require" + "github.com/tendermint/tendermint/types" ) // FIXME: We should probably have some randomized property-based tests for the @@ -432,7 +432,7 @@ func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { require.GreaterOrEqual(t, elapsed, time.Duration(math.Pow(2, float64(i)))*options.MinRetryTime) } if i == 3 { - if got,err:=(p2p.DialFailuresError{}),peerManager.DialFailed(ctx, a); !errors.As(err, &got) || got.Failures!=4 { + if got, err := (p2p.DialFailuresError{}), peerManager.DialFailed(ctx, a); !errors.As(err, &got) || got.Failures != 4 { t.Errorf("expected 4 failures, got error %v", err) } @@ -1048,7 +1048,7 @@ func TestPeerManager_Dialed_Upgrade(t *testing.T) { // a should now be evicted. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { t.Fatalf("evict = %v, expected %s", evict, a.NodeID) } } @@ -1106,7 +1106,7 @@ func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { require.NoError(t, peerManager.Dialed(c)) evict, err := peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=d.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != d.NodeID { t.Fatalf("evict = %v, expected %s", evict, d.NodeID) } } @@ -1319,7 +1319,7 @@ func TestPeerManager_Accepted_Upgrade(t *testing.T) { // This should cause a to get evicted. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { t.Fatalf("evict = %v, expected %s", evict, a.NodeID) } peerManager.Disconnected(ctx, a.NodeID) @@ -1371,7 +1371,7 @@ func TestPeerManager_Accepted_UpgradeDialing(t *testing.T) { // This should cause a to get evicted, and the dial upgrade to fail. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { t.Fatalf("evict = %v, expected %s", evict, a.NodeID) } require.Error(t, peerManager.Dialed(b)) @@ -1597,7 +1597,7 @@ func TestPeerManager_TryEvictNext(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err = peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { t.Fatalf("evict = %v, expected %s", evict, a.NodeID) } @@ -1702,7 +1702,7 @@ func TestPeerManager_Errored(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err = peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { t.Fatalf("evict = %v, expected %s", evict, a.NodeID) } } @@ -1753,7 +1753,7 @@ func TestPeerManager_Subscribe(t *testing.T) { evict, err := peerManager.TryEvictNext() require.NoError(t, err) - if ev,ok := evict.Get(); !ok || ev.ID!=a.NodeID { + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { t.Fatalf("evict = %v, expected %s", evict, a.NodeID) } diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 138ac3795..f0bb7f1d4 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -148,8 +148,8 @@ func (r *Reactor) SetChannel(ch *p2p.Channel) { // OnStop to ensure the outbound p2p Channels are closed. func (r *Reactor) OnStart(ctx context.Context) error { peerUpdates := r.peerEvents(ctx) - r.Spawn("processPexCh",func(ctx context.Context) error { return r.processPexCh(ctx) }) - r.Spawn("processPeerUpdates",func(ctx context.Context) error { return r.processPeerUpdates(ctx, peerUpdates) }) + r.Spawn("processPexCh", func(ctx context.Context) error { return r.processPexCh(ctx) }) + r.Spawn("processPeerUpdates", func(ctx context.Context) error { return r.processPeerUpdates(ctx, peerUpdates) }) return nil } @@ -165,7 +165,7 @@ func (r *Reactor) processPexCh(ctx context.Context) error { defer close(incoming) iter := r.channel.Receive(ctx) for iter.Next(ctx) { - if err:=utils.Send(ctx, incoming, iter.Envelope()); err!=nil { + if err := utils.Send(ctx, incoming, iter.Envelope()); err != nil { return } } @@ -197,7 +197,7 @@ func (r *Reactor) processPexCh(ctx context.Context) error { // Send a request for more peer addresses. if err := r.sendRequestForPeers(ctx); err != nil { r.logger.Error("failed to send request for peers", "err", err) - if errors.Is(err,NoPeersAvailableError) { + if errors.Is(err, NoPeersAvailableError) { noAvailablePeerFailCounter++ lastNoAvailablePeersTime = time.Now() continue @@ -207,7 +207,7 @@ func (r *Reactor) processPexCh(ctx context.Context) error { noAvailablePeerFailCounter = 0 case envelope, ok := <-incoming: if !ok { - return nil// channel closed + return nil // channel closed } // A request from another peer, or a response to one of our requests. @@ -236,8 +236,10 @@ func (r *Reactor) processPexCh(ctx context.Context) error { // close the p2p PeerUpdatesCh gracefully. func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerUpdates) error { for { - peerUpdate,err:=utils.Recv(ctx,peerUpdates.Updates()) - if err!=nil { return err } + peerUpdate, err := utils.Recv(ctx, peerUpdates.Updates()) + if err != nil { + return err + } r.processPeerUpdate(peerUpdate) } } @@ -285,7 +287,7 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope) for _, pexAddress := range msg.Addresses { peerAddress, err := p2p.ParseNodeAddress(pexAddress.URL) if err != nil { - return 0,fmt.Errorf("PEX parse node address error %s", err) + return 0, fmt.Errorf("PEX parse node address error %s", err) } added, err := r.peerManager.Add(peerAddress) if err != nil { diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 714e0706c..89e6a8788 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -322,7 +322,7 @@ type reactorTestSuite struct { reactors map[types.NodeID]*pex.Reactor pexChannels map[types.NodeID]*p2p.Channel - peerChans map[types.NodeID]chan p2p.PeerUpdate + peerChans map[types.NodeID]chan p2p.PeerUpdate nodes []types.NodeID mocks []types.NodeID @@ -624,7 +624,9 @@ func (r *reactorTestSuite) requireNumberOfPeers( } func (r *reactorTestSuite) connectCycle(ctx context.Context, t *testing.T) { - if r.total==0 { return } + if r.total == 0 { + return + } for i := range r.total { r.connectPeers(ctx, t, i, (i+1)%r.total) } @@ -632,7 +634,7 @@ func (r *reactorTestSuite) connectCycle(ctx context.Context, t *testing.T) { func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { for i := range r.total { - for j := range r.total-1 { + for j := range r.total - 1 { r.connectPeers(ctx, t, i, (i+j+1)%r.total) } } @@ -641,10 +643,10 @@ func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { // Adds enough addresses to peerManagers, so that all nodes are discoverable. func (r *reactorTestSuite) seedAddrs(t *testing.T) { t.Helper() - for i := range r.total-1 { + for i := range r.total - 1 { n1 := r.network.Nodes[r.nodes[i]] n2 := r.network.Nodes[r.nodes[i+1]] - _,err := n1.PeerManager.Add(n2.NodeAddress) + _, err := n1.PeerManager.Add(n2.NodeAddress) require.NoError(t, err) } } @@ -667,7 +669,7 @@ func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourc } // Subscription is for the ctx lifetime. - ctx,cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(ctx) defer cancel() sourceSub := n1.PeerManager.Subscribe(ctx) targetSub := n2.PeerManager.Subscribe(ctx) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 55055b4f4..2b9355503 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -405,7 +405,7 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { if err != nil { return fmt.Errorf("failed to accept connection: %w", err) } - r.metrics.NewConnections.With("direction","in").Add(1) + r.metrics.NewConnections.With("direction", "in").Add(1) incomingAddr := conn.RemoteEndpoint().Addr if err := r.connTracker.AddConn(incomingAddr); err != nil { closeErr := conn.Close() @@ -428,7 +428,6 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) error { incomingAddr := conn.RemoteEndpoint().Addr defer r.connTracker.RemoveConn(incomingAddr) - if err := r.filterPeersIP(ctx, incomingAddr); err != nil { r.logger.Debug("peer filtered by IP", "ip", incomingAddr, "err", err) return nil @@ -590,7 +589,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, if err != nil { r.logger.Debug("failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) } else { - r.metrics.NewConnections.With("direction","out").Add(1) + r.metrics.NewConnections.With("direction", "out").Add(1) r.logger.Debug("dialed peer", "peer", address.NodeID, "endpoint", endpoint) return conn, nil } @@ -781,7 +780,7 @@ func (r *Router) evictPeers(ctx context.Context) error { } for states := range r.peerStates.Lock() { if s, ok := states[ev.ID]; ok { - r.logger.Info("evicting peer", "peer", ev.ID,"cause",ev.Cause) + r.logger.Info("evicting peer", "peer", ev.ID, "cause", ev.Cause) s.cancel() } } @@ -805,7 +804,7 @@ func (r *Router) OnStart(ctx context.Context) error { } } - r.SpawnCritical("transport.Run",func(ctx context.Context) error { + r.SpawnCritical("transport.Run", func(ctx context.Context) error { return r.transport.Run(ctx) }) r.SpawnCritical("dialPeers", func(ctx context.Context) error { return r.dialPeers(ctx) }) @@ -820,7 +819,7 @@ func (r *Router) OnStart(ctx context.Context) error { // router, to prevent blocked channel sends in reactors. Channels are not closed // here, since that would cause any reactor senders to panic, so it is the // sender's responsibility. -func (r *Router) OnStop() { } +func (r *Router) OnStop() {} type ChannelIDSet map[ChannelID]struct{} diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index c1ac098ab..540c165af 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -32,7 +32,6 @@ type Transport interface { // Endpoints returns the local endpoints the transport is listening on. Endpoint() Endpoint - // Accept waits for the next inbound connection on a listening endpoint, blocking // until either a connection is available or the transport is closed. On closure, // io.EOF is returned and further Accept calls are futile. @@ -125,7 +124,7 @@ func NewEndpoint(addr string) (Endpoint, error) { return Endpoint{ Protocol: MConnProtocol, - Addr: addrPort, + Addr: addrPort, }, nil } @@ -136,7 +135,7 @@ func (e Endpoint) NodeAddress(nodeID types.NodeID) NodeAddress { Protocol: e.Protocol, Path: e.Path, } - if e.Addr!=(netip.AddrPort{}) { + if e.Addr != (netip.AddrPort{}) { address.Hostname = e.Addr.Addr().String() address.Port = e.Addr.Port() } @@ -161,10 +160,10 @@ func (e Endpoint) Validate() error { if e.Protocol == "" { return errors.New("endpoint has no protocol") } - if (e.Addr==netip.AddrPort{}) && (e.Path=="") { + if (e.Addr == netip.AddrPort{}) && (e.Path == "") { return errors.New("endpoint has neither path nor IP") } - if e.Addr!=(netip.AddrPort{}) { + if e.Addr != (netip.AddrPort{}) { if !e.Addr.IsValid() { return fmt.Errorf("endpoint has invalid address %q", e.Addr.String()) } diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index f164630ba..0b13916e3 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -17,8 +17,8 @@ import ( "github.com/tendermint/tendermint/internal/p2p/conn" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/utils" - "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/libs/utils/scope" + "github.com/tendermint/tendermint/libs/utils/tcp" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -77,7 +77,7 @@ func NewMConnTransport( // WaitForStart waits until transport starts listening for incoming connections. func (m *MConnTransport) WaitForStart(ctx context.Context) error { - _,_,err := utils.RecvOrClosed(ctx, m.started) + _, _, err := utils.RecvOrClosed(ctx, m.started) return err } @@ -91,7 +91,7 @@ func (m *MConnTransport) Run(ctx context.Context) error { } listener, err := tcp.Listen(m.endpoint.Addr) if err != nil { - return fmt.Errorf("net.Listen(): %w",err) + return fmt.Errorf("net.Listen(): %w", err) } close(m.started) // signal that we are listening if m.options.MaxAcceptedConnections > 0 { @@ -110,14 +110,14 @@ func (m *MConnTransport) Run(ctx context.Context) error { }) for { conn, err := listener.Accept() - if err!=nil { + if err != nil { if errors.Is(err, net.ErrClosed) { return nil } return err } mconn := newMConnConnection(m.logger, conn, m.mConnConfig, m.channelDescs) - if err:=utils.Send(ctx, m.listener, mconn); err!=nil { + if err := utils.Send(ctx, m.listener, mconn); err != nil { mconn.Close() return err } @@ -146,12 +146,12 @@ func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connectio return nil, err } if endpoint.Addr.Port() == 0 { - endpoint.Addr = netip.AddrPortFrom(endpoint.Addr.Addr(),26657) + endpoint.Addr = netip.AddrPortFrom(endpoint.Addr.Addr(), 26657) } dialer := net.Dialer{} tcpConn, err := dialer.DialContext(ctx, "tcp", endpoint.Addr.String()) if err != nil { - return nil,fmt.Errorf("dialer.DialContext(%v): %w", endpoint.Addr, err) + return nil, fmt.Errorf("dialer.DialContext(%v): %w", endpoint.Addr, err) } return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil } @@ -167,7 +167,7 @@ func (m *MConnTransport) AddChannelDescriptors(channelDesc []*ChannelDescriptor) m.channelDescs = append(m.channelDescs, channelDesc...) } -type InvalidEndpointErr struct { error } +type InvalidEndpointErr struct{ error } // validateEndpoint validates an endpoint. func (m *MConnTransport) validateEndpoint(endpoint Endpoint) error { @@ -382,7 +382,7 @@ func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg [ case err := <-c.errorCh: return err default: - if err := c.mconn.Send(ctx, chID, msg); err!=nil { + if err := c.mconn.Send(ctx, chID, msg); err != nil { return fmt.Errorf("m.mconn.Send(%v): %w", chID, err) } return nil diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index f161cdc40..64121809c 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -2,18 +2,18 @@ package p2p_test import ( "context" - "net/netip" + "errors" + "fmt" "io" + "net/netip" "testing" "time" - "fmt" - "errors" - "github.com/tendermint/tendermint/libs/utils/tcp" - "github.com/tendermint/tendermint/libs/utils" - "github.com/tendermint/tendermint/libs/utils/scope" "github.com/fortytw2/leaktest" "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/internal/p2p/conn" @@ -29,18 +29,18 @@ func init() { log.NewNopLogger(), p2p.Endpoint{ Protocol: p2p.MConnProtocol, - Addr: tcp.TestReserveAddr(), + Addr: tcp.TestReserveAddr(), }, conn.DefaultMConnConfig(), []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, p2p.MConnTransportOptions{}, ) go func() { - if err:=transport.Run(ctx); err != nil { + if err := transport.Run(ctx); err != nil { panic(err) } }() - if err:=transport.WaitForStart(ctx); err!=nil { + if err := transport.WaitForStart(ctx); err != nil { panic(err) } return transport @@ -53,23 +53,31 @@ func init() { func connect(ctx context.Context, tr *p2p.MConnTransport) (c1 p2p.Connection, c2 p2p.Connection, err error) { defer func() { if err != nil { - if c1 != nil { c1.Close() } - if c2 != nil { c2.Close() } + if c1 != nil { + c1.Close() + } + if c2 != nil { + c2.Close() + } } }() // Here we are utilizing the fact that MConnTransport accepts connection proactively // before Accept is called. c1, err = tr.Dial(ctx, tr.Endpoint()) - if err != nil { return nil,nil,fmt.Errorf("Dial(): %w", err) } + if err != nil { + return nil, nil, fmt.Errorf("Dial(): %w", err) + } c2, err = tr.Accept(ctx) - if err != nil { return nil,nil,fmt.Errorf("Accept(): %w", err) } - if got,want := c1.LocalEndpoint(),c2.RemoteEndpoint(); got!=want { - return nil,nil,fmt.Errorf("c1.LocalEndpoint() = %v, want %v", got, want) + if err != nil { + return nil, nil, fmt.Errorf("Accept(): %w", err) + } + if got, want := c1.LocalEndpoint(), c2.RemoteEndpoint(); got != want { + return nil, nil, fmt.Errorf("c1.LocalEndpoint() = %v, want %v", got, want) } - if got,want := c1.RemoteEndpoint(),c2.LocalEndpoint(); got!=want { - return nil,nil,fmt.Errorf("c1.RemoteEndpoint() = %v, want %v", got, want) + if got, want := c1.RemoteEndpoint(), c2.LocalEndpoint(); got != want { + return nil, nil, fmt.Errorf("c1.RemoteEndpoint() = %v, want %v", got, want) } - return c1,c2,nil + return c1, c2, nil } func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { @@ -78,7 +86,7 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { log.NewNopLogger(), p2p.Endpoint{ Protocol: p2p.MConnProtocol, - Addr: tcp.TestReserveAddr(), + Addr: tcp.TestReserveAddr(), }, conn.DefaultMConnConfig(), []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, @@ -89,28 +97,34 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { s.SpawnBgNamed("transport", func() error { return transport.Run(ctx) }) - if err:=transport.WaitForStart(ctx); err!=nil { + if err := transport.WaitForStart(ctx); err != nil { return err } t.Logf("The first two connections should be accepted just fine.") - a1,a2,err := connect(ctx, transport) - if err!=nil { return fmt.Errorf("1st connect(): %w", err) } + a1, a2, err := connect(ctx, transport) + if err != nil { + return fmt.Errorf("1st connect(): %w", err) + } defer a1.Close() defer a2.Close() - b1,b2,err := connect(ctx, transport) - if err!=nil { return fmt.Errorf("2nd connect(): %w",err) } + b1, b2, err := connect(ctx, transport) + if err != nil { + return fmt.Errorf("2nd connect(): %w", err) + } defer b1.Close() defer b2.Close() t.Logf("The third connection will be dialed successfully, but the accept should not go through.") c1, err := transport.Dial(ctx, transport.Endpoint()) - if err!=nil { return fmt.Errorf("3rd Dial(): %w", err) } + if err != nil { + return fmt.Errorf("3rd Dial(): %w", err) + } defer c1.Close() if err := utils.WithTimeout(ctx, time.Second, func(ctx context.Context) error { c2, err := transport.Accept(ctx) - if err==nil { + if err == nil { c2.Close() } return err @@ -121,8 +135,10 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { t.Logf("once either of the other connections are closed, the accept goes through.") a1.Close() a2.Close() // we close both a1 and a2 to make sure the connection count drops below the limit. - c2,err := transport.Accept(ctx) - if err!=nil { return fmt.Errorf("3rd Accept(): %w",err) } + c2, err := transport.Accept(ctx) + if err != nil { + return fmt.Errorf("3rd Accept(): %w", err) + } defer c2.Close() return nil })) @@ -165,33 +181,37 @@ func TestMConnTransport_Listen(t *testing.T) { []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, p2p.MConnTransportOptions{}, ) - if got,want := transport.Endpoint(),tc.endpoint; got!=want { + if got, want := transport.Endpoint(), tc.endpoint; got != want { t.Fatalf("transport.Endpoint() = %v, want %v", got, want) } err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { s.SpawnBgNamed("transport", func() error { return transport.Run(ctx) }) - if err:=transport.WaitForStart(ctx); err!=nil { + if err := transport.WaitForStart(ctx); err != nil { return err } - s.SpawnNamed("dial",func() error { + s.SpawnNamed("dial", func() error { conn, err := transport.Dial(ctx, tc.endpoint) - if err != nil { return fmt.Errorf("transport.Dial(): %w", err) } - if err:=conn.Close(); err!=nil { + if err != nil { + return fmt.Errorf("transport.Dial(): %w", err) + } + if err := conn.Close(); err != nil { return fmt.Errorf("conn.Close(): %w", err) } - if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err,io.EOF) { + if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err, io.EOF) { return fmt.Errorf("conn.ReceiveMessage() = %v, want %v", err, io.EOF) } return nil }) - s.SpawnNamed("accept",func() error { + s.SpawnNamed("accept", func() error { conn, err := transport.Accept(ctx) - if err != nil { return fmt.Errorf("transport.Accept(): %w",err) } - if err:=conn.Close(); err!=nil { + if err != nil { + return fmt.Errorf("transport.Accept(): %w", err) + } + if err := conn.Close(); err != nil { return fmt.Errorf("conn.Close(): %w", err) } - if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err,io.EOF) { + if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err, io.EOF) { return fmt.Errorf("conn.ReceiveMessage() = %v, want %v", err, io.EOF) } return nil diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 79b442d66..f4fa08bf5 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -10,8 +10,8 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" - "github.com/tendermint/tendermint/types" "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/types" ) const ( @@ -128,7 +128,7 @@ func (t *MemoryTransport) Endpoint() Endpoint { // Accept implements Transport. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { - return utils.Recv(ctx,t.acceptCh) + return utils.Recv(ctx, t.acceptCh) } // Dial implements Transport. @@ -168,7 +168,7 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connecti inConn.closeCh = closeCh inConn.closeFn = closeFn - if err:=utils.Send(ctx,peer.acceptCh,inConn); err!=nil { + if err := utils.Send(ctx, peer.acceptCh, inConn); err != nil { return nil, err } return outConn, nil diff --git a/internal/p2p/transport_memory_test.go b/internal/p2p/transport_memory_test.go index ef5b1f299..334101bd8 100644 --- a/internal/p2p/transport_memory_test.go +++ b/internal/p2p/transport_memory_test.go @@ -2,8 +2,8 @@ package p2p_test import ( "bytes" - "encoding/hex" "context" + "encoding/hex" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/libs/log" @@ -18,10 +18,12 @@ func init() { return func(ctx context.Context) p2p.Transport { i := byte(network.Size()) nodeID, err := types.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20))) - if err!=nil { panic(err) } + if err != nil { + panic(err) + } t := network.CreateTransport(nodeID) go func() { - if err:=t.Run(ctx); err!=nil { + if err := t.Run(ctx); err != nil { panic(err) } }() diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index 119274a8d..a088d0a4a 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -43,7 +43,7 @@ func TestTransport_DialEndpoints(t *testing.T) { {netip.IPv4Unspecified(), true}, {netip.IPv6Unspecified(), true}, - {netip.AddrFrom4([4]byte{255,255,255,255}), false}, + {netip.AddrFrom4([4]byte{255, 255, 255, 255}), false}, {netip.AddrFrom4([4]byte{224, 0, 0, 1}), false}, } @@ -85,11 +85,11 @@ func TestTransport_DialEndpoints(t *testing.T) { require.Error(t, err) // Tests for networked endpoints (with IP). - if endpoint.Addr!=(netip.AddrPort{}) && endpoint.Protocol != p2p.MemoryProtocol { + if endpoint.Addr != (netip.AddrPort{}) && endpoint.Protocol != p2p.MemoryProtocol { for _, tc := range ipTestCases { t.Run(tc.ip.String(), func(t *testing.T) { e := endpoint - e.Addr = netip.AddrPortFrom(tc.ip,endpoint.Addr.Port()) + e.Addr = netip.AddrPortFrom(tc.ip, endpoint.Addr.Port()) conn, err := a.Dial(ctx, e) if tc.ok { require.NoError(t, err) @@ -325,9 +325,9 @@ func TestConnection_String(t *testing.T) { func TestEndpoint_NodeAddress(t *testing.T) { var ( - ip4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) - ip6 = netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) - id = types.NodeID("00112233445566778899aabbccddeeff00112233") + ip4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 = netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) + id = types.NodeID("00112233445566778899aabbccddeeff00112233") ) testcases := []struct { @@ -336,7 +336,7 @@ func TestEndpoint_NodeAddress(t *testing.T) { }{ // Valid endpoints. { - p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080), Path: "path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080), Path: "path"}, p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, }, { @@ -355,7 +355,7 @@ func TestEndpoint_NodeAddress(t *testing.T) { // Partial (invalid) endpoints. {p2p.Endpoint{}, p2p.NodeAddress{}}, {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4, 0)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}}, } for _, tc := range testcases { @@ -388,23 +388,23 @@ func TestEndpoint_String(t *testing.T) { {p2p.Endpoint{Protocol: "file", Path: "👋"}, "file:///%F0%9F%91%8B"}, // IPv4 endpoints. - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,0)}, "tcp://1.2.3.4"}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080)}, "tcp://1.2.3.4:8080"}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080), Path: "/path"}, "tcp://1.2.3.4:8080/path"}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,0), Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 0)}, "tcp://1.2.3.4"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080)}, "tcp://1.2.3.4:8080"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080), Path: "/path"}, "tcp://1.2.3.4:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 0), Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, // IPv6 endpoints. - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,0)}, "tcp://b10c::1"}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,8080)}, "tcp://[b10c::1]:8080"}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,8080), Path: "/path"}, "tcp://[b10c::1]:8080/path"}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,0), Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 0)}, "tcp://b10c::1"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 8080)}, "tcp://[b10c::1]:8080"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 8080), Path: "/path"}, "tcp://[b10c::1]:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 0), Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, // Partial (invalid) endpoints. {p2p.Endpoint{}, ""}, {p2p.Endpoint{Protocol: "tcp"}, "tcp:"}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, "1.2.3.4"}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(ip6,0)}, "b10c::1"}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(),8080)}, "0.0.0.0:8080"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4, 0)}, "1.2.3.4"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip6, 0)}, "b10c::1"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(), 8080)}, "0.0.0.0:8080"}, {p2p.Endpoint{Path: "foo"}, "/foo"}, } for _, tc := range testcases { @@ -423,15 +423,15 @@ func TestEndpoint_Validate(t *testing.T) { expectValid bool }{ // Valid endpoints. - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,0)}, true}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6,0)}, true}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8008)}, true}, - {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4,8080), Path: "path"}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 0)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 0)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8008)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080), Path: "path"}, true}, {p2p.Endpoint{Protocol: "memory", Path: "path"}, true}, // Invalid endpoints. {p2p.Endpoint{}, false}, - {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4,0)}, false}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4, 0)}, false}, {p2p.Endpoint{Protocol: "tcp"}, false}, } for _, tc := range testcases { diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go index 1bfa00eda..438df3dfd 100644 --- a/libs/utils/require/require.go +++ b/libs/utils/require/require.go @@ -27,7 +27,7 @@ var NotZero = require.NotZero var Contains = require.Contains func ElementsMatch[T any](t TestingT, a []T, b []T, msgAndArgs ...any) { - require.ElementsMatch(t,a,b,msgAndArgs...) + require.ElementsMatch(t, a, b, msgAndArgs...) } // Eventually . @@ -93,7 +93,6 @@ func GreaterOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { require.GreaterOrEqual(t, e1, e2, msgAndArgs...) } - // Equal . func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { require.Equal(t, expected, actual, msgAndArgs...) diff --git a/libs/utils/tcp/tcp.go b/libs/utils/tcp/tcp.go index ccfe745b5..4baef6cb6 100644 --- a/libs/utils/tcp/tcp.go +++ b/libs/utils/tcp/tcp.go @@ -19,7 +19,7 @@ func IPv4Loopback() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } // Norm normalizes address by unmapping IPv4 -> IPv6 embedding. func Norm(addr netip.AddrPort) netip.AddrPort { - return netip.AddrPortFrom(addr.Addr().Unmap(),addr.Port()) + return netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) } // Listen opens a TCP listener on the given address. diff --git a/node/node.go b/node/node.go index 7a3c2d43a..6e175eddd 100644 --- a/node/node.go +++ b/node/node.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" "net" - "net/netip" "net/http" + "net/netip" "strings" "time" From 05dbdb74c62c62375a9bbfbc6e6ce28294c1f45e Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 1 Sep 2025 14:45:53 +0200 Subject: [PATCH 37/41] tidy --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2106ce130..d87a05273 100644 --- a/go.mod +++ b/go.mod @@ -226,7 +226,6 @@ require ( go.uber.org/zap v1.21.0 // indirect golang.org/x/exp/typeparams v0.0.0-20220218215828-6cf2b201936e // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect @@ -254,5 +253,6 @@ require ( go.opentelemetry.io/otel v1.9.0 go.opentelemetry.io/otel/sdk v1.9.0 go.opentelemetry.io/otel/trace v1.9.0 + golang.org/x/sys v0.28.0 google.golang.org/protobuf v1.28.0 ) From 30272db332cd8ff9fb2e9a59f07480c09d858a16 Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Mon, 1 Sep 2025 15:10:29 +0200 Subject: [PATCH 38/41] fixed flaky test --- privval/socket_listeners_test.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/privval/socket_listeners_test.go b/privval/socket_listeners_test.go index e91d111d0..ff9c24cab 100644 --- a/privval/socket_listeners_test.go +++ b/privval/socket_listeners_test.go @@ -6,8 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/tendermint/tendermint/crypto/ed25519" ) @@ -111,8 +109,14 @@ func TestListenerTimeoutReadWrite(t *testing.T) { ) for _, tc := range listenerTestCases(t, timeoutAccept, timeoutReadWrite) { go func(dialer SocketDialer) { - _, err := dialer() - require.NoError(t, err) + conn, err := dialer() + if err != nil { + panic(err) // this is not the main goroutine, so "require.NoError" won't work + } + // If we don't close this properly, the test gets flaky because connection + // closes at random. + defer conn.Close() + <-t.Context().Done() }(tc.dialer) c, err := tc.listener.Accept() From 4dc086a2b84fb75cf2d67725501d7400d0a780ef Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 2 Sep 2025 15:32:27 +0200 Subject: [PATCH 39/41] fixed panic in handshakePeer --- internal/p2p/peermanager.go | 6 ++++++ internal/p2p/router.go | 32 +++++++++++++------------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 14b5f5b7f..f1a79cddc 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -495,6 +495,12 @@ func (m *PeerManager) Add(address NodeAddress) (bool, error) { return true, nil } +func (m *PeerManager) Delete(id types.NodeID) error { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.store.Delete(id) +} + func (m *PeerManager) GetBlockSyncPeers() map[types.NodeID]bool { return m.options.blocksyncPeers } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 2b9355503..3f40d4cb6 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -614,31 +614,32 @@ func (r *Router) handshakePeer( nodeInfo := r.nodeInfoProducer() peerInfo, peerKey, err := conn.Handshake(ctx, *nodeInfo, r.privKey) if err != nil { - return peerInfo, err + return types.NodeInfo{}, err } - if err = peerInfo.Validate(); err != nil { - return peerInfo, fmt.Errorf("invalid handshake NodeInfo: %w", err) + // Authenticate the peer first. + if types.NodeIDFromPubKey(peerKey) != peerInfo.NodeID { + return types.NodeInfo{}, fmt.Errorf("peer's public key did not match its node ID %q (expected %q)", + peerInfo.NodeID, types.NodeIDFromPubKey(peerKey)) } + if err = peerInfo.Validate(); err != nil { + return types.NodeInfo{}, fmt.Errorf("invalid handshake NodeInfo: %w", err) + } if peerInfo.Network != nodeInfo.Network { - if err := r.peerManager.store.Delete(peerInfo.NodeID); err != nil { - return peerInfo, fmt.Errorf("problem removing peer from store from incorrect network [%s]: %w", peerInfo.Network, err) + if err := r.peerManager.Delete(peerInfo.NodeID); err != nil { + return types.NodeInfo{}, fmt.Errorf("problem removing peer from store from incorrect network [%s]: %w", peerInfo.Network, err) } - return peerInfo, fmt.Errorf("connected to peer from wrong network, %q, removed from peer store", peerInfo.Network) + return types.NodeInfo{}, fmt.Errorf("connected to peer from wrong network, %q, removed from peer store", peerInfo.Network) } - if types.NodeIDFromPubKey(peerKey) != peerInfo.NodeID { - return peerInfo, fmt.Errorf("peer's public key did not match its node ID %q (expected %q)", - peerInfo.NodeID, types.NodeIDFromPubKey(peerKey)) - } if expectID != "" && expectID != peerInfo.NodeID { - return peerInfo, fmt.Errorf("expected to connect with peer %q, got %q", + return types.NodeInfo{}, fmt.Errorf("expected to connect with peer %q, got %q", expectID, peerInfo.NodeID) } if err := nodeInfo.CompatibleWith(peerInfo); err != nil { - return peerInfo, ErrRejected{ + return types.NodeInfo{}, ErrRejected{ err: err, id: peerInfo.ID(), isIncompatible: true, @@ -647,13 +648,6 @@ func (r *Router) handshakePeer( return peerInfo, nil } -func (r *Router) runWithPeerMutex(fn func() error) error { - for range r.peerStates.Lock() { - return fn() - } - panic("unreachable") -} - // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. From 0effa48b40af130f1117df48c09cc83caae4992e Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Tue, 2 Sep 2025 16:40:46 +0200 Subject: [PATCH 40/41] fixed test --- internal/p2p/router_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 50b15957d..36fe90254 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -345,7 +345,7 @@ func TestRouter_AcceptPeers(t *testing.T) { ok bool }{ "valid handshake": {peerInfo, peerKey.PubKey(), true}, - "empty handshake": {types.NodeInfo{}, nil, false}, + "empty handshake": {types.NodeInfo{}, peerKey.PubKey(), false}, "invalid key": {peerInfo, selfKey.PubKey(), false}, "self handshake": {selfInfo, selfKey.PubKey(), false}, "incompatible peer": { @@ -528,7 +528,7 @@ func TestRouter_DialPeers(t *testing.T) { ok bool }{ "valid dial": {peerInfo.NodeID, peerInfo, peerKey.PubKey(), nil, true}, - "empty handshake": {peerInfo.NodeID, types.NodeInfo{}, nil, nil, false}, + "empty handshake": {peerInfo.NodeID, types.NodeInfo{}, peerKey.PubKey(), nil, false}, "invalid key": {peerInfo.NodeID, peerInfo, selfKey.PubKey(), nil, false}, "unexpected node ID": {peerInfo.NodeID, selfInfo, selfKey.PubKey(), nil, false}, "dial error": {peerInfo.NodeID, peerInfo, peerKey.PubKey(), errors.New("boom"), false}, From 9392ba7f37baed0a13c53ed87ec8c6cc8e8f454a Mon Sep 17 00:00:00 2001 From: Grzegorz Prusak Date: Wed, 3 Sep 2025 12:57:01 +0200 Subject: [PATCH 41/41] removed log --- internal/p2p/address.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/p2p/address.go b/internal/p2p/address.go index d67f3c906..18d95100b 100644 --- a/internal/p2p/address.go +++ b/internal/p2p/address.go @@ -123,7 +123,6 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]Endpoint, error) { endpoints := make([]Endpoint, len(ips)) for i, ip := range ips { ip, ok := netip.AddrFromSlice(ip) - fmt.Printf("%v\n", ip) if !ok { return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) }