From eb2a856ce5272408e4534e67bc11840097d29af2 Mon Sep 17 00:00:00 2001 From: chenquan Date: Mon, 12 Sep 2022 11:34:21 +0800 Subject: [PATCH] feat: support tx context --- conn.go | 2 +- stmt.go | 1 - tx.go | 15 +++++++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 96cf567..88388e6 100644 --- a/conn.go +++ b/conn.go @@ -120,7 +120,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx return nil, err } - return &tx{Tx: t, TxHook: c.ConnHook.(TxHook)}, nil + return &tx{Tx: t, TxHook: c.ConnHook.(TxHook), txContext: ctx}, nil } func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { diff --git a/stmt.go b/stmt.go index ffc55ea..c672ca2 100644 --- a/stmt.go +++ b/stmt.go @@ -55,7 +55,6 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows } func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { - ctx, args, err = s.BeforeStmtExecContext(ctx, s.query, args, nil) defer func() { _, r, err = s.AfterStmtExecContext(ctx, s.query, args, r, err) diff --git a/tx.go b/tx.go index bd1687a..f6b94d8 100644 --- a/tx.go +++ b/tx.go @@ -5,13 +5,24 @@ import ( "database/sql/driver" ) +type txContextKey struct{} type tx struct { driver.Tx TxHook + txContext context.Context +} + +func TxContextFromContext(ctx context.Context) context.Context { + value := ctx.Value(txContextKey{}) + if value != nil { + return value.(context.Context) + } + + return nil } func (t *tx) Commit() (err error) { - ctx := context.Background() + ctx := context.WithValue(context.Background(), txContextKey{}, t.txContext) ctx, err = t.BeforeCommit(ctx, nil) defer func() { _, err = t.AfterCommit(ctx, err) @@ -29,7 +40,7 @@ func (t *tx) Commit() (err error) { } func (t *tx) Rollback() (err error) { - ctx := context.Background() + ctx := context.WithValue(context.Background(), txContextKey{}, t.txContext) ctx, err = t.BeforeRollback(ctx, nil) defer func() { _, err = t.AfterRollback(ctx, err)