diff --git a/changelog.md b/changelog.md index dfd2066f08..7f54957e3e 100644 --- a/changelog.md +++ b/changelog.md @@ -21,6 +21,7 @@ ### Fixes * [2654](https://github.com/zeta-chain/node/pull/2654) - add validation for authorization list in when validating genesis state for authorization module +* [2672](https://github.com/zeta-chain/node/pull/2672) - check observer set for duplicates when adding a new observer or updating an existing one ## v19.0.0 diff --git a/x/observer/keeper/msg_server_add_observer.go b/x/observer/keeper/msg_server_add_observer.go index 11f7701dde..2c7368eade 100644 --- a/x/observer/keeper/msg_server_add_observer.go +++ b/x/observer/keeper/msg_server_add_observer.go @@ -52,13 +52,15 @@ func (k msgServer) AddObserver( return &types.MsgAddObserverResponse{}, nil } - k.AddObserverToSet(ctx, msg.ObserverAddress) - observerSet, _ := k.GetObserverSet(ctx) + // Add observer to the observer set and update the observer count + count, err := k.AddObserverToSet(ctx, msg.ObserverAddress) + if err != nil { + return &types.MsgAddObserverResponse{}, err + } - k.SetLastObserverCount(ctx, &types.LastObserverCount{Count: observerSet.LenUint()}) EmitEventAddObserver( ctx, - observerSet.LenUint(), + count, msg.ObserverAddress, granteeAddress.String(), msg.ZetaclientGranteePubkey, diff --git a/x/observer/keeper/msg_server_add_observer_test.go b/x/observer/keeper/msg_server_add_observer_test.go index b26a2fe5d0..183a97c8e8 100644 --- a/x/observer/keeper/msg_server_add_observer_test.go +++ b/x/observer/keeper/msg_server_add_observer_test.go @@ -52,7 +52,37 @@ func TestMsgServer_AddObserver(t *testing.T) { require.Equal(t, &types.MsgAddObserverResponse{}, res) }) - t.Run("should add if add node account only false", func(t *testing.T) { + t.Run("unable to add observer if observer already exists", func(t *testing.T) { + //ARRANGE + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + observerAddress := sample.AccAddress() + wctx := sdk.WrapSDKContext(ctx) + + _, found := k.GetLastObserverCount(ctx) + require.False(t, found) + srv := keeper.NewMsgServerImpl(*k) + k.SetObserverSet(ctx, types.ObserverSet{ObserverList: []string{observerAddress}}) + + msg := types.MsgAddObserver{ + Creator: admin, + ZetaclientGranteePubkey: sample.PubKeyString(), + AddNodeAccountOnly: false, + ObserverAddress: observerAddress, + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + + // ACT + _, err := srv.AddObserver(wctx, &msg) + + // ASSERT + require.ErrorIs(t, err, types.ErrDuplicateObserver) + }) + + t.Run("should add observer if add node account only false", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ UseAuthorityMock: true, }) diff --git a/x/observer/keeper/msg_server_update_observer_test.go b/x/observer/keeper/msg_server_update_observer_test.go index 2a4308590f..66e646eb53 100644 --- a/x/observer/keeper/msg_server_update_observer_test.go +++ b/x/observer/keeper/msg_server_update_observer_test.go @@ -73,6 +73,61 @@ func TestMsgServer_UpdateObserver(t *testing.T) { require.Equal(t, newOperatorAddress.String(), acc.Operator) }) + t.Run( + "unable to update a tombstoned observer if the new address already exists in the observer set", + func(t *testing.T) { + //ARRANGE + k, ctx, _, _ := keepertest.ObserverKeeper(t) + srv := keeper.NewMsgServerImpl(*k) + // #nosec G404 test purpose - weak randomness is not an issue here + r := rand.New(rand.NewSource(9)) + // Set validator in the store + validator := sample.Validator(t, r) + validatorNew := sample.Validator(t, r) + validatorNew.Status = stakingtypes.Bonded + k.GetStakingKeeper().SetValidator(ctx, validatorNew) + k.GetStakingKeeper().SetValidator(ctx, validator) + + consAddress, err := validator.GetConsAddr() + require.NoError(t, err) + k.GetSlashingKeeper().SetValidatorSigningInfo(ctx, consAddress, slashingtypes.ValidatorSigningInfo{ + Address: consAddress.String(), + StartHeight: 0, + JailedUntil: ctx.BlockHeader().Time.Add(1000000 * time.Second), + Tombstoned: true, + MissedBlocksCounter: 1, + }) + + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + newOperatorAddress, err := types.GetAccAddressFromOperatorAddress(validatorNew.OperatorAddress) + require.NoError(t, err) + + observerList := []string{accAddressOfValidator.String(), newOperatorAddress.String()} + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: observerList, + }) + k.SetNodeAccount(ctx, types.NodeAccount{ + Operator: accAddressOfValidator.String(), + }) + k.SetLastObserverCount(ctx, &types.LastObserverCount{ + Count: uint64(len(observerList)), + }) + + //ACT + _, err = srv.UpdateObserver(sdk.WrapSDKContext(ctx), &types.MsgUpdateObserver{ + Creator: accAddressOfValidator.String(), + OldObserverAddress: accAddressOfValidator.String(), + NewObserverAddress: newOperatorAddress.String(), + UpdateReason: types.ObserverUpdateReason_Tombstoned, + }) + + // ASSERT + require.ErrorContains(t, err, types.ErrDuplicateObserver.Error()) + }, + ) + t.Run("unable to update to a non validator address", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) srv := keeper.NewMsgServerImpl(*k) diff --git a/x/observer/keeper/observer_set.go b/x/observer/keeper/observer_set.go index c8a22e0e0f..4a6a7044f1 100644 --- a/x/observer/keeper/observer_set.go +++ b/x/observer/keeper/observer_set.go @@ -1,6 +1,7 @@ package keeper import ( + "cosmossdk.io/errors" "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" @@ -36,23 +37,29 @@ func (k Keeper) IsAddressPartOfObserverSet(ctx sdk.Context, address string) bool return false } -func (k Keeper) AddObserverToSet(ctx sdk.Context, address string) { +// AddObserverToSet adds an observer to the observer set.It makes sure the updated observer set is valid. +// It also sets the observer count and returns the updated length of the observer set. +func (k Keeper) AddObserverToSet(ctx sdk.Context, address string) (uint64, error) { observerSet, found := k.GetObserverSet(ctx) if !found { - k.SetObserverSet(ctx, types.ObserverSet{ - ObserverList: []string{address}, - }) - return - } - for _, addr := range observerSet.ObserverList { - if addr == address { - return + observerSet = types.ObserverSet{ + ObserverList: []string{}, } } + observerSet.ObserverList = append(observerSet.ObserverList, address) + if err := observerSet.Validate(); err != nil { + return 0, err + } + k.SetObserverSet(ctx, observerSet) + newCount := observerSet.LenUint() + k.SetLastObserverCount(ctx, &types.LastObserverCount{Count: newCount}) + + return newCount, nil } +// RemoveObserverFromSet removes an observer from the observer set. func (k Keeper) RemoveObserverFromSet(ctx sdk.Context, address string) { observerSet, found := k.GetObserverSet(ctx) if !found { @@ -67,17 +74,28 @@ func (k Keeper) RemoveObserverFromSet(ctx sdk.Context, address string) { } } +// UpdateObserverAddress updates an observer address in the observer set.It makes sure the updated observer set is valid. func (k Keeper) UpdateObserverAddress(ctx sdk.Context, oldObserverAddress, newObserverAddress string) error { observerSet, found := k.GetObserverSet(ctx) if !found { return types.ErrObserverSetNotFound } + found = false for i, addr := range observerSet.ObserverList { if addr == oldObserverAddress { observerSet.ObserverList[i] = newObserverAddress - k.SetObserverSet(ctx, observerSet) - return nil + found = true + break } } - return types.ErrUpdateObserver + if !found { + return errors.Wrapf(types.ErrObserverNotFound, "observer %s", oldObserverAddress) + } + + err := observerSet.Validate() + if err != nil { + return errors.Wrap(types.ErrUpdateObserver, err.Error()) + } + k.SetObserverSet(ctx, observerSet) + return nil } diff --git a/x/observer/keeper/observer_set_test.go b/x/observer/keeper/observer_set_test.go index 3ae4a99cf2..41366943f4 100644 --- a/x/observer/keeper/observer_set_test.go +++ b/x/observer/keeper/observer_set_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/x/observer/types" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -35,33 +36,64 @@ func TestKeeper_IsAddressPartOfObserverSet(t *testing.T) { func TestKeeper_AddObserverToSet(t *testing.T) { t.Run("add observer to set", func(t *testing.T) { + // ARRANGE k, ctx, _, _ := keepertest.ObserverKeeper(t) os := sample.ObserverSet(10) k.SetObserverSet(ctx, os) newObserver := sample.AccAddress() - k.AddObserverToSet(ctx, newObserver) + + // ACT + countReturned, err := k.AddObserverToSet(ctx, newObserver) + + // ASSERT + require.NoError(t, err) require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver)) require.False(t, k.IsAddressPartOfObserverSet(ctx, sample.AccAddress())) osNew, found := k.GetObserverSet(ctx) require.True(t, found) require.Len(t, osNew.ObserverList, len(os.ObserverList)+1) + count, found := k.GetLastObserverCount(ctx) + require.True(t, found) + require.Equal(t, osNew.LenUint(), count.Count) + require.Equal(t, osNew.LenUint(), countReturned) }) t.Run("add observer to set if set doesn't exist", func(t *testing.T) { + // ARRANGE k, ctx, _, _ := keepertest.ObserverKeeper(t) newObserver := sample.AccAddress() - k.AddObserverToSet(ctx, newObserver) + + // ACT + countReturned, err := k.AddObserverToSet(ctx, newObserver) + + // ASSERT + require.NoError(t, err) require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver)) osNew, found := k.GetObserverSet(ctx) require.True(t, found) require.Len(t, osNew.ObserverList, 1) + count, found := k.GetLastObserverCount(ctx) + require.True(t, found) + require.Equal(t, osNew.LenUint(), count.Count) + require.Equal(t, osNew.LenUint(), countReturned) + }) - // add same address again, len doesn't change - k.AddObserverToSet(ctx, newObserver) + t.Run("cannot add observer to set the address is already part of the set", func(t *testing.T) { + // ARRANGE + k, ctx, _, _ := keepertest.ObserverKeeper(t) + newObserver := sample.AccAddress() + _, err := k.AddObserverToSet(ctx, newObserver) + require.NoError(t, err) require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver)) - osNew, found = k.GetObserverSet(ctx) + osNew, found := k.GetObserverSet(ctx) require.True(t, found) require.Len(t, osNew.ObserverList, 1) + + // ACT + _, err = k.AddObserverToSet(ctx, newObserver) + + // ASSERT + require.ErrorIs(t, err, types.ErrDuplicateObserver) }) } @@ -95,6 +127,33 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) { require.True(t, found) require.Equal(t, newObserverAddress, observerSet.ObserverList[len(observerSet.ObserverList)-1]) }) + t.Run("unable to update observer list observe set not found", func(t *testing.T) { + // ARRANGE + k, ctx, _, _ := keepertest.ObserverKeeper(t) + oldObserverAddress := sample.AccAddress() + newObserverAddress := sample.AccAddress() + + // ACT + err := k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress) + + // ASSERT + require.ErrorIs(t, err, types.ErrObserverSetNotFound) + }) + t.Run("unable to update observer list if the new list is not valid", func(t *testing.T) { + // ARRANGE + k, ctx, _, _ := keepertest.ObserverKeeper(t) + oldObserverAddress := sample.AccAddress() + newObserverAddress := sample.AccAddress() + observerSet := sample.ObserverSet(10) + observerSet.ObserverList = append(observerSet.ObserverList, []string{oldObserverAddress, newObserverAddress}...) + k.SetObserverSet(ctx, observerSet) + + // ACT + err := k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress) + + // ASSERT + require.ErrorContains(t, err, types.ErrDuplicateObserver.Error()) + }) t.Run("should error if observer address not found", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) oldObserverAddress := sample.AccAddress() @@ -103,7 +162,7 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) { observerSet.ObserverList = append(observerSet.ObserverList, oldObserverAddress) k.SetObserverSet(ctx, observerSet) err := k.UpdateObserverAddress(ctx, sample.AccAddress(), newObserverAddress) - require.Error(t, err) + require.ErrorIs(t, err, types.ErrObserverNotFound) }) t.Run("update observer address long observerList", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) diff --git a/x/observer/types/errors.go b/x/observer/types/errors.go index 6485e613ed..218521f242 100644 --- a/x/observer/types/errors.go +++ b/x/observer/types/errors.go @@ -46,7 +46,10 @@ var ( ErrObserverSetNotFound = errorsmod.Register(ModuleName, 1130, "observer set not found") ErrTssNotFound = errorsmod.Register(ModuleName, 1131, "tss not found") - ErrInboundDisabled = errorsmod.Register(ModuleName, 1132, "inbound tx processing is disabled") - ErrInvalidZetaCoinTypes = errorsmod.Register(ModuleName, 1133, "invalid zeta coin types") - ErrNotObserver = errorsmod.Register(ModuleName, 1134, "sender is not an observer") + ErrInboundDisabled = errorsmod.Register(ModuleName, 1132, "inbound tx processing is disabled") + ErrInvalidZetaCoinTypes = errorsmod.Register(ModuleName, 1133, "invalid zeta coin types") + ErrNotObserver = errorsmod.Register(ModuleName, 1134, "sender is not an observer") + ErrDuplicateObserver = errorsmod.Register(ModuleName, 1135, "observer already exists") + ErrObserverNotFound = errorsmod.Register(ModuleName, 1136, "observer not found") + ErrInvalidObserverAddress = errorsmod.Register(ModuleName, 1137, "invalid observer address") ) diff --git a/x/observer/types/observer_set.go b/x/observer/types/observer_set.go index ffa2d1c05a..db9473f187 100644 --- a/x/observer/types/observer_set.go +++ b/x/observer/types/observer_set.go @@ -1,6 +1,7 @@ package types import ( + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/zeta-chain/zetacore/pkg/chains" @@ -14,13 +15,23 @@ func (m *ObserverSet) LenUint() uint64 { return uint64(len(m.ObserverList)) } -// Validate observer mapper contains an existing chain +// Validate observer set verifies that the observer set is valid +// - All observer addresses are valid +// - No duplicate observer addresses func (m *ObserverSet) Validate() error { + observers := make(map[string]struct{}) for _, observerAddress := range m.ObserverList { + // Check for valid observer addresses _, err := sdk.AccAddressFromBech32(observerAddress) if err != nil { - return err + return errors.Wrapf(ErrInvalidObserverAddress, "observer %s err %s", observerAddress, err.Error()) } + // Check for duplicates + if _, ok := observers[observerAddress]; ok { + return errors.Wrapf(ErrDuplicateObserver, "observer %s", observerAddress) + } + + observers[observerAddress] = struct{}{} } return nil } diff --git a/x/observer/types/observer_set_test.go b/x/observer/types/observer_set_test.go index 69a9a19f96..8344757b18 100644 --- a/x/observer/types/observer_set_test.go +++ b/x/observer/types/observer_set_test.go @@ -10,17 +10,39 @@ import ( "github.com/zeta-chain/zetacore/x/observer/types" ) -func TestObserverSet(t *testing.T) { - observerSet := sample.ObserverSet(4) +func TestObserverSet_Validate(t *testing.T) { + observer1Address := sample.AccAddress() + tt := []struct { + name string + observer types.ObserverSet + wantErr require.ErrorAssertionFunc + }{ + { + name: "observer set with duplicate observer", + observer: types.ObserverSet{ObserverList: []string{observer1Address, observer1Address}}, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, types.ErrDuplicateObserver) + }, + }, + { + name: "observer set with invalid observer", + observer: types.ObserverSet{ObserverList: []string{"invalid"}}, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "decoding bech32 failed") + }, + }, + { + name: "observer set with valid observer", + observer: types.ObserverSet{ObserverList: []string{observer1Address}}, + wantErr: require.NoError, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + tc.wantErr(t, tc.observer.Validate()) + }) - require.Equal(t, int(4), observerSet.Len()) - require.Equal(t, uint64(4), observerSet.LenUint()) - err := observerSet.Validate() - require.NoError(t, err) - - observerSet.ObserverList[0] = "invalid" - err = observerSet.Validate() - require.Error(t, err) + } } func TestCheckReceiveStatus(t *testing.T) {