Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ var (
_ driver.ExecerContext = (*conn)(nil)
)

type conn struct {
driver.Conn
ConnHook
}
type (
conn struct {
driver.Conn
ConnHook
}
connKey struct{}
)

func (c *conn) Close() (err error) {
ctx, err := c.BeforeClose(context.Background(), nil)
ctx := c.newConnContext(context.Background())
ctx, err = c.BeforeClose(ctx, nil)
defer func() {
_, err = c.AfterClose(ctx, err)
}()
Expand All @@ -31,6 +35,7 @@ func (c *conn) Close() (err error) {
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
ctx = c.newConnContext(ctx)
ctx, query, args, err = c.BeforeExecContext(ctx, query, args, nil)
defer func() {
_, result, err = c.AfterExecContext(ctx, query, args, result, err)
Expand Down Expand Up @@ -60,6 +65,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
ctx = c.newConnContext(ctx)
ctx, query, args, err = c.BeforeQueryContext(ctx, query, args, nil)
defer func() {
_, rows, err = c.AfterQueryContext(ctx, query, args, rows, err)
Expand Down Expand Up @@ -90,6 +96,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
}

func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) {
ctx = c.newConnContext(ctx)
ctx, query, err = c.BeforePrepareContext(ctx, query, nil)
defer func() {
_, s, err = c.AfterPrepareContext(ctx, query, s, err)
Expand All @@ -113,6 +120,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt,
}

func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx, err error) {
ctx = c.newConnContext(ctx)
ctx, opts, err = c.BeforeBeginTx(ctx, opts, nil)
defer func() {
_, dd, err = c.AfterBeginTx(ctx, opts, dd, err)
Expand Down Expand Up @@ -147,3 +155,24 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {

return dargs, nil
}

func ConnFromContext(ctx context.Context) interface {
driver.Conn
driver.ConnPrepareContext
driver.ConnBeginTx
} {
value := ctx.Value(connKey{})
if value == nil {
return nil
}

return value.(interface {
driver.Conn
driver.ConnPrepareContext
driver.ConnBeginTx
})
}

func (c *conn) newConnContext(ctx context.Context) context.Context {
return context.WithValue(ctx, connKey{}, c)
}
10 changes: 10 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ func (m *mockConnBeginTx) BeginTx(_ context.Context, _ driver.TxOptions) (driver
return &mockTx{}, nil
}

var _ driver.ConnPrepareContext = (*mockConnPrepareContext)(nil)

type mockConnPrepareContext struct {
*mockConn
}

func (m *mockConnPrepareContext) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return &mockStmt{}, nil
}

// -----------------

func createMockConn() (*conn, *mockHook) {
Expand Down
27 changes: 25 additions & 2 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type (
prepareContext context.Context
}
prepareContextKey struct{}
stmtKey struct{}
)

func PrepareContextFromContext(ctx context.Context) context.Context {
Expand All @@ -30,11 +31,28 @@ func PrepareContextFromContext(ctx context.Context) context.Context {
return nil
}

func StmtFromContext(ctx context.Context) interface {
driver.Stmt
driver.StmtExecContext
driver.StmtQueryContext
} {
value := ctx.Value(stmtKey{})
if value != nil {
return nil
}

return value.(interface {
driver.Stmt
driver.StmtExecContext
driver.StmtQueryContext
})
}

// -----------------

func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
query := s.query
ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext)
ctx = s.newStmtContext(ctx)
ctx, args, err = s.BeforeStmtQueryContext(ctx, query, args, nil)
defer func() {
_, rows, err = s.AfterStmtQueryContext(ctx, query, args, rows, err)
Expand All @@ -59,7 +77,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows

func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
query := s.query
ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext)
ctx = s.newStmtContext(ctx)
ctx, args, err = s.BeforeStmtExecContext(ctx, query, args, nil)
defer func() {
_, r, err = s.AfterStmtExecContext(ctx, query, args, r, err)
Expand All @@ -80,3 +98,8 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri
return s.Exec(value)
}
}

func (s *stmt) newStmtContext(ctx context.Context) context.Context {
ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext)
return context.WithValue(ctx, stmtKey{}, s)
}