Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/howto/tx.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ if err != nil {
:class: tip
[TX.Commit](https://pkg.go.dev/github.com/canonical/sqlair#TX.Commit),
[TX.Rollback](https://pkg.go.dev/github.com/canonical/sqlair#TX.Rollback),
[sqlair.ErrTXDone](https://pkg.go.dev/github.com/canonical/sqlair#ErrTXDone)
[sql.ErrTXDone](https://pkg.go.dev/database/sql#ErrTxDone)
```
74 changes: 72 additions & 2 deletions package_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,76 @@ func (s *PackageSuite) TestGetAllErrors(c *C) {
}
}

// TestGetAllMutatesExistingSliceValues is a regression test to assert a
// behaviour change made to [Query.GetAll]. Prior to this calling [Query.GetAll]
// with one or more slice values that had pre existing elements len(s) > 0 would
// maintain the existing elements in the slice.
//
// This is a problem for transaction closures that are retried due to failures
// and may have had a successful read into a slice prior to a transaction
// failure.
func (s *PackageSuite) TestGetAllMutatesExistingSliceValues(c *C) {
db, tables := s.personAndAddressDB(c)
defer dropTables(c, db, tables...)

stmt, err := sqlair.Prepare("SELECT &Person.* FROM person", Person{})
c.Assert(err, IsNil)

var dbVals []Person
err = db.Query(nil, stmt).GetAll(&dbVals)
c.Assert(err, IsNil)

initialLen := len(dbVals)
c.Check(initialLen >= 0, Equals, true,
Commentf("expected at least one or more person records"))

err = db.Query(nil, stmt).GetAll(&dbVals)
c.Assert(err, IsNil)

c.Check(len(dbVals), Equals, initialLen)
}

// TestGetAllMutatesMultipleExistingSliceValues is a regression test to assert a
// behaviour change made to [Query.GetAll]. Prior to this calling [Query.GetAll]
// with one or more slice values that had pre existing elements len(s) > 0 would
// maintain the existing elements in the slice.
//
// This is a problem for transaction closures that are retried due to failures
// and may have had a successful read into a slice prior to a transaction
// failure.
func (s *PackageSuite) TestGetAllMutatesMultipleExistingSliceValues(c *C) {
db, tables := s.personAndAddressDB(c)
defer dropTables(c, db, tables...)

stmt, err := sqlair.Prepare(`
SELECT person.* AS &Person.*,
address.* AS &Address.*
FROM person
INNER JOIN address ON person.address_id = address.id
`, Address{}, Person{})
c.Assert(err, IsNil)

var (
addressDBVals []Address
personDBVals []Person
)
err = db.Query(nil, stmt).GetAll(&personDBVals, &addressDBVals)
c.Assert(err, IsNil)

initialAddressLen := len(personDBVals)
c.Check(initialAddressLen >= 0, Equals, true,
Commentf("expected at least one or more address records"))
initialPersonLen := len(personDBVals)
c.Check(initialPersonLen >= 0, Equals, true,
Commentf("expected at least one or more person records"))

err = db.Query(nil, stmt).GetAll(&personDBVals, &addressDBVals)
c.Assert(err, IsNil)

c.Check(len(addressDBVals), Equals, initialAddressLen)
c.Check(len(personDBVals), Equals, initialPersonLen)
}

func (s *PackageSuite) TestRun(c *C) {
db, tables := s.personAndAddressDB(c)
defer dropTables(c, db, tables...)
Expand Down Expand Up @@ -1118,10 +1188,10 @@ func (s *PackageSuite) TestTransactionErrors(c *C) {
// Test error when running query after rollback against the public error variable.
tx, err = db.Begin(ctx, nil)
c.Assert(err, IsNil)

err = tx.Rollback()
c.Assert(err, IsNil)

err = tx.Query(ctx, insertStmt, &derek).Run()
if !errors.Is(err, sql.ErrTxDone) {
c.Errorf("expected %q, got %q", sql.ErrTxDone, err)
Expand Down
6 changes: 6 additions & 0 deletions sqlair.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ func (o *Outcome) Result() sql.Result {
// A pointer to an empty [Outcome] struct may be provided as the first output
// variable to get information about query execution.
//
// A provided slice that already contains one or more values will have its
// length reset to 0 before scanning into the slice.
//
// [ErrNoRows] will be returned if no rows are found.
func (q *Query) GetAll(sliceArgs ...any) (err error) {
if q.err != nil {
Expand Down Expand Up @@ -348,6 +351,9 @@ func (q *Query) GetAll(sliceArgs ...any) (err error) {
if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("need pointer to slice, got pointer to %s", sliceVal.Kind())
}
// Set the length of the slice value back to zero maintaining any
// existing capacity.
sliceVal.SetLen(0)
sliceVals = append(sliceVals, sliceVal)
}

Expand Down
Loading