From 027a00eb26a3899f3b4a8d67732c0edaf2d134a3 Mon Sep 17 00:00:00 2001 From: Cyson Date: Wed, 16 Nov 2022 17:13:24 -0800 Subject: [PATCH] Remove memstate for order cancellation --- x/dex/cache/cache.go | 50 ++++++----- x/dex/cache/cache_test.go | 5 ++ x/dex/cache/cancel.go | 62 +++++++++++--- x/dex/cache/cancel_test.go | 82 ++++++++++++++----- x/dex/contract/whitelist.go | 1 + .../msgserver/msg_server_cancel_orders.go | 9 +- x/dex/types/cancel.go | 5 -- x/dex/types/keys.go | 8 ++ 8 files changed, 155 insertions(+), 67 deletions(-) delete mode 100644 x/dex/types/cancel.go diff --git a/x/dex/cache/cache.go b/x/dex/cache/cache.go index eb51bc85f9..7102e23046 100644 --- a/x/dex/cache/cache.go +++ b/x/dex/cache/cache.go @@ -66,23 +66,13 @@ func (i *memStateItems[T]) Copy() *memStateItems[T] { } type MemState struct { - storeKey sdk.StoreKey - blockCancels *datastructures.TypedNestedSyncMap[ - typesutils.ContractAddress, - typesutils.PairString, - *BlockCancellations, - ] + storeKey sdk.StoreKey depositInfo *datastructures.TypedSyncMap[typesutils.ContractAddress, *DepositInfo] } func NewMemState(storeKey sdk.StoreKey) *MemState { return &MemState{ - storeKey: storeKey, - blockCancels: datastructures.NewTypedNestedSyncMap[ - typesutils.ContractAddress, - typesutils.PairString, - *BlockCancellations, - ](), + storeKey: storeKey, depositInfo: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo](), } } @@ -123,8 +113,14 @@ func (s *MemState) GetBlockOrders(ctx sdk.Context, contractAddr typesutils.Contr func (s *MemState) GetBlockCancels(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockCancellations { s.SynchronizeAccess(ctx, contractAddr) - cancelsForPair, _ := s.blockCancels.LoadOrStoreNested(contractAddr, pair, NewCancels()) - return cancelsForPair + return NewCancels( + prefix.NewStore( + ctx.KVStore(s.storeKey), + types.MemCancelPrefixForPair( + string(contractAddr), string(pair), + ), + ), + ) } func (s *MemState) GetDepositInfo(ctx sdk.Context, contractAddr typesutils.ContractAddress) *DepositInfo { @@ -139,22 +135,26 @@ func (s *MemState) GetDepositInfo(ctx sdk.Context, contractAddr typesutils.Contr func (s *MemState) Clear(ctx sdk.Context) { DeepDelete(ctx.KVStore(s.storeKey), types.KeyPrefix(types.MemOrderKey), func(_ []byte) bool { return true }) - s.blockCancels = datastructures.NewTypedNestedSyncMap[ - typesutils.ContractAddress, - typesutils.PairString, - *BlockCancellations, - ]() + DeepDelete(ctx.KVStore(s.storeKey), types.KeyPrefix(types.MemCancelKey), func(_ []byte) bool { return true }) DeepDelete(ctx.KVStore(s.storeKey), types.KeyPrefix(types.MemDepositKey), func(_ []byte) bool { return true }) } func (s *MemState) ClearCancellationForPair(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) { s.SynchronizeAccess(ctx, contractAddr) - s.blockCancels.StoreNested(contractAddr, pair, NewCancels()) + DeepDelete(ctx.KVStore(s.storeKey), types.KeyPrefix(types.MemCancelKey), func(v []byte) bool { + var c types.Cancellation + if err := c.Unmarshal(v); err != nil { + panic(err) + } + return c.ContractAddr == string(contractAddr) && typesutils.GetPairString(&types.Pair{ + AssetDenom: c.AssetDenom, + PriceDenom: c.PriceDenom, + }) == pair + }) } func (s *MemState) DeepCopy() *MemState { copy := NewMemState(s.storeKey) - copy.blockCancels = s.blockCancels.DeepCopy(func(o *BlockCancellations) *BlockCancellations { return o.Copy() }) return copy } @@ -166,7 +166,13 @@ func (s *MemState) DeepFilterAccount(ctx sdk.Context, account string) { } return o.Account == account }) - s.blockCancels.DeepApply(func(o *BlockCancellations) { o.FilterByAccount(account) }) + DeepDelete(ctx.KVStore(s.storeKey), types.KeyPrefix(types.MemCancelKey), func(v []byte) bool { + var c types.Cancellation + if err := c.Unmarshal(v); err != nil { + panic(err) + } + return c.Creator == account + }) DeepDelete(ctx.KVStore(s.storeKey), types.KeyPrefix(types.MemDepositKey), func(v []byte) bool { var d types.DepositInfoEntry if err := d.Unmarshal(v); err != nil { diff --git a/x/dex/cache/cache_test.go b/x/dex/cache/cache_test.go index 0f39df5a77..840ad4349b 100644 --- a/x/dex/cache/cache_test.go +++ b/x/dex/cache/cache_test.go @@ -82,8 +82,13 @@ func TestClear(t *testing.T) { Account: "test", ContractAddr: TEST_CONTRACT, }) + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + Id: 2, + ContractAddr: TEST_CONTRACT, + }) stateOne.Clear(ctx) require.Equal(t, 0, len(stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) + require.Equal(t, 0, len(stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) } func TestSynchronization(t *testing.T) { diff --git a/x/dex/cache/cancel.go b/x/dex/cache/cancel.go index 95fa2d748d..3036bc242e 100644 --- a/x/dex/cache/cancel.go +++ b/x/dex/cache/cancel.go @@ -1,29 +1,65 @@ package dex import ( - "github.com/sei-protocol/sei-chain/utils" + "encoding/binary" + + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/sei-protocol/sei-chain/x/dex/types" ) type BlockCancellations struct { - memStateItems[*types.Cancellation] + cancelStore *prefix.Store +} + +func NewCancels(cancelStore prefix.Store) *BlockCancellations { + return &BlockCancellations{cancelStore: &cancelStore} } -func NewCancels() *BlockCancellations { - return &BlockCancellations{memStateItems: NewItems(utils.PtrCopier[types.Cancellation])} +func (o *BlockCancellations) Has(cancel *types.Cancellation) bool { + keybz := make([]byte, 8) + binary.BigEndian.PutUint64(keybz, cancel.Id) + return o.cancelStore.Has(keybz) } -func (o *BlockCancellations) Copy() *BlockCancellations { - return &BlockCancellations{memStateItems: *o.memStateItems.Copy()} +func (o *BlockCancellations) Get() (list []*types.Cancellation) { + iterator := sdk.KVStorePrefixIterator(o.cancelStore, []byte{}) + + defer iterator.Close() + + for ; iterator.Valid(); iterator.Next() { + var val types.Cancellation + if err := val.Unmarshal(iterator.Value()); err != nil { + panic(err) + } + list = append(list, &val) + } + + return } -func (o *BlockCancellations) GetIdsToCancel() []uint64 { - o.mu.Lock() - defer o.mu.Unlock() +func (o *BlockCancellations) GetIdsToCancel() (list []uint64) { + iterator := sdk.KVStorePrefixIterator(o.cancelStore, []byte{}) + + defer iterator.Close() + + for ; iterator.Valid(); iterator.Next() { + var val types.Cancellation + if err := val.Unmarshal(iterator.Value()); err != nil { + panic(err) + } + list = append(list, val.Id) + } + + return +} - res := []uint64{} - for _, cancel := range o.internal { - res = append(res, cancel.Id) +func (o *BlockCancellations) Add(newItem *types.Cancellation) { + keybz := make([]byte, 8) + binary.BigEndian.PutUint64(keybz, newItem.Id) + if valbz, err := newItem.Marshal(); err != nil { + panic(err) + } else { + o.cancelStore.Set(keybz, valbz) } - return res } diff --git a/x/dex/cache/cancel_test.go b/x/dex/cache/cancel_test.go index 67a0b44f08..b284c0d422 100644 --- a/x/dex/cache/cancel_test.go +++ b/x/dex/cache/cancel_test.go @@ -3,31 +3,75 @@ package dex_test import ( "testing" + keepertest "github.com/sei-protocol/sei-chain/testutil/keeper" dex "github.com/sei-protocol/sei-chain/x/dex/cache" "github.com/sei-protocol/sei-chain/x/dex/types" + "github.com/sei-protocol/sei-chain/x/dex/types/utils" "github.com/stretchr/testify/require" ) -func TestCancelCopy(t *testing.T) { - cancels := dex.NewCancels() - cancel := types.Cancellation{ - Id: 1, - Creator: "abc", - } - cancels.Add(&cancel) - copy := cancels.Copy() - copy.Get()[0].Id = 2 - require.Equal(t, uint64(1), cancel.Id) -} - func TestCancelGetIdsToCancel(t *testing.T) { - cancels := dex.NewCancels() - cancel := types.Cancellation{ - Id: 1, - Creator: "abc", - } - cancels.Add(&cancel) - ids := cancels.GetIdsToCancel() + keeper, ctx := keepertest.DexKeeper(t) + stateOne := dex.NewMemState(keeper.GetStoreKey()) + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + Id: 1, + Creator: "abc", + ContractAddr: TEST_CONTRACT, + }) + ids := stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetIdsToCancel() require.Equal(t, 1, len(ids)) require.Equal(t, uint64(1), ids[0]) } + +func TestCancelGetCancels(t *testing.T) { + keeper, ctx := keepertest.DexKeeper(t) + stateOne := dex.NewMemState(keeper.GetStoreKey()) + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + Id: 1, + Creator: "abc", + ContractAddr: TEST_CONTRACT, + }) + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + Id: 2, + Creator: "def", + ContractAddr: TEST_CONTRACT, + }) + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + Id: 3, + Creator: "efg", + ContractAddr: TEST_CONTRACT, + }) + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + Id: 4, + Creator: "efg", + ContractAddr: TEST_CONTRACT, + }) + + cancels := stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get() + require.Equal(t, 4, len(cancels)) + require.True(t, stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Has(&types.Cancellation{ + Id: 1, + Creator: "abc", + ContractAddr: TEST_CONTRACT, + })) + require.True(t, stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Has(&types.Cancellation{ + Id: 2, + Creator: "def", + ContractAddr: TEST_CONTRACT, + })) + require.True(t, stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Has(&types.Cancellation{ + Id: 3, + Creator: "efg", + ContractAddr: TEST_CONTRACT, + })) + require.True(t, stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Has(&types.Cancellation{ + Id: 4, + Creator: "efg", + ContractAddr: TEST_CONTRACT, + })) + require.False(t, stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Has(&types.Cancellation{ + Id: 5, + Creator: "efg", + ContractAddr: TEST_CONTRACT, + })) +} diff --git a/x/dex/contract/whitelist.go b/x/dex/contract/whitelist.go index c423d48d2f..b083032cf6 100644 --- a/x/dex/contract/whitelist.go +++ b/x/dex/contract/whitelist.go @@ -23,6 +23,7 @@ var DexWhitelistedKeys = []string{ keeper.ContractPrefixKey, types.MemOrderKey, types.MemDepositKey, + types.MemCancelKey, } var WasmWhitelistedKeys = []string{ diff --git a/x/dex/keeper/msgserver/msg_server_cancel_orders.go b/x/dex/keeper/msgserver/msg_server_cancel_orders.go index d0dfbbe382..416e172464 100644 --- a/x/dex/keeper/msgserver/msg_server_cancel_orders.go +++ b/x/dex/keeper/msgserver/msg_server_cancel_orders.go @@ -36,14 +36,7 @@ func (k msgServer) CancelOrders(goCtx context.Context, msg *types.MsgCancelOrder pair := types.Pair{PriceDenom: cancellation.PriceDenom, AssetDenom: cancellation.AssetDenom} pairStr := typesutils.GetPairString(&pair) pairBlockCancellations := dexutils.GetMemState(ctx.Context()).GetBlockCancels(ctx, typesutils.ContractAddress(msg.GetContractAddr()), pairStr) - cancelledInCurrentBlock := false - for _, cancelInCurrentBlock := range pairBlockCancellations.Get() { - if cancelInCurrentBlock.Id == cancellation.Id { - cancelledInCurrentBlock = true - break - } - } - if !cancelledInCurrentBlock { + if !pairBlockCancellations.Has(cancellation) { // only cancel if it's not cancelled in a previous tx in the same block cancel := types.Cancellation{ Id: cancellation.Id, diff --git a/x/dex/types/cancel.go b/x/dex/types/cancel.go deleted file mode 100644 index bd4ec526c3..0000000000 --- a/x/dex/types/cancel.go +++ /dev/null @@ -1,5 +0,0 @@ -package types - -func (c *Cancellation) GetAccount() string { - return c.Creator -} diff --git a/x/dex/types/keys.go b/x/dex/types/keys.go index 6d65a090cc..348e0a06a1 100644 --- a/x/dex/types/keys.go +++ b/x/dex/types/keys.go @@ -144,6 +144,13 @@ func MemOrderPrefixForPair(contractAddr string, pairString string) []byte { ) } +func MemCancelPrefixForPair(contractAddr string, pairString string) []byte { + return append( + append(KeyPrefix(MemCancelKey), KeyPrefix(contractAddr)...), + []byte(pairString)..., + ) +} + func MemOrderPrefix(contractAddr string) []byte { return append(KeyPrefix(MemOrderKey), KeyPrefix(contractAddr)...) } @@ -181,4 +188,5 @@ const ( MemOrderKey = "MemOrder-" MemDepositKey = "MemDeposit-" + MemCancelKey = "MemCancel-" )