diff --git a/store/whitelist/cachemulti/store.go b/store/whitelist/cachemulti/store.go index dbfeeb3234..c1404a4fab 100644 --- a/store/whitelist/cachemulti/store.go +++ b/store/whitelist/cachemulti/store.go @@ -13,10 +13,10 @@ type sdkCacheMultiStore = storetypes.CacheMultiStore type Store struct { sdkCacheMultiStore - storeKeyToWriteWhitelist map[storetypes.StoreKey][]string + storeKeyToWriteWhitelist map[string][]string } -func NewStore(parent storetypes.CacheMultiStore, storeKeyToWriteWhitelist map[storetypes.StoreKey][]string) storetypes.CacheMultiStore { +func NewStore(parent storetypes.CacheMultiStore, storeKeyToWriteWhitelist map[string][]string) storetypes.CacheMultiStore { return &Store{ sdkCacheMultiStore: parent, storeKeyToWriteWhitelist: storeKeyToWriteWhitelist, @@ -29,8 +29,9 @@ func (cms Store) CacheMultiStore() storetypes.CacheMultiStore { func (cms Store) GetKVStore(key storetypes.StoreKey) storetypes.KVStore { rawKVStore := cms.sdkCacheMultiStore.GetKVStore(key) - if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key]; ok { + if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key.Name()]; ok { return kv.NewStore(rawKVStore, writeWhitelist) } - return rawKVStore + // whitelist nothing + return kv.NewStore(rawKVStore, []string{}) } diff --git a/store/whitelist/cachemulti/store_test.go b/store/whitelist/cachemulti/store_test.go index 42d3856d35..4f5de2037b 100644 --- a/store/whitelist/cachemulti/store_test.go +++ b/store/whitelist/cachemulti/store_test.go @@ -13,8 +13,8 @@ import ( var ( WhitelistedStoreKey = storetypes.NewKVStoreKey("whitelisted") NotWhitelistedStoreKey = storetypes.NewKVStoreKey("not-whitelisted") - TestStoreKeyToWriteWhitelist = map[storetypes.StoreKey][]string{ - WhitelistedStoreKey: {"foo"}, + TestStoreKeyToWriteWhitelist = map[string][]string{ + WhitelistedStoreKey.Name(): {"foo"}, } ) @@ -37,8 +37,8 @@ func TestWhitelistDisabledStore(t *testing.T) { multistore := store.NewTestCacheMultiStore(stores) whitelistMultistore := cachemulti.NewStore(multistore, TestStoreKeyToWriteWhitelist) kvStore := whitelistMultistore.GetKVStore(NotWhitelistedStoreKey) - require.NotPanics(t, func() { kvStore.Delete([]byte("bar")) }) - require.NotPanics(t, func() { kvStore.Delete([]byte("foo")) }) + require.Panics(t, func() { kvStore.Delete([]byte("bar")) }) + require.Panics(t, func() { kvStore.Delete([]byte("foo")) }) } func TestCacheStillWhitelist(t *testing.T) { diff --git a/store/whitelist/multi/store.go b/store/whitelist/multi/store.go index fdfa3ad81b..82e187eaf7 100644 --- a/store/whitelist/multi/store.go +++ b/store/whitelist/multi/store.go @@ -9,10 +9,10 @@ import ( type Store struct { storetypes.MultiStore - storeKeyToWriteWhitelist map[storetypes.StoreKey][]string + storeKeyToWriteWhitelist map[string][]string } -func NewStore(parent storetypes.MultiStore, storeKeyToWriteWhitelist map[storetypes.StoreKey][]string) storetypes.MultiStore { +func NewStore(parent storetypes.MultiStore, storeKeyToWriteWhitelist map[string][]string) storetypes.MultiStore { return &Store{ MultiStore: parent, storeKeyToWriteWhitelist: storeKeyToWriteWhitelist, @@ -25,8 +25,9 @@ func (cms Store) CacheMultiStore() storetypes.CacheMultiStore { func (cms Store) GetKVStore(key storetypes.StoreKey) storetypes.KVStore { rawKVStore := cms.MultiStore.GetKVStore(key) - if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key]; ok { + if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key.Name()]; ok { return kv.NewStore(rawKVStore, writeWhitelist) } - return rawKVStore + // whitelist nothing + return kv.NewStore(rawKVStore, []string{}) } diff --git a/store/whitelist/multi/store_test.go b/store/whitelist/multi/store_test.go index 81e3081e7c..c12645066d 100644 --- a/store/whitelist/multi/store_test.go +++ b/store/whitelist/multi/store_test.go @@ -13,8 +13,8 @@ import ( var ( WhitelistedStoreKey = storetypes.NewKVStoreKey("whitelisted") NotWhitelistedStoreKey = storetypes.NewKVStoreKey("not-whitelisted") - TestStoreKeyToWriteWhitelist = map[storetypes.StoreKey][]string{ - WhitelistedStoreKey: {"foo"}, + TestStoreKeyToWriteWhitelist = map[string][]string{ + WhitelistedStoreKey.Name(): {"foo"}, } ) @@ -37,8 +37,8 @@ func TestWhitelistDisabledStore(t *testing.T) { multistore := store.NewTestCacheMultiStore(stores) whitelistMultistore := multi.NewStore(multistore, TestStoreKeyToWriteWhitelist) kvStore := whitelistMultistore.GetKVStore(NotWhitelistedStoreKey) - require.NotPanics(t, func() { kvStore.Delete([]byte("bar")) }) - require.NotPanics(t, func() { kvStore.Delete([]byte("foo")) }) + require.Panics(t, func() { kvStore.Delete([]byte("bar")) }) + require.Panics(t, func() { kvStore.Delete([]byte("foo")) }) } func TestCacheStillWhitelist(t *testing.T) { diff --git a/utils/panic.go b/utils/panic.go new file mode 100644 index 0000000000..b711290180 --- /dev/null +++ b/utils/panic.go @@ -0,0 +1,28 @@ +package utils + +import ( + "fmt" + + "github.com/armon/go-metrics" + "github.com/cosmos/cosmos-sdk/telemetry" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +func PanicHandler(recoverCallback func(any)) func() { + return func() { + if err := recover(); err != nil { + recoverCallback(err) + } + } +} + +func MetricsPanicCallback(err any, ctx sdk.Context, key string) { + ctx.Logger().Error(fmt.Sprintf("panic occurred during order matching for: %s", key)) + telemetry.IncrCounterWithLabels( + []string{key}, + 1, + []metrics.Label{ + telemetry.NewLabel("error", fmt.Sprintf("%s", err)), + }, + ) +} diff --git a/x/dex/cache/cache.go b/x/dex/cache/cache.go index 1172a6b08d..0bb318d034 100644 --- a/x/dex/cache/cache.go +++ b/x/dex/cache/cache.go @@ -1,10 +1,17 @@ package dex import ( + "fmt" + "time" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/utils/datastructures" typesutils "github.com/sei-protocol/sei-chain/x/dex/types/utils" ) +const SynchronizationTimeoutInSeconds = 5 + type memStateItem interface { GetAccount() string } @@ -46,88 +53,145 @@ func (i *memStateItems[T]) Copy() *memStateItems[T] { } type MemState struct { - BlockOrders *datastructures.TypedNestedSyncMap[ + blockOrders *datastructures.TypedNestedSyncMap[ typesutils.ContractAddress, typesutils.PairString, *BlockOrders, ] - BlockCancels *datastructures.TypedNestedSyncMap[ + blockCancels *datastructures.TypedNestedSyncMap[ typesutils.ContractAddress, typesutils.PairString, *BlockCancellations, ] - DepositInfo *datastructures.TypedSyncMap[typesutils.ContractAddress, *DepositInfo] - LiquidationRequests *datastructures.TypedSyncMap[typesutils.ContractAddress, *LiquidationRequests] + depositInfo *datastructures.TypedSyncMap[typesutils.ContractAddress, *DepositInfo] + liquidationRequests *datastructures.TypedSyncMap[typesutils.ContractAddress, *LiquidationRequests] } func NewMemState() *MemState { return &MemState{ - BlockOrders: datastructures.NewTypedNestedSyncMap[ + blockOrders: datastructures.NewTypedNestedSyncMap[ typesutils.ContractAddress, typesutils.PairString, *BlockOrders, ](), - BlockCancels: datastructures.NewTypedNestedSyncMap[ + blockCancels: datastructures.NewTypedNestedSyncMap[ typesutils.ContractAddress, typesutils.PairString, *BlockCancellations, ](), - DepositInfo: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo](), - LiquidationRequests: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests](), + depositInfo: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo](), + liquidationRequests: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests](), } } -func (s *MemState) GetBlockOrders(contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockOrders { - ordersForPair, _ := s.BlockOrders.LoadOrStoreNested(contractAddr, pair, NewOrders()) +func (s *MemState) GetAllBlockOrders(ctx sdk.Context, contractAddr typesutils.ContractAddress) *datastructures.TypedSyncMap[typesutils.PairString, *BlockOrders] { + s.SynchronizeAccess(ctx, contractAddr) + ordersMap, _ := s.blockOrders.LoadOrStore(contractAddr, datastructures.NewTypedSyncMap[typesutils.PairString, *BlockOrders]()) + return ordersMap +} + +func (s *MemState) GetBlockOrders(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockOrders { + s.SynchronizeAccess(ctx, contractAddr) + ordersForPair, _ := s.blockOrders.LoadOrStoreNested(contractAddr, pair, NewOrders()) return ordersForPair } -func (s *MemState) GetBlockCancels(contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockCancellations { - cancelsForPair, _ := s.BlockCancels.LoadOrStoreNested(contractAddr, pair, NewCancels()) +func (s *MemState) GetBlockCancels(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockCancellations { + s.SynchronizeAccess(ctx, contractAddr) + cancelsForPair, _ := s.blockCancels.LoadOrStoreNested(contractAddr, pair, NewCancels()) return cancelsForPair } -func (s *MemState) GetDepositInfo(contractAddr typesutils.ContractAddress) *DepositInfo { - depositsForContract, _ := s.DepositInfo.LoadOrStore(contractAddr, NewDepositInfo()) +func (s *MemState) GetDepositInfo(ctx sdk.Context, contractAddr typesutils.ContractAddress) *DepositInfo { + s.SynchronizeAccess(ctx, contractAddr) + depositsForContract, _ := s.depositInfo.LoadOrStore(contractAddr, NewDepositInfo()) return depositsForContract } -func (s *MemState) GetLiquidationRequests(contractAddr typesutils.ContractAddress) *LiquidationRequests { - liquidationsForContract, _ := s.LiquidationRequests.LoadOrStore(contractAddr, NewLiquidationRequests()) +func (s *MemState) GetLiquidationRequests(ctx sdk.Context, contractAddr typesutils.ContractAddress) *LiquidationRequests { + s.SynchronizeAccess(ctx, contractAddr) + liquidationsForContract, _ := s.liquidationRequests.LoadOrStore(contractAddr, NewLiquidationRequests()) return liquidationsForContract } func (s *MemState) Clear() { - s.BlockOrders = datastructures.NewTypedNestedSyncMap[ + s.blockOrders = datastructures.NewTypedNestedSyncMap[ typesutils.ContractAddress, typesutils.PairString, *BlockOrders, ]() - s.BlockCancels = datastructures.NewTypedNestedSyncMap[ + s.blockCancels = datastructures.NewTypedNestedSyncMap[ typesutils.ContractAddress, typesutils.PairString, *BlockCancellations, ]() - s.DepositInfo = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo]() - s.LiquidationRequests = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests]() + s.depositInfo = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo]() + s.liquidationRequests = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests]() } -func (s *MemState) ClearCancellationForPair(contractAddr typesutils.ContractAddress, pair typesutils.PairString) { - s.BlockCancels.StoreNested(contractAddr, pair, NewCancels()) +func (s *MemState) ClearCancellationForPair(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) { + s.SynchronizeAccess(ctx, contractAddr) + s.blockCancels.StoreNested(contractAddr, pair, NewCancels()) } func (s *MemState) DeepCopy() *MemState { copy := NewMemState() - copy.BlockOrders = s.BlockOrders.DeepCopy(func(o *BlockOrders) *BlockOrders { return o.Copy() }) - copy.BlockCancels = s.BlockCancels.DeepCopy(func(o *BlockCancellations) *BlockCancellations { return o.Copy() }) - copy.DepositInfo = s.DepositInfo.DeepCopy(func(o *DepositInfo) *DepositInfo { return o.Copy() }) - copy.LiquidationRequests = s.LiquidationRequests.DeepCopy(func(o *LiquidationRequests) *LiquidationRequests { return o.Copy() }) + copy.blockOrders = s.blockOrders.DeepCopy(func(o *BlockOrders) *BlockOrders { return o.Copy() }) + copy.blockCancels = s.blockCancels.DeepCopy(func(o *BlockCancellations) *BlockCancellations { return o.Copy() }) + copy.depositInfo = s.depositInfo.DeepCopy(func(o *DepositInfo) *DepositInfo { return o.Copy() }) + copy.liquidationRequests = s.liquidationRequests.DeepCopy(func(o *LiquidationRequests) *LiquidationRequests { return o.Copy() }) return copy } func (s *MemState) DeepFilterAccount(account string) { - s.BlockOrders.DeepApply(func(o *BlockOrders) { o.FilterByAccount(account) }) - s.BlockCancels.DeepApply(func(o *BlockCancellations) { o.FilterByAccount(account) }) - s.DepositInfo.DeepApply(func(o *DepositInfo) { o.FilterByAccount(account) }) - s.LiquidationRequests.DeepApply(func(o *LiquidationRequests) { o.FilterByAccount(account) }) + s.blockOrders.DeepApply(func(o *BlockOrders) { o.FilterByAccount(account) }) + s.blockCancels.DeepApply(func(o *BlockCancellations) { o.FilterByAccount(account) }) + s.depositInfo.DeepApply(func(o *DepositInfo) { o.FilterByAccount(account) }) + s.liquidationRequests.DeepApply(func(o *LiquidationRequests) { o.FilterByAccount(account) }) +} + +func (s *MemState) SynchronizeAccess(ctx sdk.Context, contractAddr typesutils.ContractAddress) { + executingContract := GetExecutingContract(ctx) + if executingContract == nil { + // not accessed by contract. no need to synchronize + return + } + targetContractAddr := string(contractAddr) + if executingContract.ContractAddr == targetContractAddr { + // access by the contract itself does not need synchronization + return + } + for _, dependency := range executingContract.Dependencies { + if dependency.Dependency != targetContractAddr { + continue + } + terminationSignals := GetTerminationSignals(ctx) + if terminationSignals == nil { + // synchronization should fail in the case of no termination signal to prevent race conditions. + panic("no termination signal map found in context") + } + targetChannel, ok := terminationSignals.Load(dependency.ImmediateElderSibling) + if !ok { + // synchronization should fail in the case of no termination signal to prevent race conditions. + panic(fmt.Sprintf("no termination signal channel for contract %s in context", dependency.ImmediateElderSibling)) + } + + select { + case <-targetChannel: + // since buffered channel can only be consumed once, we need to + // requeue so that it can unblock other goroutines that waits for + // the same channel. + targetChannel <- struct{}{} + case <-time.After(SynchronizationTimeoutInSeconds * time.Second): + // synchronization should fail in the case of timeout to prevent race conditions. + panic(fmt.Sprintf("timing out waiting for termination of %s", dependency.ImmediateElderSibling)) + } + + return + } + + // fail loudly so that the offending contract can be rolled back. + // eventually we will automatically de-register contracts that have to be rolled back + // so that this would not become a point of attack in terms of performance. + panic(fmt.Sprintf("Contract %s trying to access state of %s which is not a registered dependency", executingContract.ContractAddr, targetContractAddr)) } diff --git a/x/dex/cache/cache_test.go b/x/dex/cache/cache_test.go index c135d16c43..4fc7925682 100644 --- a/x/dex/cache/cache_test.go +++ b/x/dex/cache/cache_test.go @@ -1,8 +1,11 @@ package dex_test import ( + "context" "testing" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/utils/datastructures" dex "github.com/sei-protocol/sei-chain/x/dex/cache" "github.com/sei-protocol/sei-chain/x/dex/types" "github.com/sei-protocol/sei-chain/x/dex/types/utils" @@ -15,67 +18,112 @@ const ( ) func TestDeepCopy(t *testing.T) { + ctx := sdk.Context{} stateOne := dex.NewMemState() - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 1, Account: "test", ContractAddr: TEST_CONTRACT, }) stateTwo := stateOne.DeepCopy() - stateTwo.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateTwo.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 2, Account: "test", ContractAddr: TEST_CONTRACT, }) // old state must not be changed - require.Equal(t, 1, len(stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) + require.Equal(t, 1, len(stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) // new state must be changed - require.Equal(t, 2, len(stateTwo.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) + require.Equal(t, 2, len(stateTwo.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) } func TestDeepFilterAccounts(t *testing.T) { + ctx := sdk.Context{} stateOne := dex.NewMemState() - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 1, Account: "test", ContractAddr: TEST_CONTRACT, }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 2, Account: "test2", ContractAddr: TEST_CONTRACT, }) - stateOne.GetBlockCancels(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ Id: 1, Creator: "test", }) - stateOne.GetBlockCancels(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ + stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Cancellation{ Id: 2, Creator: "test2", }) - stateOne.GetDepositInfo(utils.ContractAddress(TEST_CONTRACT)).Add(&dex.DepositInfoEntry{ + stateOne.GetDepositInfo(ctx, utils.ContractAddress(TEST_CONTRACT)).Add(&dex.DepositInfoEntry{ Creator: "test", }) - stateOne.GetDepositInfo(utils.ContractAddress(TEST_CONTRACT)).Add(&dex.DepositInfoEntry{ + stateOne.GetDepositInfo(ctx, utils.ContractAddress(TEST_CONTRACT)).Add(&dex.DepositInfoEntry{ Creator: "test2", }) - stateOne.GetLiquidationRequests(utils.ContractAddress(TEST_CONTRACT)).Add(&dex.LiquidationRequest{Requestor: "test", AccountToLiquidate: ""}) - stateOne.GetLiquidationRequests(utils.ContractAddress(TEST_CONTRACT)).Add(&dex.LiquidationRequest{Requestor: "test2", AccountToLiquidate: ""}) + stateOne.GetLiquidationRequests(ctx, utils.ContractAddress(TEST_CONTRACT)).Add(&dex.LiquidationRequest{Requestor: "test", AccountToLiquidate: ""}) + stateOne.GetLiquidationRequests(ctx, utils.ContractAddress(TEST_CONTRACT)).Add(&dex.LiquidationRequest{Requestor: "test2", AccountToLiquidate: ""}) stateOne.DeepFilterAccount("test") - require.Equal(t, 1, stateOne.BlockOrders.Len()) - require.Equal(t, 1, stateOne.BlockCancels.Len()) - require.Equal(t, 1, stateOne.DepositInfo.Len()) - require.Equal(t, 1, stateOne.LiquidationRequests.Len()) + require.Equal(t, 1, stateOne.GetAllBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT)).Len()) + require.Equal(t, 1, len(stateOne.GetBlockCancels(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) + require.Equal(t, 1, len(stateOne.GetDepositInfo(ctx, utils.ContractAddress(TEST_CONTRACT)).Get())) + require.Equal(t, 1, len(stateOne.GetLiquidationRequests(ctx, utils.ContractAddress(TEST_CONTRACT)).Get())) } func TestClear(t *testing.T) { + ctx := sdk.Context{} stateOne := dex.NewMemState() - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 1, Account: "test", ContractAddr: TEST_CONTRACT, }) stateOne.Clear() - require.Equal(t, 0, len(stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) + require.Equal(t, 0, len(stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get())) +} + +func TestSynchronization(t *testing.T) { + ctx := sdk.Context{} + stateOne := dex.NewMemState() + targetContract := utils.ContractAddress(TEST_CONTRACT) + // no go context + require.NotPanics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // no executing contract + goCtx := context.Background() + ctx = ctx.WithContext(goCtx) + require.NotPanics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // executing contract same as target contract + executingContract := types.ContractInfo{ContractAddr: TEST_CONTRACT} + ctx = ctx.WithContext(context.WithValue(goCtx, dex.CtxKeyExecutingContract, executingContract)) + require.NotPanics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // executing contract attempting to access non-dependency + executingContract.ContractAddr = "executing" + ctx = ctx.WithContext(context.WithValue(goCtx, dex.CtxKeyExecutingContract, executingContract)) + require.Panics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // no termination map + executingContract.Dependencies = []*types.ContractDependencyInfo{ + {Dependency: TEST_CONTRACT, ImmediateElderSibling: "elder"}, + } + ctx = ctx.WithContext(context.WithValue(goCtx, dex.CtxKeyExecutingContract, executingContract)) + require.Panics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // no termination signal channel for sibling + terminationSignals := datastructures.NewTypedSyncMap[string, chan struct{}]() + goCtx = context.WithValue(goCtx, dex.CtxKeyExecutingContract, executingContract) + goCtx = context.WithValue(goCtx, dex.CtxKeyExecTermSignal, terminationSignals) + ctx = ctx.WithContext(goCtx) + require.Panics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // termination signal times out + siblingChan := make(chan struct{}, 1) + terminationSignals.Store("elder", siblingChan) + require.Panics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + // termination signal sent + go func() { + siblingChan <- struct{}{} + }() + require.NotPanics(t, func() { stateOne.SynchronizeAccess(ctx, targetContract) }) + <-siblingChan // the channel should be re-populated } diff --git a/x/dex/cache/context.go b/x/dex/cache/context.go new file mode 100644 index 0000000000..ea6d25a74f --- /dev/null +++ b/x/dex/cache/context.go @@ -0,0 +1,44 @@ +package dex + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/utils/datastructures" + "github.com/sei-protocol/sei-chain/x/dex/types" +) + +type CtxKeyType string + +const ( + CtxKeyExecTermSignal = CtxKeyType("execution-termination-signals") + CtxKeyExecutingContract = CtxKeyType("executing-contract") +) + +func GetExecutingContract(ctx sdk.Context) *types.ContractInfo { + if ctx.Context() == nil { + return nil + } + executingContract := ctx.Context().Value(CtxKeyExecutingContract) + if executingContract == nil { + return nil + } + contract, ok := executingContract.(types.ContractInfo) + if !ok { + return nil + } + return &contract +} + +func GetTerminationSignals(ctx sdk.Context) *datastructures.TypedSyncMap[string, chan struct{}] { + if ctx.Context() == nil { + return nil + } + signals := ctx.Context().Value(CtxKeyExecTermSignal) + if signals == nil { + return nil + } + typedSignals, ok := signals.(*datastructures.TypedSyncMap[string, chan struct{}]) + if !ok { + return nil + } + return typedSignals +} diff --git a/x/dex/cache/order_test.go b/x/dex/cache/order_test.go index fc1948a597..908f604a2e 100644 --- a/x/dex/cache/order_test.go +++ b/x/dex/cache/order_test.go @@ -23,20 +23,22 @@ func TestOrderFilterByAccount(t *testing.T) { } func TestMarkFailedToPlaceByAccounts(t *testing.T) { + ctx := sdk.Context{} stateOne := dex.NewMemState() - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 1, Account: "test", ContractAddr: TEST_CONTRACT, }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).MarkFailedToPlaceByAccounts([]string{"test"}) + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).MarkFailedToPlaceByAccounts([]string{"test"}) require.Equal(t, types.OrderStatus_FAILED_TO_PLACE, - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get()[0].Status) + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get()[0].Status) } func TestMarkFailedToPlace(t *testing.T) { + ctx := sdk.Context{} stateOne := dex.NewMemState() - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 1, Account: "test", ContractAddr: TEST_CONTRACT, @@ -45,16 +47,17 @@ func TestMarkFailedToPlace(t *testing.T) { ID: 1, Reason: "some reason", } - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).MarkFailedToPlace([]wasm.UnsuccessfulOrder{unsuccessfulOrder}) + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).MarkFailedToPlace([]wasm.UnsuccessfulOrder{unsuccessfulOrder}) require.Equal(t, types.OrderStatus_FAILED_TO_PLACE, - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get()[0].Status) + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get()[0].Status) require.Equal(t, "some reason", - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get()[0].StatusDescription) + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Get()[0].StatusDescription) } func TestGetSortedMarketOrders(t *testing.T) { + ctx := sdk.Context{} stateOne := dex.NewMemState() - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 1, Account: "test", ContractAddr: TEST_CONTRACT, @@ -62,7 +65,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_LIQUIDATION, Price: sdk.MustNewDecFromStr("150"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 2, Account: "test", ContractAddr: TEST_CONTRACT, @@ -70,7 +73,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_MARKET, Price: sdk.MustNewDecFromStr("100"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 3, Account: "test", ContractAddr: TEST_CONTRACT, @@ -78,7 +81,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_MARKET, Price: sdk.MustNewDecFromStr("0"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 4, Account: "test", ContractAddr: TEST_CONTRACT, @@ -86,7 +89,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_LIQUIDATION, Price: sdk.MustNewDecFromStr("100"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 5, Account: "test", ContractAddr: TEST_CONTRACT, @@ -94,7 +97,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_MARKET, Price: sdk.MustNewDecFromStr("80"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 6, Account: "test", ContractAddr: TEST_CONTRACT, @@ -102,7 +105,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_MARKET, Price: sdk.MustNewDecFromStr("0"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 7, Account: "test", ContractAddr: TEST_CONTRACT, @@ -110,7 +113,7 @@ func TestGetSortedMarketOrders(t *testing.T) { OrderType: types.OrderType_LIMIT, Price: sdk.MustNewDecFromStr("100"), }) - stateOne.GetBlockOrders(utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ + stateOne.GetBlockOrders(ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).Add(&types.Order{ Id: 8, Account: "test", ContractAddr: TEST_CONTRACT, @@ -120,7 +123,7 @@ func TestGetSortedMarketOrders(t *testing.T) { }) marketBuysWithLiquidation := stateOne.GetBlockOrders( - utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( + ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( types.PositionDirection_LONG, true, ) require.Equal(t, uint64(3), marketBuysWithLiquidation[0].Id) @@ -128,14 +131,14 @@ func TestGetSortedMarketOrders(t *testing.T) { require.Equal(t, uint64(2), marketBuysWithLiquidation[2].Id) marketBuysWithoutLiquidation := stateOne.GetBlockOrders( - utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( + ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( types.PositionDirection_LONG, false, ) require.Equal(t, uint64(3), marketBuysWithoutLiquidation[0].Id) require.Equal(t, uint64(2), marketBuysWithoutLiquidation[1].Id) marketSellsWithLiquidation := stateOne.GetBlockOrders( - utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( + ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( types.PositionDirection_SHORT, true, ) require.Equal(t, uint64(6), marketSellsWithLiquidation[0].Id) @@ -143,7 +146,7 @@ func TestGetSortedMarketOrders(t *testing.T) { require.Equal(t, uint64(4), marketSellsWithLiquidation[2].Id) marketSellsWithoutLiquidation := stateOne.GetBlockOrders( - utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( + ctx, utils.ContractAddress(TEST_CONTRACT), utils.PairString(TEST_PAIR)).GetSortedMarketOrders( types.PositionDirection_SHORT, false, ) require.Equal(t, uint64(6), marketSellsWithoutLiquidation[0].Id) diff --git a/x/dex/contract/abci.go b/x/dex/contract/abci.go new file mode 100644 index 0000000000..3c3ee42542 --- /dev/null +++ b/x/dex/contract/abci.go @@ -0,0 +1,183 @@ +package contract + +import ( + "context" + "fmt" + "sync" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/store/whitelist/multi" + "github.com/sei-protocol/sei-chain/utils" + "github.com/sei-protocol/sei-chain/utils/datastructures" + dexcache "github.com/sei-protocol/sei-chain/x/dex/cache" + "github.com/sei-protocol/sei-chain/x/dex/keeper" + dexkeeperabci "github.com/sei-protocol/sei-chain/x/dex/keeper/abci" + dexkeeperutils "github.com/sei-protocol/sei-chain/x/dex/keeper/utils" + "github.com/sei-protocol/sei-chain/x/dex/types" + dextypeswasm "github.com/sei-protocol/sei-chain/x/dex/types/wasm" + "github.com/sei-protocol/sei-chain/x/store" + otrace "go.opentelemetry.io/otel/trace" +) + +type environment struct { + validContractsInfo []types.ContractInfo + failedContractAddresses datastructures.SyncSet[string] + finalizeBlockMessages *datastructures.TypedSyncMap[string, *dextypeswasm.SudoFinalizeBlockMsg] + settlementsByContract *datastructures.TypedSyncMap[string, []*types.SettlementEntry] + executionTerminationSignals *datastructures.TypedSyncMap[string, chan struct{}] + + finalizeMsgMutex *sync.Mutex +} + +func EndBlockerAtomic(ctx sdk.Context, keeper *keeper.Keeper, validContractsInfo []types.ContractInfo, tracer *otrace.Tracer) ([]types.ContractInfo, bool) { + env := newEnv(validContractsInfo) + ctx, msCached := cacheAndDecorateContext(ctx, env) + memStateCopy := keeper.MemState.DeepCopy() + + handleDeposits(ctx, env, keeper, tracer) + + runner := NewParallelRunner(func(contract types.ContractInfo) { + orderMatchingRunnable(ctx, env, keeper, contract, tracer) + }, validContractsInfo) + runner.Run() + + handleSettlements(ctx, env, keeper) + handleFinalizedBlocks(ctx, env, keeper) + + // No error is thrown for any contract. This should happen most of the time. + if env.failedContractAddresses.Size() == 0 { + msCached.Write() + return env.validContractsInfo, true + } + // restore keeper in-memory state + *keeper.MemState = *memStateCopy + + return filterNewValidContracts(env, keeper), false +} + +func newEnv(validContractsInfo []types.ContractInfo) *environment { + finalizeBlockMessages := datastructures.NewTypedSyncMap[string, *dextypeswasm.SudoFinalizeBlockMsg]() + settlementsByContract := datastructures.NewTypedSyncMap[string, []*types.SettlementEntry]() + executionTerminationSignals := datastructures.NewTypedSyncMap[string, chan struct{}]() + for _, contract := range validContractsInfo { + finalizeBlockMessages.Store(contract.ContractAddr, dextypeswasm.NewSudoFinalizeBlockMsg()) + settlementsByContract.Store(contract.ContractAddr, []*types.SettlementEntry{}) + executionTerminationSignals.Store(contract.ContractAddr, make(chan struct{}, 1)) + } + return &environment{ + validContractsInfo: validContractsInfo, + failedContractAddresses: datastructures.NewSyncSet([]string{}), + finalizeBlockMessages: finalizeBlockMessages, + settlementsByContract: settlementsByContract, + executionTerminationSignals: executionTerminationSignals, + finalizeMsgMutex: &sync.Mutex{}, + } +} + +func cacheAndDecorateContext(ctx sdk.Context, env *environment) (sdk.Context, sdk.CacheMultiStore) { + cachedCtx, msCached := store.GetCachedContext(ctx) + goCtx := context.WithValue(cachedCtx.Context(), dexcache.CtxKeyExecTermSignal, env.executionTerminationSignals) + cachedCtx = cachedCtx.WithContext(goCtx) + return cachedCtx, msCached +} + +func decorateContextForContract(ctx sdk.Context, contractInfo types.ContractInfo) sdk.Context { + goCtx := context.WithValue(ctx.Context(), dexcache.CtxKeyExecutingContract, contractInfo) + return ctx.WithContext(goCtx).WithMultiStore(multi.NewStore(ctx.MultiStore(), GetWhitelistMap(contractInfo.ContractAddr))) +} + +func handleDeposits(ctx sdk.Context, env *environment, keeper *keeper.Keeper, tracer *otrace.Tracer) { + // Handle deposit sequentially since they mutate `bank` state which is shared by all contracts + keeperWrapper := dexkeeperabci.KeeperWrapper{Keeper: keeper} + for _, contract := range env.validContractsInfo { + if err := keeperWrapper.HandleEBDeposit(ctx.Context(), ctx, tracer, contract.ContractAddr); err != nil { + env.failedContractAddresses.Add(contract.ContractAddr) + } + } +} + +func handleSettlements(ctx sdk.Context, env *environment, keeper *keeper.Keeper) { + contractsNeedOrderMatching := datastructures.NewSyncSet([]string{}) + for _, contract := range env.validContractsInfo { + if contract.NeedOrderMatching { + contractsNeedOrderMatching.Add(contract.ContractAddr) + } + } + env.settlementsByContract.Range(func(contractAddr string, settlements []*types.SettlementEntry) bool { + if !contractsNeedOrderMatching.Contains(contractAddr) { + return true + } + if err := HandleSettlements(ctx, contractAddr, keeper, settlements); err != nil { + ctx.Logger().Error(fmt.Sprintf("Error handling settlements for %s", contractAddr)) + env.failedContractAddresses.Add(contractAddr) + } + return true + }) +} + +func handleFinalizedBlocks(ctx sdk.Context, env *environment, keeper *keeper.Keeper) { + contractsNeedHook := datastructures.NewSyncSet([]string{}) + for _, contract := range env.validContractsInfo { + if contract.NeedHook { + contractsNeedHook.Add(contract.ContractAddr) + } + } + env.finalizeBlockMessages.Range(func(contractAddr string, finalizeBlockMsg *dextypeswasm.SudoFinalizeBlockMsg) bool { + if !contractsNeedHook.Contains(contractAddr) { + return true + } + if _, err := dexkeeperutils.CallContractSudo(ctx, keeper, contractAddr, finalizeBlockMsg); err != nil { + ctx.Logger().Error(fmt.Sprintf("Error calling FinalizeBlock of %s", contractAddr)) + env.failedContractAddresses.Add(contractAddr) + } + return true + }) +} + +func orderMatchingRunnable(ctx sdk.Context, env *environment, keeper *keeper.Keeper, contractInfo types.ContractInfo, tracer *otrace.Tracer) { + defer utils.PanicHandler(func(err any) { orderMatchingRecoverCallback(err, ctx, env, contractInfo) })() + defer func() { + if channel, ok := env.executionTerminationSignals.Load(contractInfo.ContractAddr); ok { + channel <- struct{}{} + } + }() + + if !contractInfo.NeedOrderMatching { + return + } + ctx = decorateContextForContract(ctx, contractInfo) + ctx.Logger().Info(fmt.Sprintf("End block for %s", contractInfo.ContractAddr)) + if orderResultsMap, settlements, err := HandleExecutionForContract(ctx, contractInfo, keeper, tracer); err != nil { + ctx.Logger().Error(fmt.Sprintf("Error for EndBlock of %s", contractInfo.ContractAddr)) + env.failedContractAddresses.Add(contractInfo.ContractAddr) + } else { + for account, orderResults := range orderResultsMap { + // only add to finalize message for contract addresses + if msg, ok := env.finalizeBlockMessages.Load(account); ok { + env.finalizeMsgMutex.Lock() + msg.AddContractResult(orderResults) + env.finalizeMsgMutex.Unlock() + } + } + env.settlementsByContract.Store(contractInfo.ContractAddr, settlements) + } +} + +func orderMatchingRecoverCallback(err any, ctx sdk.Context, env *environment, contractInfo types.ContractInfo) { + utils.MetricsPanicCallback(err, ctx, fmt.Sprintf("%s%s", types.ModuleName, "endblockpanic")) + // idempotent + env.failedContractAddresses.Add(contractInfo.ContractAddr) +} + +func filterNewValidContracts(env *environment, keeper *keeper.Keeper) []types.ContractInfo { + newValidContracts := []types.ContractInfo{} + for _, contract := range env.validContractsInfo { + if !env.failedContractAddresses.Contains(contract.ContractAddr) { + newValidContracts = append(newValidContracts, contract) + } + } + for _, failedContractAddress := range env.failedContractAddresses.ToOrderedSlice(datastructures.StringComparator) { + keeper.MemState.DeepFilterAccount(failedContractAddress) + } + return newValidContracts +} diff --git a/x/dex/contract/execution.go b/x/dex/contract/execution.go index 52799761fd..3afd7b0ad7 100644 --- a/x/dex/contract/execution.go +++ b/x/dex/contract/execution.go @@ -1,10 +1,15 @@ package contract import ( + "fmt" + "sync" + sdk "github.com/cosmos/cosmos-sdk/types" "go.opentelemetry.io/otel/attribute" otrace "go.opentelemetry.io/otel/trace" + "github.com/sei-protocol/sei-chain/store/whitelist/multi" + "github.com/sei-protocol/sei-chain/utils" dexcache "github.com/sei-protocol/sei-chain/x/dex/cache" "github.com/sei-protocol/sei-chain/x/dex/exchange" "github.com/sei-protocol/sei-chain/x/dex/keeper" @@ -83,7 +88,7 @@ func cancelForPair( dexkeeper *keeper.Keeper, orderbook *types.OrderBook, ) { - cancels := dexkeeper.MemState.GetBlockCancels(typedContractAddr, typedPairStr) + cancels := dexkeeper.MemState.GetBlockCancels(ctx, typedContractAddr, typedPairStr) originalOrdersToCancel := dexkeeper.GetOrdersByIds(ctx, string(typedContractAddr), cancels.GetIdsToCancel()) exchange.CancelOrders(cancels.Get(), orderbook, originalOrdersToCancel) } @@ -95,7 +100,7 @@ func matchMarketOrderForPair( dexkeeper *keeper.Keeper, orderbook *types.OrderBook, ) exchange.ExecutionOutcome { - orders := dexkeeper.MemState.GetBlockOrders(typedContractAddr, typedPairStr) + orders := dexkeeper.MemState.GetBlockOrders(ctx, typedContractAddr, typedPairStr) marketBuys := orders.GetSortedMarketOrders(types.PositionDirection_LONG, true) marketSells := orders.GetSortedMarketOrders(types.PositionDirection_SHORT, true) marketBuyOutcome := exchange.MatchMarketOrders( @@ -120,7 +125,7 @@ func matchLimitOrderForPair( dexkeeper *keeper.Keeper, orderbook *types.OrderBook, ) exchange.ExecutionOutcome { - orders := dexkeeper.MemState.GetBlockOrders(typedContractAddr, typedPairStr) + orders := dexkeeper.MemState.GetBlockOrders(ctx, typedContractAddr, typedPairStr) limitBuys := orders.GetLimitOrders(types.PositionDirection_LONG) limitSells := orders.GetLimitOrders(types.PositionDirection_SHORT) return exchange.MatchLimitOrders( @@ -138,8 +143,8 @@ func UpdateOrderState( dexkeeper *keeper.Keeper, settlements []*types.SettlementEntry, ) { - orders := dexkeeper.MemState.GetBlockOrders(typedContractAddr, typedPairStr) - cancels := dexkeeper.MemState.GetBlockCancels(typedContractAddr, typedPairStr) + orders := dexkeeper.MemState.GetBlockOrders(ctx, typedContractAddr, typedPairStr) + cancels := dexkeeper.MemState.GetBlockCancels(ctx, typedContractAddr, typedPairStr) // First add any new order, whether successfully placed or not, to the store for _, order := range orders.Get() { if order.Quantity.IsZero() { @@ -160,19 +165,20 @@ func UpdateOrderState( } } // Finally update market order status based on execution result - for _, marketOrderID := range getUnfulfilledPlacedMarketOrderIds(typedContractAddr, typedPairStr, dexkeeper) { + for _, marketOrderID := range getUnfulfilledPlacedMarketOrderIds(ctx, typedContractAddr, typedPairStr, dexkeeper) { dexkeeper.UpdateOrderStatus(ctx, string(typedContractAddr), marketOrderID, types.OrderStatus_CANCELLED) } } func PrepareCancelUnfulfilledMarketOrders( + ctx sdk.Context, typedContractAddr dextypesutils.ContractAddress, typedPairStr dextypesutils.PairString, dexkeeper *keeper.Keeper, ) { - dexkeeper.MemState.ClearCancellationForPair(typedContractAddr, typedPairStr) - for _, marketOrderID := range getUnfulfilledPlacedMarketOrderIds(typedContractAddr, typedPairStr, dexkeeper) { - dexkeeper.MemState.GetBlockCancels(typedContractAddr, typedPairStr).Add(&types.Cancellation{ + dexkeeper.MemState.ClearCancellationForPair(ctx, typedContractAddr, typedPairStr) + for _, marketOrderID := range getUnfulfilledPlacedMarketOrderIds(ctx, typedContractAddr, typedPairStr, dexkeeper) { + dexkeeper.MemState.GetBlockCancels(ctx, typedContractAddr, typedPairStr).Add(&types.Cancellation{ Id: marketOrderID, Initiator: types.CancellationInitiator_USER, }) @@ -180,12 +186,13 @@ func PrepareCancelUnfulfilledMarketOrders( } func getUnfulfilledPlacedMarketOrderIds( + ctx sdk.Context, typedContractAddr dextypesutils.ContractAddress, typedPairStr dextypesutils.PairString, dexkeeper *keeper.Keeper, ) []uint64 { res := []uint64{} - for _, order := range dexkeeper.MemState.GetBlockOrders(typedContractAddr, typedPairStr).Get() { + for _, order := range dexkeeper.MemState.GetBlockOrders(ctx, typedContractAddr, typedPairStr).Get() { if order.Status == types.OrderStatus_FAILED_TO_PLACE { continue } @@ -198,6 +205,49 @@ func getUnfulfilledPlacedMarketOrderIds( return res } +func ExecutePairsInParallel(ctx sdk.Context, contractAddr string, dexkeeper *keeper.Keeper) ([]func(), []*types.SettlementEntry) { + typedContractAddr := dextypesutils.ContractAddress(contractAddr) + registeredPairs := dexkeeper.GetAllRegisteredPairs(ctx, contractAddr) + orderUpdaters := []func(){} + settlements := []*types.SettlementEntry{} + + mu := sync.Mutex{} + wg := sync.WaitGroup{} + anyPanicked := false + + for _, pair := range registeredPairs { + wg.Add(1) + + pair := pair + pairCtx := ctx.WithMultiStore(multi.NewStore(ctx.MultiStore(), GetPerPairWhitelistMap(contractAddr, pair))) + go func() { + defer wg.Done() + defer utils.PanicHandler(func(err any) { + anyPanicked = true + utils.MetricsPanicCallback(err, ctx, fmt.Sprintf("%s-%s|%s", contractAddr, pair.PriceDenom, pair.AssetDenom)) + })() + + pairCopy := pair + pairSettlements := ExecutePair(pairCtx, contractAddr, pair, dexkeeper) + PrepareCancelUnfulfilledMarketOrders(pairCtx, typedContractAddr, dextypesutils.GetPairString(&pairCopy), dexkeeper) + + mu.Lock() + defer mu.Unlock() + orderUpdaters = append(orderUpdaters, func() { + UpdateOrderState(ctx, typedContractAddr, dextypesutils.GetPairString(&pairCopy), dexkeeper, pairSettlements) + }) + settlements = append(settlements, pairSettlements...) + }() + } + wg.Wait() + if anyPanicked { + // need to re-throw panic to the top level goroutine + panic("panicked during pair execution") + } + + return orderUpdaters, settlements +} + func HandleExecutionForContract( ctx sdk.Context, contract types.ContractInfo, @@ -206,21 +256,17 @@ func HandleExecutionForContract( ) (map[string]dextypeswasm.ContractOrderResult, []*types.SettlementEntry, error) { contractAddr := contract.ContractAddr typedContractAddr := dextypesutils.ContractAddress(contractAddr) - registeredPairs := dexkeeper.GetAllRegisteredPairs(ctx, contractAddr) orderResults := map[string]dextypeswasm.ContractOrderResult{} - settlements := []*types.SettlementEntry{} + // Call contract hooks so that contracts can do internal bookkeeping if err := CallPreExecutionHooks(ctx, contractAddr, dexkeeper, tracer); err != nil { - return orderResults, settlements, err + return orderResults, []*types.SettlementEntry{}, err } - for _, pair := range registeredPairs { - pairCopy := pair - pairSettlements := ExecutePair(ctx, contractAddr, pair, dexkeeper) - UpdateOrderState(ctx, typedContractAddr, dextypesutils.GetPairString(&pairCopy), dexkeeper, pairSettlements) - PrepareCancelUnfulfilledMarketOrders(typedContractAddr, dextypesutils.GetPairString(&pairCopy), dexkeeper) + orderUpdaters, settlements := ExecutePairsInParallel(ctx, contractAddr, dexkeeper) - settlements = append(settlements, pairSettlements...) + for _, orderUpdater := range orderUpdaters { + orderUpdater() } // Cancel unfilled market orders if err := CancelUnfulfilledMarketOrders(ctx, contractAddr, dexkeeper, tracer); err != nil { @@ -228,12 +274,9 @@ func HandleExecutionForContract( } // populate order placement results for FinalizeBlock hook - contractOrdersMap, ok := dexkeeper.MemState.BlockOrders.Load(typedContractAddr) - if ok { - contractOrdersMap.DeepApply(func(orders *dexcache.BlockOrders) { - dextypeswasm.PopulateOrderPlacementResults(contractAddr, orders.Get(), orderResults) - }) - } + dexkeeper.MemState.GetAllBlockOrders(ctx, typedContractAddr).DeepApply(func(orders *dexcache.BlockOrders) { + dextypeswasm.PopulateOrderPlacementResults(contractAddr, orders.Get(), orderResults) + }) dextypeswasm.PopulateOrderExecutionResults(contractAddr, settlements, orderResults) return orderResults, settlements, nil } diff --git a/x/dex/contract/runner.go b/x/dex/contract/runner.go new file mode 100644 index 0000000000..6aae908278 --- /dev/null +++ b/x/dex/contract/runner.go @@ -0,0 +1,131 @@ +package contract + +import ( + "sync/atomic" + + "github.com/sei-protocol/sei-chain/utils/datastructures" + "github.com/sei-protocol/sei-chain/x/dex/types" + "github.com/sei-protocol/sei-chain/x/dex/types/utils" +) + +type ParallelRunner struct { + runnable func(contract types.ContractInfo) + + contractAddrToInfo *datastructures.TypedSyncMap[utils.ContractAddress, *types.ContractInfo] + readyContracts *datastructures.TypedSyncMap[utils.ContractAddress, struct{}] + readyCnt int64 + inProgressCnt int64 + someContractFinished chan struct{} +} + +func NewParallelRunner(runnable func(contract types.ContractInfo), contracts []types.ContractInfo) ParallelRunner { + contractAddrToInfo := datastructures.NewTypedSyncMap[utils.ContractAddress, *types.ContractInfo]() + contractsFrontier := datastructures.NewTypedSyncMap[utils.ContractAddress, struct{}]() + for _, contract := range contracts { + // runner will mutate ContractInfo fields + copy := contract + typedContractAddr := utils.ContractAddress(contract.ContractAddr) + contractAddrToInfo.Store(typedContractAddr, ©) + if copy.NumIncomingDependencies == 0 { + contractsFrontier.Store(typedContractAddr, struct{}{}) + } + } + return ParallelRunner{ + runnable: runnable, + contractAddrToInfo: contractAddrToInfo, + readyContracts: contractsFrontier, + readyCnt: int64(contractsFrontier.Len()), + inProgressCnt: 0, + someContractFinished: make(chan struct{}), + } +} + +// We define "frontier contract" as a contract which: +// 1. Has not finished running yet, and +// 2. either: +// a. has no other contracts depending on it, or +// b. for which all contracts that depend on it have already finished. +// Consequently, the set of frontier contracts will mutate throughout the +// `Run` method, until all contracts finish their runs. +// The key principle here is that at any moment, we can have all frontier +// contracts running concurrently, since there must be no ancestral +// relationships among them due to the definition above. +// The simplest implementation would be: +// ``` +// while there is any contract left: +// run all frontier contracts concurrently +// wait for all runs to finish +// update the frontier set +// ``` +// We can illustrate why this implementation is not optimal with the following +// example: +// Suppose we have four contracts, where A depends on B, and C depends on +// D. The run for A, B, C, D takes 5s, 5s, 8s, 2s, respectively. +// With the implementation above, the first iteration would take 8s since +// it runs A and C, and the second iteration would take 5s since it runs +// B and D. However C doesn't actually need to wait for B to finish, and +// if C runs immediately after A finishes, the whole process would take +// max(5 + 5, 8 + 2) = 10s, which is 3s faster than the implementation +// above. +// So we can optimize the implementation to be: +// ``` +// while there is any contract left: +// run all frontier contracts concurrently +// wait for any existing run (could be from previous iteration) to finish +// update the frontier set +// ``` +// With the example above, the whole process would take 3 iterations: +// Iter 1 (A, C run): 5s since it finishes when A finishes +// Iter 2 (B run): 3s since it finishes when C finishes +// Iter 3 (D run): 2s since it finishes when B, D finish +// +// The following `Run` method implements the pseudocode above. +func (r *ParallelRunner) Run() { + // The ordering of the two conditions below matters, since readyCnt + // is updated before inProgressCnt. + for r.inProgressCnt > 0 || r.readyCnt > 0 { + // r.readyContracts represent all frontier contracts that have + // not started running yet. + r.readyContracts.Range(func(key utils.ContractAddress, _ struct{}) bool { + atomic.AddInt64(&r.inProgressCnt, 1) + go r.wrapRunnable(key) + // Since the frontier contract has started running, we need + // to remove it from r.readyContracts so that it won't + // double-run. + r.readyContracts.Delete(key) + // The reason we use a separate readyCnt is because `sync.Map` + // doesn't provide an atomic way to get its length. + atomic.AddInt64(&r.readyCnt, -1) + return true + }) + // This corresponds to the "wait for any existing run (could be + // from previous iteration) to finish" part in the pseudocode above. + <-r.someContractFinished + } +} + +func (r *ParallelRunner) wrapRunnable(contractAddr utils.ContractAddress) { + contractInfo, _ := r.contractAddrToInfo.Load(contractAddr) + r.runnable(*contractInfo) + + // Check if there is any contract that should be promoted to the frontier set. + if contractInfo.Dependencies != nil { + for _, dependency := range contractInfo.Dependencies { + dependentContract := dependency.Dependency + typedDependentContract := utils.ContractAddress(dependentContract) + dependentInfo, _ := r.contractAddrToInfo.Load(typedDependentContract) + // It's okay to mutate ContractInfo here since it's a copy made in the runner's + // constructor. + newNumIncomingPaths := atomic.AddInt64(&dependentInfo.NumIncomingDependencies, -1) + // This corresponds to the "for which all contracts that depend on it have + // already finished." definition for frontier contract. + if newNumIncomingPaths == 0 { + r.readyContracts.Store(typedDependentContract, struct{}{}) + atomic.AddInt64(&r.readyCnt, 1) + } + } + } + + atomic.AddInt64(&r.inProgressCnt, -1) // this has to happen after any potential increment to readyCnt + r.someContractFinished <- struct{}{} +} diff --git a/x/dex/contract/runner_test.go b/x/dex/contract/runner_test.go new file mode 100644 index 0000000000..d4447378c6 --- /dev/null +++ b/x/dex/contract/runner_test.go @@ -0,0 +1,95 @@ +package contract_test + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/sei-protocol/sei-chain/x/dex/contract" + "github.com/sei-protocol/sei-chain/x/dex/types" + "github.com/stretchr/testify/require" +) + +var counter int64 = 0 +var dependencyCheck = sync.Map{} + +func noopRunnable(_ types.ContractInfo) { + atomic.AddInt64(&counter, 1) +} + +func idleRunnable(_ types.ContractInfo) { + time.Sleep(5 * time.Second) + atomic.AddInt64(&counter, 1) +} + +func dependencyCheckRunnable(contractInfo types.ContractInfo) { + if contractInfo.ContractAddr == "C" { + _, hasA := dependencyCheck.Load("A") + _, hasB := dependencyCheck.Load("B") + if !hasA || !hasB { + return + } + } + dependencyCheck.Store(contractInfo.ContractAddr, struct{}{}) +} + +func TestRunnerSingleContract(t *testing.T) { + counter = 0 + contractInfo := types.ContractInfo{ + ContractAddr: "A", + NumIncomingDependencies: 0, + } + runner := contract.NewParallelRunner(noopRunnable, []types.ContractInfo{contractInfo}) + runner.Run() + require.Equal(t, int64(1), counter) +} + +func TestRunnerParallelContract(t *testing.T) { + counter = 0 + contractInfoA := types.ContractInfo{ + ContractAddr: "A", + NumIncomingDependencies: 0, + } + contractInfoB := types.ContractInfo{ + ContractAddr: "B", + NumIncomingDependencies: 0, + } + runner := contract.NewParallelRunner(idleRunnable, []types.ContractInfo{contractInfoA, contractInfoB}) + start := time.Now() + runner.Run() + end := time.Now() + duration := end.Sub(start) + require.Equal(t, int64(2), counter) + require.True(t, duration.Seconds() < 10) // would not be flaky unless it's running on really slow hardware +} + +func TestRunnerParallelContractWithDependency(t *testing.T) { + counter = 0 + contractInfoA := types.ContractInfo{ + ContractAddr: "A", + NumIncomingDependencies: 0, + Dependencies: []*types.ContractDependencyInfo{ + { + Dependency: "C", + }, + }, + } + contractInfoB := types.ContractInfo{ + ContractAddr: "B", + NumIncomingDependencies: 0, + Dependencies: []*types.ContractDependencyInfo{ + { + Dependency: "C", + }, + }, + } + contractInfoC := types.ContractInfo{ + ContractAddr: "C", + NumIncomingDependencies: 2, + } + runner := contract.NewParallelRunner(dependencyCheckRunnable, []types.ContractInfo{contractInfoC, contractInfoB, contractInfoA}) + runner.Run() + _, hasC := dependencyCheck.Load("C") + require.True(t, hasC) +} diff --git a/x/dex/contract/whitelist.go b/x/dex/contract/whitelist.go new file mode 100644 index 0000000000..2dac6a88c5 --- /dev/null +++ b/x/dex/contract/whitelist.go @@ -0,0 +1,69 @@ +package contract + +import ( + wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/utils" + "github.com/sei-protocol/sei-chain/x/dex/types" +) + +var DexWhitelistedKeys = []string{ + types.LongBookKey, + types.ShortBookKey, + types.OrderKey, + types.AccountActiveOrdersKey, + types.CancelKey, + types.TwapKey, + types.PriceKey, + types.SettlementEntryKey, + types.NextOrderIDKey, +} + +var WasmWhitelistedKeys = []string{ + string(wasmtypes.ContractStorePrefix), +} + +var DexPerPairWhitelistedKeys = []string{ + types.LongBookKey, + types.ShortBookKey, + types.PriceKey, +} + +func GetWhitelistMap(contractAddr string) map[string][]string { + res := map[string][]string{} + res[storetypes.NewKVStoreKey(types.StoreKey).Name()] = GetDexWhitelistedPrefixes(contractAddr) + res[storetypes.NewKVStoreKey(wasmtypes.StoreKey).Name()] = GetWasmWhitelistedPrefixes(contractAddr) + return res +} + +func GetPerPairWhitelistMap(contractAddr string, pair types.Pair) map[string][]string { + res := map[string][]string{} + res[storetypes.NewKVStoreKey(types.StoreKey).Name()] = GetDexPerPairWhitelistedPrefixes(contractAddr, pair) + return res +} + +func GetDexWhitelistedPrefixes(contractAddr string) []string { + return utils.Map(DexWhitelistedKeys, func(key string) string { + return string(append( + types.KeyPrefix(key), types.KeyPrefix(contractAddr)..., + )) + }) +} + +func GetWasmWhitelistedPrefixes(contractAddr string) []string { + addr, _ := sdk.AccAddressFromBech32(contractAddr) + return utils.Map(WasmWhitelistedKeys, func(key string) string { + return string(append( + []byte(key), addr..., + )) + }) +} + +func GetDexPerPairWhitelistedPrefixes(contractAddr string, pair types.Pair) []string { + return utils.Map(DexWhitelistedKeys, func(key string) string { + return string(append(append( + types.KeyPrefix(key), types.KeyPrefix(contractAddr)..., + ), types.PairPrefix(pair.PriceDenom, pair.AssetDenom)...)) + }) +} diff --git a/x/dex/keeper/abci/end_block_cancel_orders.go b/x/dex/keeper/abci/end_block_cancel_orders.go index a3e0b199c2..5465f0c3df 100644 --- a/x/dex/keeper/abci/end_block_cancel_orders.go +++ b/x/dex/keeper/abci/end_block_cancel_orders.go @@ -18,7 +18,7 @@ func (w KeeperWrapper) HandleEBCancelOrders(ctx context.Context, sdkCtx sdk.Cont span.SetAttributes(attribute.String("contractAddr", contractAddr)) typedContractAddr := typesutils.ContractAddress(contractAddr) - msg := w.getCancelSudoMsg(typedContractAddr, registeredPairs) + msg := w.getCancelSudoMsg(sdkCtx, typedContractAddr, registeredPairs) if _, err := utils.CallContractSudo(sdkCtx, w.Keeper, contractAddr, msg); err != nil { sdkCtx.Logger().Error(fmt.Sprintf("Error during cancellation: %s", err.Error())) return err @@ -27,11 +27,11 @@ func (w KeeperWrapper) HandleEBCancelOrders(ctx context.Context, sdkCtx sdk.Cont return nil } -func (w KeeperWrapper) getCancelSudoMsg(typedContractAddr typesutils.ContractAddress, registeredPairs []types.Pair) wasm.SudoOrderCancellationMsg { +func (w KeeperWrapper) getCancelSudoMsg(sdkCtx sdk.Context, typedContractAddr typesutils.ContractAddress, registeredPairs []types.Pair) wasm.SudoOrderCancellationMsg { idsToCancel := []uint64{} for _, pair := range registeredPairs { typedPairStr := typesutils.GetPairString(&pair) //nolint:gosec // THIS MAY BE CAUSE FOR CONCERN AND WE MIGHT WANT TO REFACTOR. - for _, cancel := range w.MemState.GetBlockCancels(typedContractAddr, typedPairStr).Get() { + for _, cancel := range w.MemState.GetBlockCancels(sdkCtx, typedContractAddr, typedPairStr).Get() { idsToCancel = append(idsToCancel, cancel.Id) } } diff --git a/x/dex/keeper/abci/end_block_deposit.go b/x/dex/keeper/abci/end_block_deposit.go new file mode 100644 index 0000000000..00bc36bb86 --- /dev/null +++ b/x/dex/keeper/abci/end_block_deposit.go @@ -0,0 +1,56 @@ +package abci + +import ( + "context" + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/x/dex/keeper/utils" + "github.com/sei-protocol/sei-chain/x/dex/types" + typesutils "github.com/sei-protocol/sei-chain/x/dex/types/utils" + "github.com/sei-protocol/sei-chain/x/dex/types/wasm" + "go.opentelemetry.io/otel/attribute" + otrace "go.opentelemetry.io/otel/trace" +) + +func (w KeeperWrapper) HandleEBDeposit(ctx context.Context, sdkCtx sdk.Context, tracer *otrace.Tracer, contractAddr string) error { + _, span := (*tracer).Start(ctx, "SudoPlaceOrders") + defer span.End() + span.SetAttributes(attribute.String("contractAddr", contractAddr)) + + typedContractAddr := typesutils.ContractAddress(contractAddr) + msg := w.GetDepositSudoMsg(sdkCtx, typedContractAddr) + _, err := utils.CallContractSudo(sdkCtx, w.Keeper, contractAddr, msg) // deposit + if err != nil { + sdkCtx.Logger().Error(fmt.Sprintf("Error during deposit: %s", err.Error())) + return err + } + + return nil +} + +func (w KeeperWrapper) GetDepositSudoMsg(ctx sdk.Context, typedContractAddr typesutils.ContractAddress) wasm.SudoOrderPlacementMsg { + contractDepositInfo := []wasm.ContractDepositInfo{} + for _, depositInfo := range w.MemState.GetDepositInfo(ctx, typedContractAddr).Get() { + fund := sdk.NewCoins(sdk.NewCoin(depositInfo.Denom, depositInfo.Amount.RoundInt())) + sender, err := sdk.AccAddressFromBech32(depositInfo.Creator) + if err != nil { + ctx.Logger().Error("Invalid deposit creator") + } + receiver, err := sdk.AccAddressFromBech32(string(typedContractAddr)) + if err != nil { + ctx.Logger().Error("Invalid deposit contract") + } + if err := w.BankKeeper.SendCoins(ctx, sender, receiver, fund); err == nil { + contractDepositInfo = append(contractDepositInfo, depositInfo.ToContractDepositInfo()) + } else { + ctx.Logger().Error(err.Error()) + } + } + return wasm.SudoOrderPlacementMsg{ + OrderPlacements: wasm.OrderPlacementMsgDetails{ + Orders: []types.Order{}, + Deposits: contractDepositInfo, + }, + } +} diff --git a/x/dex/keeper/abci/end_block_liquidation.go b/x/dex/keeper/abci/end_block_liquidation.go index c4e52e96c9..807475d42d 100644 --- a/x/dex/keeper/abci/end_block_liquidation.go +++ b/x/dex/keeper/abci/end_block_liquidation.go @@ -19,7 +19,7 @@ func (w KeeperWrapper) HandleEBLiquidation(ctx context.Context, sdkCtx sdk.Conte liquidationSpan.SetAttributes(attribute.String("contractAddr", contractAddr)) typedContractAddr := typesutils.ContractAddress(contractAddr) - msg := w.getLiquidationSudoMsg(typedContractAddr) + msg := w.getLiquidationSudoMsg(sdkCtx, typedContractAddr) data, err := utils.CallContractSudo(sdkCtx, w.Keeper, contractAddr, msg) if err != nil { return err @@ -38,14 +38,14 @@ func (w KeeperWrapper) HandleEBLiquidation(ctx context.Context, sdkCtx sdk.Conte // Clear up all user-initiated order activities in the current block for _, pair := range registeredPairs { typedPairStr := typesutils.GetPairString(&pair) //nolint:gosec // USING THE POINTER HERE COULD BE BAD LET'S CHECK IT. - w.MemState.GetBlockCancels(typedContractAddr, typedPairStr).FilterByIds(liquidatedAccountsActiveOrderIds) - w.MemState.GetBlockOrders(typedContractAddr, typedPairStr).MarkFailedToPlaceByAccounts(response.SuccessfulAccounts) + w.MemState.GetBlockCancels(sdkCtx, typedContractAddr, typedPairStr).FilterByIds(liquidatedAccountsActiveOrderIds) + w.MemState.GetBlockOrders(sdkCtx, typedContractAddr, typedPairStr).MarkFailedToPlaceByAccounts(response.SuccessfulAccounts) } // Cancel all outstanding orders of liquidated accounts, as denoted as cancelled via liquidation for id, order := range w.GetOrdersByIds(sdkCtx, contractAddr, liquidatedAccountsActiveOrderIds) { pair := types.Pair{PriceDenom: order.PriceDenom, AssetDenom: order.AssetDenom} typedPairStr := typesutils.GetPairString(&pair) - w.MemState.GetBlockCancels(typedContractAddr, typedPairStr).Add(&types.Cancellation{ + w.MemState.GetBlockCancels(sdkCtx, typedContractAddr, typedPairStr).Add(&types.Cancellation{ Id: id, Initiator: types.CancellationInitiator_LIQUIDATED, }) @@ -60,21 +60,21 @@ func (w KeeperWrapper) HandleEBLiquidation(ctx context.Context, sdkCtx sdk.Conte func (w KeeperWrapper) PlaceLiquidationOrders(ctx sdk.Context, contractAddr string, liquidationOrders []types.Order) { ctx.Logger().Info("Placing liquidation orders...") - nextID := w.GetNextOrderID(ctx) + nextID := w.GetNextOrderID(ctx, contractAddr) for _, order := range liquidationOrders { ctx.Logger().Info(fmt.Sprintf("Liquidation order %s", order.String())) pair := types.Pair{PriceDenom: order.PriceDenom, AssetDenom: order.AssetDenom} - orders := w.MemState.GetBlockOrders(typesutils.ContractAddress(contractAddr), typesutils.GetPairString(&pair)) + orders := w.MemState.GetBlockOrders(ctx, typesutils.ContractAddress(contractAddr), typesutils.GetPairString(&pair)) order.Id = nextID orderCopy := order orders.Add(&orderCopy) nextID++ } - w.SetNextOrderID(ctx, nextID) + w.SetNextOrderID(ctx, contractAddr, nextID) } -func (w KeeperWrapper) getLiquidationSudoMsg(typedContractAddr typesutils.ContractAddress) wasm.SudoLiquidationMsg { - cachedLiquidationRequests := w.MemState.GetLiquidationRequests(typedContractAddr) +func (w KeeperWrapper) getLiquidationSudoMsg(ctx sdk.Context, typedContractAddr typesutils.ContractAddress) wasm.SudoLiquidationMsg { + cachedLiquidationRequests := w.MemState.GetLiquidationRequests(ctx, typedContractAddr) liquidationRequests := []wasm.LiquidationRequest{} for _, cachedLiquidationRequest := range cachedLiquidationRequests.Get() { liquidationRequests = append(liquidationRequests, wasm.LiquidationRequest{ diff --git a/x/dex/keeper/abci/end_block_liquidation_test.go b/x/dex/keeper/abci/end_block_liquidation_test.go index 95156b5f39..40fe030718 100644 --- a/x/dex/keeper/abci/end_block_liquidation_test.go +++ b/x/dex/keeper/abci/end_block_liquidation_test.go @@ -18,5 +18,5 @@ func TestPlaceLiquidationOrders(t *testing.T) { } wrapper := abci.KeeperWrapper{Keeper: keeper} wrapper.PlaceLiquidationOrders(ctx, keepertest.TestContract, []types.Order{liquidationOrder}) - require.Equal(t, 1, len(keeper.MemState.GetBlockOrders(typesutils.ContractAddress(keepertest.TestContract), typesutils.GetPairString(&keepertest.TestPair)).Get())) + require.Equal(t, 1, len(keeper.MemState.GetBlockOrders(ctx, typesutils.ContractAddress(keepertest.TestContract), typesutils.GetPairString(&keepertest.TestPair)).Get())) } diff --git a/x/dex/keeper/abci/end_block_place_orders.go b/x/dex/keeper/abci/end_block_place_orders.go index 02d919ab83..e1b1ef16b6 100644 --- a/x/dex/keeper/abci/end_block_place_orders.go +++ b/x/dex/keeper/abci/end_block_place_orders.go @@ -24,14 +24,9 @@ func (w KeeperWrapper) HandleEBPlaceOrders(ctx context.Context, sdkCtx sdk.Conte typedContractAddr := typesutils.ContractAddress(contractAddr) msgs := w.GetPlaceSudoMsg(sdkCtx, typedContractAddr, registeredPairs) - _, err := utils.CallContractSudo(sdkCtx, w.Keeper, contractAddr, msgs[0]) // deposit - if err != nil { - sdkCtx.Logger().Error(fmt.Sprintf("Error during deposit: %s", err.Error())) - return err - } responses := []wasm.SudoOrderPlacementResponse{} - for _, msg := range msgs[1:] { + for _, msg := range msgs { data, err := utils.CallContractSudo(sdkCtx, w.Keeper, contractAddr, msg) if err != nil { sdkCtx.Logger().Error(fmt.Sprintf("Error during order placement: %s", err.Error())) @@ -49,7 +44,7 @@ func (w KeeperWrapper) HandleEBPlaceOrders(ctx context.Context, sdkCtx sdk.Conte for _, pair := range registeredPairs { typedPairStr := typesutils.GetPairString(&pair) //nolint:gosec // USING THE POINTER HERE COULD BE BAD, LET'S CHECK IT. for _, response := range responses { - w.MemState.GetBlockOrders(typedContractAddr, typedPairStr).MarkFailedToPlace(response.UnsuccessfulOrders) + w.MemState.GetBlockOrders(sdkCtx, typedContractAddr, typedPairStr).MarkFailedToPlace(response.UnsuccessfulOrders) } } span.End() @@ -57,11 +52,11 @@ func (w KeeperWrapper) HandleEBPlaceOrders(ctx context.Context, sdkCtx sdk.Conte } func (w KeeperWrapper) GetPlaceSudoMsg(ctx sdk.Context, typedContractAddr typesutils.ContractAddress, registeredPairs []types.Pair) []wasm.SudoOrderPlacementMsg { - msgs := []wasm.SudoOrderPlacementMsg{w.GetDepositSudoMsg(ctx, typedContractAddr)} + msgs := []wasm.SudoOrderPlacementMsg{} contractOrderPlacements := []types.Order{} for _, pair := range registeredPairs { typedPairStr := typesutils.GetPairString(&pair) //nolint:gosec // USING THE POINTER HERE COULD BE BAD, LET'S CHECK IT. - for _, order := range w.MemState.GetBlockOrders(typedContractAddr, typedPairStr).Get() { + for _, order := range w.MemState.GetBlockOrders(ctx, typedContractAddr, typedPairStr).Get() { contractOrderPlacements = append(contractOrderPlacements, *order) if len(contractOrderPlacements) == MaxOrdersPerSudoCall { msgs = append(msgs, wasm.SudoOrderPlacementMsg{ @@ -82,29 +77,3 @@ func (w KeeperWrapper) GetPlaceSudoMsg(ctx sdk.Context, typedContractAddr typesu }) return msgs } - -func (w KeeperWrapper) GetDepositSudoMsg(ctx sdk.Context, typedContractAddr typesutils.ContractAddress) wasm.SudoOrderPlacementMsg { - contractDepositInfo := []wasm.ContractDepositInfo{} - for _, depositInfo := range w.MemState.GetDepositInfo(typedContractAddr).Get() { - fund := sdk.NewCoins(sdk.NewCoin(depositInfo.Denom, depositInfo.Amount.RoundInt())) - sender, err := sdk.AccAddressFromBech32(depositInfo.Creator) - if err != nil { - ctx.Logger().Error("Invalid deposit creator") - } - receiver, err := sdk.AccAddressFromBech32(string(typedContractAddr)) - if err != nil { - ctx.Logger().Error("Invalid deposit contract") - } - if err := w.BankKeeper.SendCoins(ctx, sender, receiver, fund); err == nil { - contractDepositInfo = append(contractDepositInfo, depositInfo.ToContractDepositInfo()) - } else { - ctx.Logger().Error(err.Error()) - } - } - return wasm.SudoOrderPlacementMsg{ - OrderPlacements: wasm.OrderPlacementMsgDetails{ - Orders: []types.Order{}, - Deposits: contractDepositInfo, - }, - } -} diff --git a/x/dex/keeper/abci/end_block_place_orders_test.go b/x/dex/keeper/abci/end_block_place_orders_test.go index d536162e94..b46ac896ed 100644 --- a/x/dex/keeper/abci/end_block_place_orders_test.go +++ b/x/dex/keeper/abci/end_block_place_orders_test.go @@ -17,7 +17,7 @@ import ( func TestGetPlaceSudoMsg(t *testing.T) { pair := types.Pair{PriceDenom: keepertest.TestPriceDenom, AssetDenom: keepertest.TestAssetDenom} keeper, ctx := keepertest.DexKeeper(t) - keeper.MemState.GetBlockOrders(keepertest.TestContract, typesutils.GetPairString(&pair)).Add( + keeper.MemState.GetBlockOrders(ctx, keepertest.TestContract, typesutils.GetPairString(&pair)).Add( &types.Order{ Id: 1, Price: sdk.OneDec(), @@ -31,7 +31,7 @@ func TestGetPlaceSudoMsg(t *testing.T) { ) wrapper := abci.KeeperWrapper{Keeper: keeper} msgs := wrapper.GetPlaceSudoMsg(ctx, keepertest.TestContract, []types.Pair{pair}) - require.Equal(t, 2, len(msgs)) + require.Equal(t, 1, len(msgs)) } func TestGetDepositSudoMsg(t *testing.T) { @@ -43,7 +43,7 @@ func TestGetDepositSudoMsg(t *testing.T) { bankkeeper.MintCoins(ctx, minttypes.ModuleName, amounts) bankkeeper.SendCoinsFromModuleToAccount(ctx, minttypes.ModuleName, testAccount, amounts) keeper := testApp.DexKeeper - keeper.MemState.GetDepositInfo(keepertest.TestContract).Add( + keeper.MemState.GetDepositInfo(ctx, keepertest.TestContract).Add( &dex.DepositInfoEntry{ Creator: testAccount.String(), Denom: amounts[0].Denom, diff --git a/x/dex/keeper/msgserver/msg_server_cancel_orders.go b/x/dex/keeper/msgserver/msg_server_cancel_orders.go index ce70b5809d..dc68382caf 100644 --- a/x/dex/keeper/msgserver/msg_server_cancel_orders.go +++ b/x/dex/keeper/msgserver/msg_server_cancel_orders.go @@ -22,7 +22,7 @@ func (k msgServer) CancelOrders(goCtx context.Context, msg *types.MsgCancelOrder order := orderMap[orderIDToCancel] pair := types.Pair{PriceDenom: order.PriceDenom, AssetDenom: order.AssetDenom} pairStr := typesutils.GetPairString(&pair) - pairBlockCancellations := k.MemState.GetBlockCancels(typesutils.ContractAddress(msg.GetContractAddr()), pairStr) + pairBlockCancellations := k.MemState.GetBlockCancels(ctx, typesutils.ContractAddress(msg.GetContractAddr()), pairStr) cancelledInCurrentBlock := false for _, cancelInCurrentBlock := range pairBlockCancellations.Get() { if cancelInCurrentBlock.Id == orderIDToCancel { diff --git a/x/dex/keeper/msgserver/msg_server_liquidate.go b/x/dex/keeper/msgserver/msg_server_liquidate.go index ca00e9f0e3..e7a1c629ce 100644 --- a/x/dex/keeper/msgserver/msg_server_liquidate.go +++ b/x/dex/keeper/msgserver/msg_server_liquidate.go @@ -3,14 +3,16 @@ package msgserver import ( "context" + sdk "github.com/cosmos/cosmos-sdk/types" dexcache "github.com/sei-protocol/sei-chain/x/dex/cache" "github.com/sei-protocol/sei-chain/x/dex/types" typesutils "github.com/sei-protocol/sei-chain/x/dex/types/utils" ) func (k msgServer) Liquidate(goCtx context.Context, msg *types.MsgLiquidation) (*types.MsgLiquidationResponse, error) { + ctx := sdk.UnwrapSDKContext(goCtx) k.MemState.GetLiquidationRequests( - typesutils.ContractAddress(msg.GetContractAddr()), + ctx, typesutils.ContractAddress(msg.GetContractAddr()), ).Add(&dexcache.LiquidationRequest{Requestor: msg.Creator, AccountToLiquidate: msg.AccountToLiquidate}) return &types.MsgLiquidationResponse{}, nil diff --git a/x/dex/keeper/msgserver/msg_server_place_orders.go b/x/dex/keeper/msgserver/msg_server_place_orders.go index 09000bbbbf..55e73cbc5c 100644 --- a/x/dex/keeper/msgserver/msg_server_place_orders.go +++ b/x/dex/keeper/msgserver/msg_server_place_orders.go @@ -41,7 +41,7 @@ func (k msgServer) transferFunds(goCtx context.Context, msg *types.MsgPlaceOrder if fund.Amount.IsNil() || fund.IsNegative() { return errors.New("fund deposits cannot be nil or negative") } - k.MemState.GetDepositInfo(typesutils.ContractAddress(msg.GetContractAddr())).Add(&dexcache.DepositInfoEntry{ + k.MemState.GetDepositInfo(ctx, typesutils.ContractAddress(msg.GetContractAddr())).Add(&dexcache.DepositInfoEntry{ Creator: msg.Creator, Denom: fund.Denom, Amount: sdk.NewDec(fund.Amount.Int64()), @@ -84,7 +84,7 @@ func (k msgServer) PlaceOrders(goCtx context.Context, msg *types.MsgPlaceOrders) return nil, err } - nextID := k.GetNextOrderID(ctx) + nextID := k.GetNextOrderID(ctx, msg.ContractAddr) idsInResp := []uint64{} for _, order := range msg.GetOrders() { ticksize, found := k.Keeper.GetTickSizeForPair(ctx, msg.GetContractAddr(), types.Pair{PriceDenom: order.PriceDenom, AssetDenom: order.AssetDenom}) @@ -96,11 +96,11 @@ func (k msgServer) PlaceOrders(goCtx context.Context, msg *types.MsgPlaceOrders) order.Id = nextID order.Account = msg.Creator order.ContractAddr = msg.GetContractAddr() - k.MemState.GetBlockOrders(typesutils.ContractAddress(msg.GetContractAddr()), pairStr).Add(order) + k.MemState.GetBlockOrders(ctx, typesutils.ContractAddress(msg.GetContractAddr()), pairStr).Add(order) idsInResp = append(idsInResp, nextID) nextID++ } - k.SetNextOrderID(ctx, nextID) + k.SetNextOrderID(ctx, msg.ContractAddr, nextID) return &types.MsgPlaceOrdersResponse{ OrderIds: idsInResp, diff --git a/x/dex/keeper/order_placement.go b/x/dex/keeper/order_placement.go index 8af24cc00b..3dab2f3880 100644 --- a/x/dex/keeper/order_placement.go +++ b/x/dex/keeper/order_placement.go @@ -8,8 +8,8 @@ import ( "github.com/sei-protocol/sei-chain/x/dex/types" ) -func (k Keeper) GetNextOrderID(ctx sdk.Context) uint64 { - store := prefix.NewStore(ctx.KVStore(k.storeKey), []byte{}) +func (k Keeper) GetNextOrderID(ctx sdk.Context, contractAddr string) uint64 { + store := prefix.NewStore(ctx.KVStore(k.storeKey), types.NextOrderIDPrefix(contractAddr)) byteKey := types.KeyPrefix(types.NextOrderIDKey) bz := store.Get(byteKey) if bz == nil { @@ -18,8 +18,8 @@ func (k Keeper) GetNextOrderID(ctx sdk.Context) uint64 { return binary.BigEndian.Uint64(bz) } -func (k Keeper) SetNextOrderID(ctx sdk.Context, nextID uint64) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), []byte{}) +func (k Keeper) SetNextOrderID(ctx sdk.Context, contractAddr string, nextID uint64) { + store := prefix.NewStore(ctx.KVStore(k.storeKey), types.NextOrderIDPrefix(contractAddr)) byteKey := types.KeyPrefix(types.NextOrderIDKey) bz := make([]byte, 8) binary.BigEndian.PutUint64(bz, nextID) diff --git a/x/dex/keeper/query/grpc_query_order_simulation.go b/x/dex/keeper/query/grpc_query_order_simulation.go index d5f30d28dd..9aefbffef1 100644 --- a/x/dex/keeper/query/grpc_query_order_simulation.go +++ b/x/dex/keeper/query/grpc_query_order_simulation.go @@ -52,7 +52,7 @@ func (k KeeperWrapper) getMatchedPriceQuantities(ctx sdk.Context, req *types.Que // exclude liquidity to be cancelled pair := types.Pair{PriceDenom: req.Order.PriceDenom, AssetDenom: req.Order.AssetDenom} - for _, cancel := range k.MemState.GetBlockCancels(utils.ContractAddress(req.Order.ContractAddr), utils.GetPairString(&pair)).Get() { + for _, cancel := range k.MemState.GetBlockCancels(ctx, utils.ContractAddress(req.Order.ContractAddr), utils.GetPairString(&pair)).Get() { orderToBeCancelled := k.GetOrdersByIds(ctx, req.Order.ContractAddr, []uint64{cancel.Id}) if _, ok := orderToBeCancelled[cancel.Id]; !ok { continue @@ -80,7 +80,7 @@ func (k KeeperWrapper) getMatchedPriceQuantities(ctx sdk.Context, req *types.Que // exclude liquidity to be taken ptr := 0 - for _, order := range k.MemState.GetBlockOrders(utils.ContractAddress(req.Order.ContractAddr), utils.GetPairString(&pair)).GetSortedMarketOrders( + for _, order := range k.MemState.GetBlockOrders(ctx, utils.ContractAddress(req.Order.ContractAddr), utils.GetPairString(&pair)).GetSortedMarketOrders( orderDirection, false, ) { // If existing market order has price zero, it means it doesn't specify a worst price and will always have precedence over the simulated diff --git a/x/dex/keeper/query/grpc_query_order_simulation_test.go b/x/dex/keeper/query/grpc_query_order_simulation_test.go index d620e84bbf..ca0f948188 100644 --- a/x/dex/keeper/query/grpc_query_order_simulation_test.go +++ b/x/dex/keeper/query/grpc_query_order_simulation_test.go @@ -70,7 +70,7 @@ func TestGetOrderSimulation(t *testing.T) { Quantity: sdk.MustNewDecFromStr("2"), PositionDirection: types.PositionDirection_SHORT, }) - keeper.MemState.GetBlockCancels(utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&keepertest.TestPair)).Add( + keeper.MemState.GetBlockCancels(ctx, utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&keepertest.TestPair)).Add( &types.Cancellation{Id: 1}, ) res, err = wrapper.GetOrderSimulation(wctx, &types.QueryOrderSimulationRequest{Order: &testOrder}) @@ -78,7 +78,7 @@ func TestGetOrderSimulation(t *testing.T) { require.Equal(t, sdk.MustNewDecFromStr("4"), *res.ExecutedQuantity) // liquidity taken by earlier market orders - keeper.MemState.GetBlockOrders(utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&keepertest.TestPair)).Add( + keeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&keepertest.TestPair)).Add( &types.Order{ Account: keepertest.TestAccount, ContractAddr: keepertest.TestContract, diff --git a/x/dex/migrations/v6_to_v7.go b/x/dex/migrations/v6_to_v7.go new file mode 100644 index 0000000000..c7df17849f --- /dev/null +++ b/x/dex/migrations/v6_to_v7.go @@ -0,0 +1,84 @@ +package migrations + +import ( + "encoding/binary" + + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/sei-protocol/sei-chain/x/dex/keeper" + "github.com/sei-protocol/sei-chain/x/dex/types" +) + +func V6ToV7(ctx sdk.Context, storeKey sdk.StoreKey) error { + backfillOrderIDPerContract(ctx, storeKey) + reformatPriceState(ctx, storeKey) + return nil +} + +// this function backfills contract order ID according to the old global order ID +func backfillOrderIDPerContract(ctx sdk.Context, storeKey sdk.StoreKey) { + oldStore := prefix.NewStore( + ctx.KVStore(storeKey), + []byte{}, + ) + oldKey := types.KeyPrefix(types.NextOrderIDKey) + oldIDBytes := oldStore.Get(oldKey) + if oldIDBytes == nil { + // nothing to backfill + return + } + oldID := binary.BigEndian.Uint64(oldIDBytes) + + contractStore := prefix.NewStore(ctx.KVStore(storeKey), []byte(keeper.ContractPrefixKey)) + iterator := sdk.KVStorePrefixIterator(contractStore, []byte{}) + + defer iterator.Close() + + for ; iterator.Valid(); iterator.Next() { + contract := types.ContractInfo{} + if err := contract.Unmarshal(iterator.Value()); err == nil { + if contract.NeedOrderMatching { + newIDStore := prefix.NewStore(ctx.KVStore(storeKey), types.NextOrderIDPrefix(contract.ContractAddr)) + byteKey := types.KeyPrefix(types.NextOrderIDKey) + bz := make([]byte, 8) + binary.BigEndian.PutUint64(bz, oldID) + newIDStore.Set(byteKey, bz) + } + } + } +} + +func reformatPriceState(ctx sdk.Context, storeKey sdk.StoreKey) { + contractStore := prefix.NewStore(ctx.KVStore(storeKey), []byte(keeper.ContractPrefixKey)) + iterator := sdk.KVStorePrefixIterator(contractStore, []byte{}) + + defer iterator.Close() + + for ; iterator.Valid(); iterator.Next() { + contract := types.ContractInfo{} + if err := contract.Unmarshal(iterator.Value()); err == nil { + pairStore := prefix.NewStore(ctx.KVStore(storeKey), types.RegisteredPairPrefix(contract.ContractAddr)) + pairIterator := sdk.KVStorePrefixIterator(pairStore, []byte{}) + for ; pairIterator.Valid(); pairIterator.Next() { + pair := types.Pair{} + if err := pair.Unmarshal(pairIterator.Value()); err == nil { + oldPriceStore := prefix.NewStore(ctx.KVStore(storeKey), append( + append( + append(types.KeyPrefix(types.PriceKey), types.KeyPrefix(contract.ContractAddr)...), + types.KeyPrefix(pair.PriceDenom)..., + ), + types.KeyPrefix(pair.AssetDenom)..., + )) + newPriceStore := prefix.NewStore(ctx.KVStore(storeKey), types.PricePrefix(contract.ContractAddr, pair.PriceDenom, pair.AssetDenom)) + oldPriceIterator := sdk.KVStorePrefixIterator(oldPriceStore, []byte{}) + for ; oldPriceIterator.Valid(); oldPriceIterator.Next() { + newPriceStore.Set(oldPriceIterator.Key(), oldPriceIterator.Value()) + } + oldPriceIterator.Close() + } + } + + pairIterator.Close() + } + } +} diff --git a/x/dex/migrations/v6_to_v7_test.go b/x/dex/migrations/v6_to_v7_test.go new file mode 100644 index 0000000000..e00f9adbce --- /dev/null +++ b/x/dex/migrations/v6_to_v7_test.go @@ -0,0 +1,89 @@ +package migrations_test + +import ( + "encoding/binary" + "testing" + + "github.com/cosmos/cosmos-sdk/store" + "github.com/cosmos/cosmos-sdk/store/prefix" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + keepertest "github.com/sei-protocol/sei-chain/testutil/keeper" + "github.com/sei-protocol/sei-chain/x/dex/keeper" + "github.com/sei-protocol/sei-chain/x/dex/migrations" + "github.com/sei-protocol/sei-chain/x/dex/types" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + tmdb "github.com/tendermint/tm-db" +) + +func TestMigrate6to7(t *testing.T) { + storeKey := sdk.NewKVStoreKey(types.StoreKey) + memStoreKey := storetypes.NewMemoryStoreKey(types.MemStoreKey) + + db := tmdb.NewMemDB() + stateStore := store.NewCommitMultiStore(db) + stateStore.MountStoreWithDB(storeKey, sdk.StoreTypeIAVL, db) + stateStore.MountStoreWithDB(memStoreKey, sdk.StoreTypeMemory, nil) + require.NoError(t, stateStore.LoadLatestVersion()) + + ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, log.NewNopLogger()) + + // write old order ID + store := prefix.NewStore(ctx.KVStore(storeKey), []byte{}) + oldID := make([]byte, 8) + binary.BigEndian.PutUint64(oldID, 10) + store.Set(types.KeyPrefix(types.NextOrderIDKey), oldID) + + // write old price state + store = prefix.NewStore(ctx.KVStore(storeKey), append( + append( + append(types.KeyPrefix(types.PriceKey), types.KeyPrefix(keepertest.TestContract)...), + types.KeyPrefix(keepertest.TestPriceDenom)..., + ), + types.KeyPrefix(keepertest.TestAssetDenom)...), + ) + + price := types.Price{ + SnapshotTimestampInSeconds: 5, + Price: sdk.MustNewDecFromStr("123.4"), + Pair: &keepertest.TestPair, + } + priceBytes, _ := price.Marshal() + store.Set(keeper.GetKeyForTs(price.SnapshotTimestampInSeconds), priceBytes) + + // register contract / pair + store = prefix.NewStore( + ctx.KVStore(storeKey), + []byte(keeper.ContractPrefixKey), + ) + contract := types.ContractInfo{ + CodeId: 1, + ContractAddr: keepertest.TestContract, + NeedOrderMatching: true, + } + contractBytes, _ := contract.Marshal() + store.Set([]byte(contract.ContractAddr), contractBytes) + + store = prefix.NewStore(ctx.KVStore(storeKey), types.RegisteredPairPrefix(keepertest.TestContract)) + keyBytes := make([]byte, 8) + binary.BigEndian.PutUint64(keyBytes, 0) + pairBytes, _ := keepertest.TestPair.Marshal() + store.Set(keyBytes, pairBytes) + + err := migrations.V6ToV7(ctx, storeKey) + require.Nil(t, err) + + store = prefix.NewStore(ctx.KVStore(storeKey), types.NextOrderIDPrefix(keepertest.TestContract)) + byteKey := types.KeyPrefix(types.NextOrderIDKey) + bz := store.Get(byteKey) + require.Equal(t, uint64(10), binary.BigEndian.Uint64(bz)) + + store = prefix.NewStore(ctx.KVStore(storeKey), types.PricePrefix(keepertest.TestContract, keepertest.TestPriceDenom, keepertest.TestAssetDenom)) + key := keeper.GetKeyForTs(5) + priceRes := types.Price{} + b := store.Get(key) + _ = priceRes.Unmarshal(b) + require.Equal(t, price, priceRes) +} diff --git a/x/dex/module.go b/x/dex/module.go index 7968a3b437..ef2b5dfabb 100644 --- a/x/dex/module.go +++ b/x/dex/module.go @@ -20,7 +20,6 @@ import ( "github.com/cosmos/cosmos-sdk/telemetry" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/module" - "github.com/sei-protocol/sei-chain/utils/datastructures" "github.com/sei-protocol/sei-chain/utils/tracing" "github.com/sei-protocol/sei-chain/x/dex/client/cli/query" "github.com/sei-protocol/sei-chain/x/dex/client/cli/tx" @@ -29,11 +28,8 @@ import ( dexkeeperabci "github.com/sei-protocol/sei-chain/x/dex/keeper/abci" "github.com/sei-protocol/sei-chain/x/dex/keeper/msgserver" dexkeeperquery "github.com/sei-protocol/sei-chain/x/dex/keeper/query" - dexkeeperutils "github.com/sei-protocol/sei-chain/x/dex/keeper/utils" "github.com/sei-protocol/sei-chain/x/dex/migrations" "github.com/sei-protocol/sei-chain/x/dex/types" - dextypeswasm "github.com/sei-protocol/sei-chain/x/dex/types/wasm" - "github.com/sei-protocol/sei-chain/x/store" ) var ( @@ -179,6 +175,9 @@ func (am AppModule) RegisterServices(cfg module.Configurator) { _ = cfg.RegisterMigration(types.ModuleName, 5, func(ctx sdk.Context) error { return migrations.V5ToV6(ctx, am.keeper.GetStoreKey(), am.keeper.Cdc) }) + _ = cfg.RegisterMigration(types.ModuleName, 6, func(ctx sdk.Context) error { + return migrations.V6ToV7(ctx, am.keeper.GetStoreKey()) + }) } // RegisterInvariants registers the capability module's invariants. @@ -203,17 +202,10 @@ func (am AppModule) ExportGenesis(ctx sdk.Context, cdc codec.JSONCodec) json.Raw } // ConsensusVersion implements ConsensusVersion. -func (AppModule) ConsensusVersion() uint64 { return 6 } +func (AppModule) ConsensusVersion() uint64 { return 7 } func (am AppModule) getAllContractInfo(ctx sdk.Context) []types.ContractInfo { - unsorted := am.keeper.GetAllContractInfo(ctx) - sorted, err := contract.TopologicalSortContractInfo(unsorted) - if err != nil { - // This should never happen unless there is a bug in contract registration. - // Chain needs to be halted to prevent bad states from being written - panic(err) - } - return sorted + return am.keeper.GetAllContractInfo(ctx) } // BeginBlock executes all ABCI BeginBlock logic respective to the capability module. @@ -282,80 +274,22 @@ func (am AppModule) EndBlock(ctx sdk.Context, _ abci.RequestEndBlock) (ret []abc } }() - validContractAddresses := map[string]types.ContractInfo{} - for _, contractInfo := range am.getAllContractInfo(ctx) { - validContractAddresses[contractInfo.ContractAddr] = contractInfo - } + validContractsInfo := am.getAllContractInfo(ctx) // Each iteration is atomic. If an iteration finishes without any error, it will return, // otherwise it will rollback any state change, filter out contracts that cause the error, // and proceed to the next iteration. The loop is guaranteed to finish since // `validContractAddresses` will always decrease in size every iteration. - iterCounter := len(validContractAddresses) - for len(validContractAddresses) > 0 { - failedContractAddresses := datastructures.NewSyncSet([]string{}) - cachedCtx, msCached := store.GetCachedContext(ctx) - // cache keeper in-memory state - memStateCopy := am.keeper.MemState.DeepCopy() - finalizeBlockMessages := map[string]*dextypeswasm.SudoFinalizeBlockMsg{} - settlementsByContract := map[string][]*types.SettlementEntry{} - for contractAddr := range validContractAddresses { - finalizeBlockMessages[contractAddr] = dextypeswasm.NewSudoFinalizeBlockMsg() - settlementsByContract[contractAddr] = []*types.SettlementEntry{} - } - - for contractAddr, contractInfo := range validContractAddresses { - if !contractInfo.NeedOrderMatching { - continue - } - ctx.Logger().Info(fmt.Sprintf("End block for %s", contractAddr)) - if orderResultsMap, settlements, err := contract.HandleExecutionForContract(cachedCtx, contractInfo, &am.keeper, am.tracingInfo.Tracer); err != nil { - ctx.Logger().Error(fmt.Sprintf("Error for EndBlock of %s", contractAddr)) - failedContractAddresses.Add(contractAddr) - } else { - for account, orderResults := range orderResultsMap { - // only add to finalize message for contract addresses - if msg, ok := finalizeBlockMessages[account]; ok { - msg.AddContractResult(orderResults) - } - } - settlementsByContract[contractAddr] = settlements - } - } - - for contractAddr, settlements := range settlementsByContract { - if !validContractAddresses[contractAddr].NeedOrderMatching { - continue - } - if err := contract.HandleSettlements(cachedCtx, contractAddr, &am.keeper, settlements); err != nil { - ctx.Logger().Error(fmt.Sprintf("Error handling settlements for %s", contractAddr)) - failedContractAddresses.Add(contractAddr) - } - } - - for contractAddr, finalizeBlockMsg := range finalizeBlockMessages { - if !validContractAddresses[contractAddr].NeedHook { - continue - } - if _, err := dexkeeperutils.CallContractSudo(cachedCtx, &am.keeper, contractAddr, finalizeBlockMsg); err != nil { - ctx.Logger().Error(fmt.Sprintf("Error calling FinalizeBlock of %s", contractAddr)) - failedContractAddresses.Add(contractAddr) - } - } - - // No error is thrown for any contract. This should happen most of the time. - if failedContractAddresses.Size() == 0 { - msCached.Write() - return []abci.ValidatorUpdate{} - } - // restore keeper in-memory state - *am.keeper.MemState = *memStateCopy - // exclude orders by failed contracts from in-memory state, - // then update `validContractAddresses` - for _, failedContractAddress := range failedContractAddresses.ToOrderedSlice(datastructures.StringComparator) { - am.keeper.MemState.DeepFilterAccount(failedContractAddress) - delete(validContractAddresses, failedContractAddress) + iterCounter := len(validContractsInfo) + for len(validContractsInfo) > 0 { + newValidContractsInfo, ok := contract.EndBlockerAtomic(ctx, &am.keeper, validContractsInfo, am.tracingInfo.Tracer) + if ok { + break } + validContractsInfo = newValidContractsInfo + // technically we don't really need this if `EndBlockerAtomic` guarantees that `validContractsInfo` size will + // always shrink if not `ok`, but just in case, we decided to have an explicit termination criteria here to + // prevent the chain from being stuck. iterCounter-- if iterCounter == 0 { ctx.Logger().Error("All contracts failed in dex EndBlock. Doing nothing.") @@ -363,6 +297,5 @@ func (am AppModule) EndBlock(ctx sdk.Context, _ abci.RequestEndBlock) (ret []abc } } - // don't call `ctx.Write` if all contracts have error return []abci.ValidatorUpdate{} } diff --git a/x/dex/module_test.go b/x/dex/module_test.go index c43f7b9d6a..855b6bdd8e 100644 --- a/x/dex/module_test.go +++ b/x/dex/module_test.go @@ -62,7 +62,7 @@ func TestEndBlockMarketOrder(t *testing.T) { dexkeeper.SetContract(ctx, &types.ContractInfo{CodeId: 123, ContractAddr: contractAddr.String(), NeedHook: true, NeedOrderMatching: true}) dexkeeper.AddRegisteredPair(ctx, contractAddr.String(), pair) // place one order to a nonexistent contract - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( &types.Order{ Id: 1, Account: testAccount.String(), @@ -76,7 +76,7 @@ func TestEndBlockMarketOrder(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( &types.Order{ Id: 2, Account: testAccount.String(), @@ -90,7 +90,7 @@ func TestEndBlockMarketOrder(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexkeeper.MemState.GetDepositInfo(utils.ContractAddress(contractAddr.String())).Add( + dexkeeper.MemState.GetDepositInfo(ctx, utils.ContractAddress(contractAddr.String())).Add( &dexcache.DepositInfoEntry{ Creator: testAccount.String(), Denom: "usei", @@ -104,7 +104,7 @@ func TestEndBlockMarketOrder(t *testing.T) { require.True(t, found) dexkeeper.MemState.Clear() - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( &types.Order{ Id: 2, Account: testAccount.String(), @@ -136,7 +136,7 @@ func TestEndBlockMarketOrder(t *testing.T) { require.Equal(t, 1, len(settlements.Entries)) dexkeeper.MemState.Clear() - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( &types.Order{ Id: 3, Account: testAccount.String(), @@ -167,7 +167,7 @@ func TestEndBlockRollback(t *testing.T) { dexkeeper.SetContract(ctx, &types.ContractInfo{CodeId: 123, ContractAddr: keepertest.TestContract, NeedHook: true, NeedOrderMatching: true}) dexkeeper.AddRegisteredPair(ctx, keepertest.TestContract, pair) // place one order to a nonexistent contract - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&pair)).Add( &types.Order{ Id: 1, Account: keepertest.TestAccount, @@ -197,7 +197,7 @@ func TestEndBlockPartialRollback(t *testing.T) { dexkeeper.SetContract(ctx, &types.ContractInfo{CodeId: 123, ContractAddr: keepertest.TestContract, NeedHook: true, NeedOrderMatching: true}) dexkeeper.AddRegisteredPair(ctx, keepertest.TestContract, pair) // place one order to a nonexistent contract - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(keepertest.TestContract), utils.GetPairString(&pair)).Add( &types.Order{ Id: 1, Account: keepertest.TestAccount, @@ -235,7 +235,7 @@ func TestEndBlockPartialRollback(t *testing.T) { dexkeeper.SetContract(ctx, &types.ContractInfo{CodeId: 123, ContractAddr: contractAddr.String(), NeedHook: true, NeedOrderMatching: true}) dexkeeper.AddRegisteredPair(ctx, contractAddr.String(), pair) // place one order to a nonexistent contract - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( &types.Order{ Id: 2, Account: testAccount.String(), @@ -249,7 +249,7 @@ func TestEndBlockPartialRollback(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexkeeper.MemState.GetDepositInfo(utils.ContractAddress(contractAddr.String())).Add( + dexkeeper.MemState.GetDepositInfo(ctx, utils.ContractAddress(contractAddr.String())).Add( &dexcache.DepositInfoEntry{ Creator: testAccount.String(), Denom: "uusdc", @@ -335,7 +335,7 @@ func TestEndBlockPanicHandling(t *testing.T) { } dexkeeper.SetContract(ctx, &types.ContractInfo{CodeId: 123, ContractAddr: contractAddr.String(), NeedHook: true, NeedOrderMatching: true}) dexkeeper.AddRegisteredPair(ctx, contractAddr.String(), pair) - dexkeeper.MemState.GetBlockOrders(utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( + dexkeeper.MemState.GetBlockOrders(ctx, utils.ContractAddress(contractAddr.String()), utils.GetPairString(&pair)).Add( &types.Order{ Id: 1, Account: testAccount.String(), @@ -349,7 +349,7 @@ func TestEndBlockPanicHandling(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexkeeper.MemState.GetDepositInfo(utils.ContractAddress(contractAddr.String())).Add( + dexkeeper.MemState.GetDepositInfo(ctx, utils.ContractAddress(contractAddr.String())).Add( &dexcache.DepositInfoEntry{ Creator: testAccount.String(), Denom: "usei", diff --git a/x/dex/types/keys.go b/x/dex/types/keys.go index 81f9ed77bb..88534e9288 100644 --- a/x/dex/types/keys.go +++ b/x/dex/types/keys.go @@ -54,11 +54,8 @@ func TwapPrefix(contractAddr string) []byte { // `Price` constant + contract + price denom + asset denom func PricePrefix(contractAddr string, priceDenom string, assetDenom string) []byte { return append( - append( - append(KeyPrefix(PriceKey), KeyPrefix(contractAddr)...), - KeyPrefix(priceDenom)..., - ), - KeyPrefix(assetDenom)..., + append(KeyPrefix(PriceKey), KeyPrefix(contractAddr)...), + PairPrefix(priceDenom, assetDenom)..., ) } @@ -103,17 +100,19 @@ func AssetListPrefix(assetDenom string) []byte { return append(KeyPrefix(AssetListKey), KeyPrefix(assetDenom)...) } +func NextOrderIDPrefix(contractAddr string) []byte { + return append(KeyPrefix(NextOrderIDKey), KeyPrefix(contractAddr)...) +} + const ( DefaultPriceDenom = "stake" DefaultAssetDenom = "dummy" ) const ( - LongBookKey = "LongBook-value-" - LongBookCountKey = "LongBook-count-" + LongBookKey = "LongBook-value-" - ShortBookKey = "ShortBook-value-" - ShortBookCountKey = "ShortBook-count-" + ShortBookKey = "ShortBook-value-" OrderKey = "order" AccountActiveOrdersKey = "account-active-orders" diff --git a/x/dex/types/keys_test.go b/x/dex/types/keys_test.go index 2e77f8a11a..a091dade92 100644 --- a/x/dex/types/keys_test.go +++ b/x/dex/types/keys_test.go @@ -18,7 +18,7 @@ func TestPricePrefix(t *testing.T) { testPriceDenom := "SEI" testAssetDenom := "ATOM" priceContractBytes := append([]byte(types.PriceKey), []byte(testContract)...) - pairBytes := append([]byte(testPriceDenom), []byte(testAssetDenom)...) + pairBytes := types.PairPrefix(testPriceDenom, testAssetDenom) expectedKey := append(priceContractBytes, pairBytes...) require.Equal(t, expectedKey, types.PricePrefix(testContract, testPriceDenom, testAssetDenom)) }