Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions tools/preconf-rpc/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type BlockTracker interface {

type Sender interface {
Enqueue(ctx context.Context, txn *sender.Transaction) error
CancelTransaction(ctx context.Context, txHash common.Hash) (bool, error)
}

type rpcMethodHandler struct {
Expand Down Expand Up @@ -122,8 +123,6 @@ func (h *rpcMethodHandler) RegisterMethods(server *rpcserver.JSONRPCServer) {
server.RegisterHandler("eth_getTransactionCount", h.handleGetTxCount)
server.RegisterHandler("eth_getBlockByHash", h.handleGetBlockByHash)
// Custom methods for MEV Commit
server.RegisterHandler("mevcommit_getTransactionCommitments", h.handleGetTxCommitments)
server.RegisterHandler("mevcommit_getBalance", h.handleMevCommitGetBalance)
server.RegisterHandler("mevcommit_optInBlock", func(ctx context.Context, params ...any) (json.RawMessage, bool, error) {
timeToOptIn, err := h.bidder.Estimate()
if err != nil {
Expand Down Expand Up @@ -188,6 +187,9 @@ func (h *rpcMethodHandler) RegisterMethods(server *rpcserver.JSONRPCServer) {
h.logger.Debug("Estimated bridge price", "bidAmount", bridgeCost, "bridgeAddress", h.bridgeAddress.Hex())
return resultJSON, false, nil
})
server.RegisterHandler("mevcommit_cancelTransaction", h.handleCancelTransaction)
server.RegisterHandler("mevcommit_getTransactionCommitments", h.handleGetTxCommitments)
server.RegisterHandler("mevcommit_getBalance", h.handleMevCommitGetBalance)
}

func getNextBlockPrice(blockPrices *pricer.BlockPrices) *big.Int {
Expand Down Expand Up @@ -586,3 +588,46 @@ func (h *rpcMethodHandler) handleMevCommitGetBalance(ctx context.Context, params

return json.RawMessage(fmt.Sprintf(`{"balance": "%s"}`, balance)), false, nil
}

func (r *rpcMethodHandler) handleCancelTransaction(ctx context.Context, params ...any) (json.RawMessage, bool, error) {
if len(params) != 1 {
return nil, false, rpcserver.NewJSONErr(
rpcserver.CodeInvalidRequest,
"cancelTransaction requires exactly one parameter",
)
}

if params[0] == nil {
return nil, false, rpcserver.NewJSONErr(
rpcserver.CodeParseError,
"cancelTransaction parameter cannot be null",
)
}

txHashStr := params[0].(string)
if len(txHashStr) < 2 || txHashStr[:2] != "0x" {
Comment thread
aloknerurkar marked this conversation as resolved.
return nil, false, rpcserver.NewJSONErr(
rpcserver.CodeParseError,
"cancelTransaction parameter must be a hex string starting with '0x'",
)
}

txHash := common.HexToHash(txHashStr)

cancelled, err := r.sndr.CancelTransaction(ctx, txHash)
if err != nil {
r.logger.Error("Failed to cancel transaction", "error", err, "txHash", txHash)
return nil, false, rpcserver.NewJSONErr(
rpcserver.CodeCustomError,
"failed to cancel transaction",
)
}

if !cancelled {
r.logger.Info("Transaction not found or already processed", "txHash", txHash)
return json.RawMessage(fmt.Sprintf(`{"cancelled": false, "txHash": "%s"}`, txHash.Hex())), false, nil
}

r.logger.Info("Transaction cancelled successfully", "txHash", txHash)
return json.RawMessage(fmt.Sprintf(`{"cancelled": true, "txHash": "%s"}`, txHash.Hex())), false, nil
}
89 changes: 74 additions & 15 deletions tools/preconf-rpc/sender/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ var (
ErrEmptyTransactionTo = errors.New("empty transaction 'to' address")
ErrNegativeTransactionValue = errors.New("negative transaction value")
ErrZeroGasLimit = errors.New("zero gas limit")
ErrTransactionCancelled = errors.New("transaction cancelled by user")
)

type Transaction struct {
Expand All @@ -71,6 +72,7 @@ type Store interface {
AddBalance(ctx context.Context, account common.Address, amount *big.Int) error
DeductBalance(ctx context.Context, account common.Address, amount *big.Int) error
StoreTransaction(ctx context.Context, txn *Transaction, commitments []*bidderapiv1.Commitment) error
GetTransactionByHash(ctx context.Context, txnHash common.Hash) (*Transaction, error)
}

type Bidder interface {
Expand Down Expand Up @@ -119,9 +121,9 @@ type TxSender struct {
egCtx context.Context
trigger chan struct{}
workerPool chan struct{}
inflightTxns map[common.Hash]struct{}
inflightTxns map[common.Hash]chan struct{}
inflightAccount map[common.Address]struct{}
inflightMu sync.Mutex
inflightMu sync.RWMutex
txnAttemptHistory *lru.Cache[common.Hash, *txnAttempt]
}

Expand Down Expand Up @@ -150,7 +152,7 @@ func NewTxSender(
logger: logger.With("component", "TxSender"),
workerPool: make(chan struct{}, 512),
trigger: make(chan struct{}, 1),
inflightTxns: make(map[common.Hash]struct{}),
inflightTxns: make(map[common.Hash]chan struct{}),
inflightAccount: make(map[common.Address]struct{}),
txnAttemptHistory: txnAttemptHistory,
}, nil
Expand Down Expand Up @@ -210,6 +212,57 @@ func (t *TxSender) Enqueue(ctx context.Context, tx *Transaction) error {
return nil
}

func (t *TxSender) CancelTransaction(ctx context.Context, txnHash common.Hash) (bool, error) {
t.inflightMu.RLock()
cancel, found := t.inflightTxns[txnHash]
t.inflightMu.RUnlock()
if !found {
t.logger.Warn("Transaction not found in flight", "hash", txnHash.Hex())
return false, nil
}

t.logger.Info("Cancelling transaction", "hash", txnHash.Hex())
close(cancel) // Signal the transaction processing to stop

ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
t.logger.Info("Context cancelled while waiting for transaction cancellation")
return false, ctx.Err()
case <-ticker.C:
t.inflightMu.RLock()
_, stillInFlight := t.inflightTxns[txnHash]
t.inflightMu.RUnlock()
if !stillInFlight {
txn, err := t.store.GetTransactionByHash(ctx, txnHash)
switch {
case err != nil:
t.logger.Error("Failed to get transaction by hash", "hash", txnHash.Hex(), "error", err)
return false, fmt.Errorf("failed to get transaction by hash: %w", err)
case txn.Status == TxStatusFailed:
if txn.Details == ErrTransactionCancelled.Error() {
t.logger.Info("Transaction successfully cancelled", "hash", txnHash.Hex())
return true, nil
}
t.logger.Warn(
"Transaction failed with other error",
"hash", txnHash.Hex(),
"status", txn.Status,
"details", txn.Details,
)
return false, fmt.Errorf("transaction failed: %s", txn.Details)
case txn.Status == TxStatusPreConfirmed || txn.Status == TxStatusConfirmed:
t.logger.Info("Transaction already confirmed or pre-confirmed", "hash", txnHash.Hex(), "status", txn.Status)
return false, errors.New("transaction already confirmed or pre-confirmed")
}
}
}
}
}

func (t *TxSender) Start(ctx context.Context) chan struct{} {
t.eg, t.egCtx = errgroup.WithContext(ctx)
done := make(chan struct{})
Expand Down Expand Up @@ -237,23 +290,24 @@ func (t *TxSender) Start(ctx context.Context) chan struct{} {
return done
}

func (t *TxSender) markInflight(txn *Transaction) bool {
func (t *TxSender) markInflight(txn *Transaction) (bool, <-chan struct{}) {
t.inflightMu.Lock()
defer t.inflightMu.Unlock()

if _, ok := t.inflightTxns[txn.Hash()]; ok {
t.logger.Debug("Transaction already in flight, skipping", "hash", txn.Hash().Hex())
return false
return false, nil
}
if _, ok := t.inflightAccount[txn.Sender]; ok {
t.logger.Debug("Transaction sender already has an inflight transaction, skipping", "sender", txn.Sender.Hex())
t.triggerSender() // Trigger to reprocess later
return false
return false, nil
}

t.inflightTxns[txn.Hash()] = struct{}{}
cancel := make(chan struct{})
t.inflightTxns[txn.Hash()] = cancel
t.inflightAccount[txn.Sender] = struct{}{}
return true
return true, cancel
}

func (t *TxSender) markCompleted(txn *Transaction) {
Expand Down Expand Up @@ -284,14 +338,15 @@ func (t *TxSender) processQueuedTransactions(ctx context.Context) {
case t.workerPool <- struct{}{}:
t.eg.Go(func() error {
defer func() { <-t.workerPool }()
if !t.markInflight(txn) {
canExecute, cancel := t.markInflight(txn)
if !canExecute {
// Transaction is already being processed or sender has an inflight transaction
return nil
}
defer t.markCompleted(txn)

t.logger.Info("Processing transaction", "sender", txn.Sender.Hex(), "type", txn.Type)
if err := t.processTransaction(ctx, txn); err != nil {
if err := t.processTransaction(ctx, txn, cancel); err != nil {
t.logger.Error("Failed to process transaction", "sender", txn.Sender.Hex(), "error", err)
txn.Status = TxStatusFailed
txn.Details = err.Error()
Expand All @@ -303,7 +358,7 @@ func (t *TxSender) processQueuedTransactions(ctx context.Context) {
}
}

func (t *TxSender) processTransaction(ctx context.Context, txn *Transaction) error {
func (t *TxSender) processTransaction(ctx context.Context, txn *Transaction, cancel <-chan struct{}) error {
var (
result bidResult
err error
Expand All @@ -313,6 +368,8 @@ BID_LOOP:
select {
case <-ctx.Done():
return ctx.Err()
case <-cancel:
return ErrTransactionCancelled
default:
}

Expand All @@ -330,6 +387,8 @@ BID_LOOP:
return ctx.Err()
case <-time.After(retryErr.retryAfter):
// Wait for the specified retry duration before retrying
case <-cancel:
return ErrTransactionCancelled
}
continue
}
Expand Down Expand Up @@ -457,7 +516,10 @@ func (t *TxSender) sendBid(

start := time.Now()

prices, err := t.pricer.EstimatePrice(ctx)
cctx, cancel := context.WithTimeout(ctx, bidTimeout)
defer cancel()

prices, err := t.pricer.EstimatePrice(cctx)
if err != nil {
t.logger.Error("Failed to estimate transaction price", "error", err)
return bidResult{}, &errRetry{
Expand Down Expand Up @@ -516,9 +578,6 @@ func (t *TxSender) sendBid(
slashAmount = new(big.Int).Set(txn.Value())
}

cctx, cancel := context.WithTimeout(ctx, bidTimeout)
defer cancel()

bidC, err := t.bidder.Bid(
cctx,
cost,
Expand Down
89 changes: 89 additions & 0 deletions tools/preconf-rpc/sender/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type mockStore struct {
queued map[common.Address][]*sender.Transaction
nonce map[common.Address]uint64
balances map[common.Address]*big.Int
byHash map[common.Hash]*sender.Transaction
preconfirmedTxns chan result
}

Expand All @@ -37,6 +38,7 @@ func newMockStore() *mockStore {
nonce: make(map[common.Address]uint64),
balances: make(map[common.Address]*big.Int),
preconfirmedTxns: make(chan result, 10),
byHash: make(map[common.Hash]*sender.Transaction),
}
}

Expand Down Expand Up @@ -136,9 +138,25 @@ func (m *mockStore) StoreTransaction(
break
}
}
m.byHash[txn.Hash()] = txn
return nil
}

func (m *mockStore) GetTransactionByHash(
ctx context.Context,
hash common.Hash,
) (*sender.Transaction, error) {
m.mu.Lock()
defer m.mu.Unlock()

txn, exists := m.byHash[hash]
if !exists {
return nil, errors.New("transaction not found")
}

return txn, nil
}

type bidOp struct {
bidAmount *big.Int
slashAmount *big.Int
Expand Down Expand Up @@ -527,3 +545,74 @@ func TestSender(t *testing.T) {
cancel()
<-done
}

func TestCancelTransaction(t *testing.T) {
t.Parallel()

st := newMockStore()
testPricer := &mockPricer{
out: make(chan *pricer.BlockPrices, 10),
errOut: make(chan error, 1),
}
bidder := &mockBidder{
optinEstimate: make(chan int64),
in: make(chan bidOp, 10),
out: make(chan chan optinbidder.BidStatus, 10),
}
blockTracker := &mockBlockTracker{
in: make(chan op, 10),
out: make(chan bool, 10),
}

sndr, err := sender.NewTxSender(
st,
bidder,
testPricer,
blockTracker,
&mockTransferer{},
big.NewInt(1), // Settlement chain ID
util.NewTestLogger(os.Stdout),
)
if err != nil {
t.Fatalf("failed to create sender: %v", err)
}

ctx, cancel := context.WithCancel(context.Background())

done := sndr.Start(ctx)

tx1 := &sender.Transaction{
Transaction: types.NewTransaction(
1,
common.HexToAddress("0x1234567890123456789012345678901234567890"),
big.NewInt(100),
21000,
big.NewInt(1),
nil,
),
Sender: common.HexToAddress("0x1234567890123456789012345678901234567890"),
Type: sender.TxTypeRegular,
Raw: "0x1234567890123456789012345678901234567890",
}

if err := st.AddBalance(ctx, tx1.Sender, big.NewInt(5e18)); err != nil {
t.Fatalf("failed to add balance: %v", err)
}

if err := sndr.Enqueue(ctx, tx1); err != nil {
t.Fatalf("failed to enqueue transaction: %v", err)
}

bidder.optinEstimate <- 18

cancelled, err := sndr.CancelTransaction(ctx, tx1.Hash())
if err != nil {
t.Fatalf("failed to cancel transaction: %v", err)
}
if !cancelled {
t.Fatal("expected transaction to be cancelled, but it was not")
}

cancel()
<-done
}
Loading