From 3638384acf88af0d72738f335db4679e745b7345 Mon Sep 17 00:00:00 2001 From: chenquan Date: Tue, 4 Oct 2022 23:46:06 +0800 Subject: [PATCH] feat(conn): add connection in ctx --- conn.go | 39 ++++++++++++++++++++++++++++++++++----- conn_test.go | 10 ++++++++++ stmt.go | 27 +++++++++++++++++++++++++-- 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 952d2eb..4c737c5 100644 --- a/conn.go +++ b/conn.go @@ -13,13 +13,17 @@ var ( _ driver.ExecerContext = (*conn)(nil) ) -type conn struct { - driver.Conn - ConnHook -} +type ( + conn struct { + driver.Conn + ConnHook + } + connKey struct{} +) func (c *conn) Close() (err error) { - ctx, err := c.BeforeClose(context.Background(), nil) + ctx := c.newConnContext(context.Background()) + ctx, err = c.BeforeClose(ctx, nil) defer func() { _, err = c.AfterClose(ctx, err) }() @@ -31,6 +35,7 @@ func (c *conn) Close() (err error) { } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + ctx = c.newConnContext(ctx) ctx, query, args, err = c.BeforeExecContext(ctx, query, args, nil) defer func() { _, result, err = c.AfterExecContext(ctx, query, args, result, err) @@ -60,6 +65,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + ctx = c.newConnContext(ctx) ctx, query, args, err = c.BeforeQueryContext(ctx, query, args, nil) defer func() { _, rows, err = c.AfterQueryContext(ctx, query, args, rows, err) @@ -90,6 +96,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) { + ctx = c.newConnContext(ctx) ctx, query, err = c.BeforePrepareContext(ctx, query, nil) defer func() { _, s, err = c.AfterPrepareContext(ctx, query, s, err) @@ -113,6 +120,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx, err error) { + ctx = c.newConnContext(ctx) ctx, opts, err = c.BeforeBeginTx(ctx, opts, nil) defer func() { _, dd, err = c.AfterBeginTx(ctx, opts, dd, err) @@ -147,3 +155,24 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { return dargs, nil } + +func ConnFromContext(ctx context.Context) interface { + driver.Conn + driver.ConnPrepareContext + driver.ConnBeginTx +} { + value := ctx.Value(connKey{}) + if value == nil { + return nil + } + + return value.(interface { + driver.Conn + driver.ConnPrepareContext + driver.ConnBeginTx + }) +} + +func (c *conn) newConnContext(ctx context.Context) context.Context { + return context.WithValue(ctx, connKey{}, c) +} diff --git a/conn_test.go b/conn_test.go index d34bf04..5c994ac 100644 --- a/conn_test.go +++ b/conn_test.go @@ -83,6 +83,16 @@ func (m *mockConnBeginTx) BeginTx(_ context.Context, _ driver.TxOptions) (driver return &mockTx{}, nil } +var _ driver.ConnPrepareContext = (*mockConnPrepareContext)(nil) + +type mockConnPrepareContext struct { + *mockConn +} + +func (m *mockConnPrepareContext) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return &mockStmt{}, nil +} + // ----------------- func createMockConn() (*conn, *mockHook) { diff --git a/stmt.go b/stmt.go index 2c24919..9069864 100644 --- a/stmt.go +++ b/stmt.go @@ -19,6 +19,7 @@ type ( prepareContext context.Context } prepareContextKey struct{} + stmtKey struct{} ) func PrepareContextFromContext(ctx context.Context) context.Context { @@ -30,11 +31,28 @@ func PrepareContextFromContext(ctx context.Context) context.Context { return nil } +func StmtFromContext(ctx context.Context) interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext +} { + value := ctx.Value(stmtKey{}) + if value != nil { + return nil + } + + return value.(interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext + }) +} + // ----------------- func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { query := s.query - ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) + ctx = s.newStmtContext(ctx) ctx, args, err = s.BeforeStmtQueryContext(ctx, query, args, nil) defer func() { _, rows, err = s.AfterStmtQueryContext(ctx, query, args, rows, err) @@ -59,7 +77,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { query := s.query - ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) + ctx = s.newStmtContext(ctx) ctx, args, err = s.BeforeStmtExecContext(ctx, query, args, nil) defer func() { _, r, err = s.AfterStmtExecContext(ctx, query, args, r, err) @@ -80,3 +98,8 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri return s.Exec(value) } } + +func (s *stmt) newStmtContext(ctx context.Context) context.Context { + ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) + return context.WithValue(ctx, stmtKey{}, s) +}