diff --git a/docs/howto/tx.md b/docs/howto/tx.md index df80a2d7..c5315db8 100644 --- a/docs/howto/tx.md +++ b/docs/howto/tx.md @@ -54,5 +54,5 @@ if err != nil { :class: tip [TX.Commit](https://pkg.go.dev/github.com/canonical/sqlair#TX.Commit), [TX.Rollback](https://pkg.go.dev/github.com/canonical/sqlair#TX.Rollback), -[sqlair.ErrTXDone](https://pkg.go.dev/github.com/canonical/sqlair#ErrTXDone) +[sql.ErrTXDone](https://pkg.go.dev/database/sql#ErrTxDone) ``` diff --git a/package_test.go b/package_test.go index 436b634e..752de899 100644 --- a/package_test.go +++ b/package_test.go @@ -727,6 +727,76 @@ func (s *PackageSuite) TestGetAllErrors(c *C) { } } +// TestGetAllMutatesExistingSliceValues is a regression test to assert a +// behaviour change made to [Query.GetAll]. Prior to this calling [Query.GetAll] +// with one or more slice values that had pre existing elements len(s) > 0 would +// maintain the existing elements in the slice. +// +// This is a problem for transaction closures that are retried due to failures +// and may have had a successful read into a slice prior to a transaction +// failure. +func (s *PackageSuite) TestGetAllMutatesExistingSliceValues(c *C) { + db, tables := s.personAndAddressDB(c) + defer dropTables(c, db, tables...) + + stmt, err := sqlair.Prepare("SELECT &Person.* FROM person", Person{}) + c.Assert(err, IsNil) + + var dbVals []Person + err = db.Query(nil, stmt).GetAll(&dbVals) + c.Assert(err, IsNil) + + initialLen := len(dbVals) + c.Check(initialLen >= 0, Equals, true, + Commentf("expected at least one or more person records")) + + err = db.Query(nil, stmt).GetAll(&dbVals) + c.Assert(err, IsNil) + + c.Check(len(dbVals), Equals, initialLen) +} + +// TestGetAllMutatesMultipleExistingSliceValues is a regression test to assert a +// behaviour change made to [Query.GetAll]. Prior to this calling [Query.GetAll] +// with one or more slice values that had pre existing elements len(s) > 0 would +// maintain the existing elements in the slice. +// +// This is a problem for transaction closures that are retried due to failures +// and may have had a successful read into a slice prior to a transaction +// failure. +func (s *PackageSuite) TestGetAllMutatesMultipleExistingSliceValues(c *C) { + db, tables := s.personAndAddressDB(c) + defer dropTables(c, db, tables...) + + stmt, err := sqlair.Prepare(` +SELECT person.* AS &Person.*, + address.* AS &Address.* +FROM person +INNER JOIN address ON person.address_id = address.id +`, Address{}, Person{}) + c.Assert(err, IsNil) + + var ( + addressDBVals []Address + personDBVals []Person + ) + err = db.Query(nil, stmt).GetAll(&personDBVals, &addressDBVals) + c.Assert(err, IsNil) + + initialAddressLen := len(personDBVals) + c.Check(initialAddressLen >= 0, Equals, true, + Commentf("expected at least one or more address records")) + initialPersonLen := len(personDBVals) + c.Check(initialPersonLen >= 0, Equals, true, + Commentf("expected at least one or more person records")) + + err = db.Query(nil, stmt).GetAll(&personDBVals, &addressDBVals) + c.Assert(err, IsNil) + + c.Check(len(addressDBVals), Equals, initialAddressLen) + c.Check(len(personDBVals), Equals, initialPersonLen) +} + func (s *PackageSuite) TestRun(c *C) { db, tables := s.personAndAddressDB(c) defer dropTables(c, db, tables...) @@ -1118,10 +1188,10 @@ func (s *PackageSuite) TestTransactionErrors(c *C) { // Test error when running query after rollback against the public error variable. tx, err = db.Begin(ctx, nil) c.Assert(err, IsNil) - + err = tx.Rollback() c.Assert(err, IsNil) - + err = tx.Query(ctx, insertStmt, &derek).Run() if !errors.Is(err, sql.ErrTxDone) { c.Errorf("expected %q, got %q", sql.ErrTxDone, err) diff --git a/sqlair.go b/sqlair.go index ad23d0b1..74e4b0d9 100644 --- a/sqlair.go +++ b/sqlair.go @@ -317,6 +317,9 @@ func (o *Outcome) Result() sql.Result { // A pointer to an empty [Outcome] struct may be provided as the first output // variable to get information about query execution. // +// A provided slice that already contains one or more values will have its +// length reset to 0 before scanning into the slice. +// // [ErrNoRows] will be returned if no rows are found. func (q *Query) GetAll(sliceArgs ...any) (err error) { if q.err != nil { @@ -348,6 +351,9 @@ func (q *Query) GetAll(sliceArgs ...any) (err error) { if sliceVal.Kind() != reflect.Slice { return fmt.Errorf("need pointer to slice, got pointer to %s", sliceVal.Kind()) } + // Set the length of the slice value back to zero maintaining any + // existing capacity. + sliceVal.SetLen(0) sliceVals = append(sliceVals, sliceVal) }