Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions store/whitelist/cachemulti/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ type sdkCacheMultiStore = storetypes.CacheMultiStore
type Store struct {
sdkCacheMultiStore

storeKeyToWriteWhitelist map[storetypes.StoreKey][]string
storeKeyToWriteWhitelist map[string][]string
}

func NewStore(parent storetypes.CacheMultiStore, storeKeyToWriteWhitelist map[storetypes.StoreKey][]string) storetypes.CacheMultiStore {
func NewStore(parent storetypes.CacheMultiStore, storeKeyToWriteWhitelist map[string][]string) storetypes.CacheMultiStore {
return &Store{
sdkCacheMultiStore: parent,
storeKeyToWriteWhitelist: storeKeyToWriteWhitelist,
Expand All @@ -29,8 +29,9 @@ func (cms Store) CacheMultiStore() storetypes.CacheMultiStore {

func (cms Store) GetKVStore(key storetypes.StoreKey) storetypes.KVStore {
rawKVStore := cms.sdkCacheMultiStore.GetKVStore(key)
if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key]; ok {
if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key.Name()]; ok {
return kv.NewStore(rawKVStore, writeWhitelist)
}
return rawKVStore
// whitelist nothing
return kv.NewStore(rawKVStore, []string{})
}
8 changes: 4 additions & 4 deletions store/whitelist/cachemulti/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
var (
WhitelistedStoreKey = storetypes.NewKVStoreKey("whitelisted")
NotWhitelistedStoreKey = storetypes.NewKVStoreKey("not-whitelisted")
TestStoreKeyToWriteWhitelist = map[storetypes.StoreKey][]string{
WhitelistedStoreKey: {"foo"},
TestStoreKeyToWriteWhitelist = map[string][]string{
WhitelistedStoreKey.Name(): {"foo"},
}
)

Expand All @@ -37,8 +37,8 @@ func TestWhitelistDisabledStore(t *testing.T) {
multistore := store.NewTestCacheMultiStore(stores)
whitelistMultistore := cachemulti.NewStore(multistore, TestStoreKeyToWriteWhitelist)
kvStore := whitelistMultistore.GetKVStore(NotWhitelistedStoreKey)
require.NotPanics(t, func() { kvStore.Delete([]byte("bar")) })
require.NotPanics(t, func() { kvStore.Delete([]byte("foo")) })
require.Panics(t, func() { kvStore.Delete([]byte("bar")) })
require.Panics(t, func() { kvStore.Delete([]byte("foo")) })
}

func TestCacheStillWhitelist(t *testing.T) {
Expand Down
9 changes: 5 additions & 4 deletions store/whitelist/multi/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
type Store struct {
storetypes.MultiStore

storeKeyToWriteWhitelist map[storetypes.StoreKey][]string
storeKeyToWriteWhitelist map[string][]string
}

func NewStore(parent storetypes.MultiStore, storeKeyToWriteWhitelist map[storetypes.StoreKey][]string) storetypes.MultiStore {
func NewStore(parent storetypes.MultiStore, storeKeyToWriteWhitelist map[string][]string) storetypes.MultiStore {
return &Store{
MultiStore: parent,
storeKeyToWriteWhitelist: storeKeyToWriteWhitelist,
Expand All @@ -25,8 +25,9 @@ func (cms Store) CacheMultiStore() storetypes.CacheMultiStore {

func (cms Store) GetKVStore(key storetypes.StoreKey) storetypes.KVStore {
rawKVStore := cms.MultiStore.GetKVStore(key)
if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key]; ok {
if writeWhitelist, ok := cms.storeKeyToWriteWhitelist[key.Name()]; ok {
return kv.NewStore(rawKVStore, writeWhitelist)
}
return rawKVStore
// whitelist nothing
return kv.NewStore(rawKVStore, []string{})
}
8 changes: 4 additions & 4 deletions store/whitelist/multi/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
var (
WhitelistedStoreKey = storetypes.NewKVStoreKey("whitelisted")
NotWhitelistedStoreKey = storetypes.NewKVStoreKey("not-whitelisted")
TestStoreKeyToWriteWhitelist = map[storetypes.StoreKey][]string{
WhitelistedStoreKey: {"foo"},
TestStoreKeyToWriteWhitelist = map[string][]string{
WhitelistedStoreKey.Name(): {"foo"},
}
)

Expand All @@ -37,8 +37,8 @@ func TestWhitelistDisabledStore(t *testing.T) {
multistore := store.NewTestCacheMultiStore(stores)
whitelistMultistore := multi.NewStore(multistore, TestStoreKeyToWriteWhitelist)
kvStore := whitelistMultistore.GetKVStore(NotWhitelistedStoreKey)
require.NotPanics(t, func() { kvStore.Delete([]byte("bar")) })
require.NotPanics(t, func() { kvStore.Delete([]byte("foo")) })
require.Panics(t, func() { kvStore.Delete([]byte("bar")) })
require.Panics(t, func() { kvStore.Delete([]byte("foo")) })
}

func TestCacheStillWhitelist(t *testing.T) {
Expand Down
28 changes: 28 additions & 0 deletions utils/panic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package utils

import (
"fmt"

"github.com/armon/go-metrics"
"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
)

func PanicHandler(recoverCallback func(any)) func() {
return func() {
if err := recover(); err != nil {
recoverCallback(err)
}
}
}

func MetricsPanicCallback(err any, ctx sdk.Context, key string) {
ctx.Logger().Error(fmt.Sprintf("panic occurred during order matching for: %s", key))
telemetry.IncrCounterWithLabels(
[]string{key},
1,
[]metrics.Label{
telemetry.NewLabel("error", fmt.Sprintf("%s", err)),
},
)
}
124 changes: 94 additions & 30 deletions x/dex/cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
package dex

import (
"fmt"
"time"

sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/sei-protocol/sei-chain/utils/datastructures"
typesutils "github.com/sei-protocol/sei-chain/x/dex/types/utils"
)

const SynchronizationTimeoutInSeconds = 5

type memStateItem interface {
GetAccount() string
}
Expand Down Expand Up @@ -46,88 +53,145 @@ func (i *memStateItems[T]) Copy() *memStateItems[T] {
}

type MemState struct {
BlockOrders *datastructures.TypedNestedSyncMap[
blockOrders *datastructures.TypedNestedSyncMap[
typesutils.ContractAddress,
typesutils.PairString,
*BlockOrders,
]
BlockCancels *datastructures.TypedNestedSyncMap[
blockCancels *datastructures.TypedNestedSyncMap[
typesutils.ContractAddress,
typesutils.PairString,
*BlockCancellations,
]
DepositInfo *datastructures.TypedSyncMap[typesutils.ContractAddress, *DepositInfo]
LiquidationRequests *datastructures.TypedSyncMap[typesutils.ContractAddress, *LiquidationRequests]
depositInfo *datastructures.TypedSyncMap[typesutils.ContractAddress, *DepositInfo]
liquidationRequests *datastructures.TypedSyncMap[typesutils.ContractAddress, *LiquidationRequests]
}

func NewMemState() *MemState {
return &MemState{
BlockOrders: datastructures.NewTypedNestedSyncMap[
blockOrders: datastructures.NewTypedNestedSyncMap[
typesutils.ContractAddress,
typesutils.PairString,
*BlockOrders,
](),
BlockCancels: datastructures.NewTypedNestedSyncMap[
blockCancels: datastructures.NewTypedNestedSyncMap[
typesutils.ContractAddress,
typesutils.PairString,
*BlockCancellations,
](),
DepositInfo: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo](),
LiquidationRequests: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests](),
depositInfo: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo](),
liquidationRequests: datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests](),
}
}

func (s *MemState) GetBlockOrders(contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockOrders {
ordersForPair, _ := s.BlockOrders.LoadOrStoreNested(contractAddr, pair, NewOrders())
func (s *MemState) GetAllBlockOrders(ctx sdk.Context, contractAddr typesutils.ContractAddress) *datastructures.TypedSyncMap[typesutils.PairString, *BlockOrders] {
s.SynchronizeAccess(ctx, contractAddr)
ordersMap, _ := s.blockOrders.LoadOrStore(contractAddr, datastructures.NewTypedSyncMap[typesutils.PairString, *BlockOrders]())
return ordersMap
}

func (s *MemState) GetBlockOrders(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockOrders {
s.SynchronizeAccess(ctx, contractAddr)
ordersForPair, _ := s.blockOrders.LoadOrStoreNested(contractAddr, pair, NewOrders())
return ordersForPair
}

func (s *MemState) GetBlockCancels(contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockCancellations {
cancelsForPair, _ := s.BlockCancels.LoadOrStoreNested(contractAddr, pair, NewCancels())
func (s *MemState) GetBlockCancels(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) *BlockCancellations {
s.SynchronizeAccess(ctx, contractAddr)
cancelsForPair, _ := s.blockCancels.LoadOrStoreNested(contractAddr, pair, NewCancels())
return cancelsForPair
}

func (s *MemState) GetDepositInfo(contractAddr typesutils.ContractAddress) *DepositInfo {
depositsForContract, _ := s.DepositInfo.LoadOrStore(contractAddr, NewDepositInfo())
func (s *MemState) GetDepositInfo(ctx sdk.Context, contractAddr typesutils.ContractAddress) *DepositInfo {
s.SynchronizeAccess(ctx, contractAddr)
depositsForContract, _ := s.depositInfo.LoadOrStore(contractAddr, NewDepositInfo())
return depositsForContract
}

func (s *MemState) GetLiquidationRequests(contractAddr typesutils.ContractAddress) *LiquidationRequests {
liquidationsForContract, _ := s.LiquidationRequests.LoadOrStore(contractAddr, NewLiquidationRequests())
func (s *MemState) GetLiquidationRequests(ctx sdk.Context, contractAddr typesutils.ContractAddress) *LiquidationRequests {
s.SynchronizeAccess(ctx, contractAddr)
liquidationsForContract, _ := s.liquidationRequests.LoadOrStore(contractAddr, NewLiquidationRequests())
return liquidationsForContract
}

func (s *MemState) Clear() {
s.BlockOrders = datastructures.NewTypedNestedSyncMap[
s.blockOrders = datastructures.NewTypedNestedSyncMap[
typesutils.ContractAddress,
typesutils.PairString,
*BlockOrders,
]()
s.BlockCancels = datastructures.NewTypedNestedSyncMap[
s.blockCancels = datastructures.NewTypedNestedSyncMap[
typesutils.ContractAddress,
typesutils.PairString,
*BlockCancellations,
]()
s.DepositInfo = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo]()
s.LiquidationRequests = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests]()
s.depositInfo = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *DepositInfo]()
s.liquidationRequests = datastructures.NewTypedSyncMap[typesutils.ContractAddress, *LiquidationRequests]()
}

func (s *MemState) ClearCancellationForPair(contractAddr typesutils.ContractAddress, pair typesutils.PairString) {
s.BlockCancels.StoreNested(contractAddr, pair, NewCancels())
func (s *MemState) ClearCancellationForPair(ctx sdk.Context, contractAddr typesutils.ContractAddress, pair typesutils.PairString) {
s.SynchronizeAccess(ctx, contractAddr)
s.blockCancels.StoreNested(contractAddr, pair, NewCancels())
}

func (s *MemState) DeepCopy() *MemState {
copy := NewMemState()
copy.BlockOrders = s.BlockOrders.DeepCopy(func(o *BlockOrders) *BlockOrders { return o.Copy() })
copy.BlockCancels = s.BlockCancels.DeepCopy(func(o *BlockCancellations) *BlockCancellations { return o.Copy() })
copy.DepositInfo = s.DepositInfo.DeepCopy(func(o *DepositInfo) *DepositInfo { return o.Copy() })
copy.LiquidationRequests = s.LiquidationRequests.DeepCopy(func(o *LiquidationRequests) *LiquidationRequests { return o.Copy() })
copy.blockOrders = s.blockOrders.DeepCopy(func(o *BlockOrders) *BlockOrders { return o.Copy() })
copy.blockCancels = s.blockCancels.DeepCopy(func(o *BlockCancellations) *BlockCancellations { return o.Copy() })
copy.depositInfo = s.depositInfo.DeepCopy(func(o *DepositInfo) *DepositInfo { return o.Copy() })
copy.liquidationRequests = s.liquidationRequests.DeepCopy(func(o *LiquidationRequests) *LiquidationRequests { return o.Copy() })
return copy
}

func (s *MemState) DeepFilterAccount(account string) {
s.BlockOrders.DeepApply(func(o *BlockOrders) { o.FilterByAccount(account) })
s.BlockCancels.DeepApply(func(o *BlockCancellations) { o.FilterByAccount(account) })
s.DepositInfo.DeepApply(func(o *DepositInfo) { o.FilterByAccount(account) })
s.LiquidationRequests.DeepApply(func(o *LiquidationRequests) { o.FilterByAccount(account) })
s.blockOrders.DeepApply(func(o *BlockOrders) { o.FilterByAccount(account) })
s.blockCancels.DeepApply(func(o *BlockCancellations) { o.FilterByAccount(account) })
s.depositInfo.DeepApply(func(o *DepositInfo) { o.FilterByAccount(account) })
s.liquidationRequests.DeepApply(func(o *LiquidationRequests) { o.FilterByAccount(account) })
}

func (s *MemState) SynchronizeAccess(ctx sdk.Context, contractAddr typesutils.ContractAddress) {
executingContract := GetExecutingContract(ctx)
if executingContract == nil {
// not accessed by contract. no need to synchronize
return
}
targetContractAddr := string(contractAddr)
if executingContract.ContractAddr == targetContractAddr {
// access by the contract itself does not need synchronization
return
}
for _, dependency := range executingContract.Dependencies {
if dependency.Dependency != targetContractAddr {
continue
}
terminationSignals := GetTerminationSignals(ctx)
if terminationSignals == nil {
// synchronization should fail in the case of no termination signal to prevent race conditions.
panic("no termination signal map found in context")
}
targetChannel, ok := terminationSignals.Load(dependency.ImmediateElderSibling)
if !ok {
// synchronization should fail in the case of no termination signal to prevent race conditions.
panic(fmt.Sprintf("no termination signal channel for contract %s in context", dependency.ImmediateElderSibling))
}

select {
case <-targetChannel:
// since buffered channel can only be consumed once, we need to
// requeue so that it can unblock other goroutines that waits for
// the same channel.
targetChannel <- struct{}{}
case <-time.After(SynchronizationTimeoutInSeconds * time.Second):
// synchronization should fail in the case of timeout to prevent race conditions.
panic(fmt.Sprintf("timing out waiting for termination of %s", dependency.ImmediateElderSibling))
}

return
}

// fail loudly so that the offending contract can be rolled back.
// eventually we will automatically de-register contracts that have to be rolled back
// so that this would not become a point of attack in terms of performance.
panic(fmt.Sprintf("Contract %s trying to access state of %s which is not a registered dependency", executingContract.ContractAddr, targetContractAddr))
}
Loading