Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions cmd/src/snapshot_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ BUCKET
flagSet := flag.NewFlagSet("upload", flag.ExitOnError)
bucketName := flagSet.String("bucket", "", "destination Cloud Storage bucket name")
credentialsPath := flagSet.String("credentials", "", "JSON credentials file for Google Cloud service account")
trimExtensions := flagSet.Bool("trim-extensions", true, "trim EXTENSION statements from database dumps for import to Google Cloud SQL")

snapshotCommands = append(snapshotCommands, &command{
flagSet: flagSet,
Expand All @@ -59,8 +60,9 @@ BUCKET
}

type upload struct {
file *os.File
stat os.FileInfo
file *os.File
stat os.FileInfo
trimExtensions bool
}
var (
uploads []upload // index aligned with progressBars
Expand All @@ -76,8 +78,9 @@ BUCKET
return errors.Wrap(err, "get file size")
}
uploads = append(uploads, upload{
file: f,
stat: stat,
file: f,
stat: stat,
trimExtensions: false, // not a database dump
})
progressBars = append(progressBars, output.ProgressBar{
Label: stat.Name(),
Expand All @@ -95,8 +98,9 @@ BUCKET
return errors.Wrap(err, "get file size")
}
uploads = append(uploads, upload{
file: f,
stat: stat,
file: f,
stat: stat,
trimExtensions: *trimExtensions,
})
progressBars = append(progressBars, output.ProgressBar{
Label: stat.Name(),
Expand All @@ -116,7 +120,7 @@ BUCKET
g.Go(func(ctx context.Context) error {
progressFn := func(p int64) { progress.SetValue(i, float64(p)) }

if err := copyToBucket(ctx, u.file, u.stat, bucket, progressFn); err != nil {
if err := copyDumpToBucket(ctx, u.file, u.stat, bucket, progressFn, u.trimExtensions); err != nil {
return errors.Wrap(err, u.stat.Name())
}

Expand All @@ -139,26 +143,43 @@ BUCKET
})
}

func copyToBucket(ctx context.Context, src io.Reader, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64)) error {
writer := dst.Object(stat.Name()).NewWriter(ctx)
writer.ProgressFunc = progressFn
defer writer.Close()
func copyDumpToBucket(ctx context.Context, src io.ReadSeeker, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64), trimExtensions bool) error {
// Set up object to write to
object := dst.Object(stat.Name()).NewWriter(ctx)
object.ProgressFunc = progressFn
defer object.Close()

// To assert against actual file size
var totalWritten int64

// Do a partial copy that trims out unwanted statements
if trimExtensions {
written, err := pgdump.PartialCopyWithoutExtensions(object, src, progressFn)
if err != nil {
return errors.Wrap(err, "trim extensions and upload")
}
totalWritten += written
}

// io.Copy is the best way to copy from a reader to writer in Go, and storage.Writer
// has its own chunking mechanisms internally.
written, err := io.Copy(writer, src)
// has its own chunking mechanisms internally. io.Reader is stateful, so this copy
// will just continue from where we left off if we use copyAndTrimExtensions.
written, err := io.Copy(object, src)
if err != nil {
return err
return errors.Wrap(err, "upload")
}
totalWritten += written

// Progress is not called on completion, so we call it manually after io.Copy is done
// Progress is not called on completion of io.Copy, so we call it manually after to
// update our pretty progress bars.
progressFn(written)

// Validate we have sent all data
// Validate we have sent all data. copyAndTrimExtensions may add some bytes, so the
// check is not a strict equality.
size := stat.Size()
if written != size {
return errors.Newf("expected to write %d bytes, but actually wrote %d bytes",
size, written)
if totalWritten < size {
return errors.Newf("expected to write %d bytes, but actually wrote %d bytes (diff: %d bytes)",
size, totalWritten, totalWritten-size)
}

return nil
Expand Down
67 changes: 67 additions & 0 deletions internal/pgdump/extensions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package pgdump

import (
"bufio"
"bytes"
"io"

"github.com/sourcegraph/sourcegraph/lib/errors"
)

// PartialCopyWithoutExtensions will perform a partial copy of a SQL database dump from
// src to dst while commenting out EXTENSIONs-related statements. When it determines there
// are no more EXTENSIONs-related statements, it will return, resetting src to the position
// of the last contents written to dst.
//
// This is needed for import to Google Cloud Storage, which does not like many EXTENSION
// statements. For more details, see https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
//
// Filtering requires reading entire lines into memory - this can be a very expensive
// operation, so when filtering is complete the more efficient io.Copy should be used
// to perform the remainder of the copy from src to dst.
func PartialCopyWithoutExtensions(dst io.Writer, src io.ReadSeeker, progressFn func(int64)) (int64, error) {
var (
reader = bufio.NewReader(src)
// position we have consumed up to, track separately because bufio.Reader may have
// read ahead on src. This allows us to reset src later.
consumed int64
// number of bytes we have actually written to dst - it should always be returned.
written int64
// set to true when we have done all our filtering
noMoreExtensions bool
)

for !noMoreExtensions {
// Read up to a line, keeping track of our position in src
line, err := reader.ReadBytes('\n')
consumed += int64(len(line))
if err != nil {
return written, err
}

// Once we start seeing table creations, we are definitely done with extensions,
// so we can hand off the rest to the superior io.Copy implementation.
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure about this? I'm lacking some context if this is being used by end users or not but this is a large assumption.
The official postgres docs use things like split https://www.postgresql.org/docs/12/backup-dump.html#BACKUP-DUMP-LARGE

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're referring to table creations being before extensions, then yes - extensions are typically a prerequisite

As for handing off to io.Copy, yes - the io.Copy implementation is more robust and efficient, and the GCS bucket handler will handle chunking for us. I don't think Cloud SQL will be happy with split dumps, and to piece it together we need to download it somewhere, which I'd like to avoid (right now the process can just have us import directly from GCS)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can revisit things like split if size becomes a blocker, but I think the streaming nature of io.Copy should mitigate the effects of size

if bytes.HasPrefix(line, []byte("CREATE TABLE")) {
// we are done with extensions
noMoreExtensions = true
} else if bytes.HasPrefix(line, []byte("COMMENT ON EXTENSION")) {
// comment out this line
line = append([]byte("-- "), line...)
}

// Write this line and update our progress before returning on error
lineWritten, err := dst.Write(line)
written += int64(lineWritten)
progressFn(written)
if err != nil {
return written, err
}
}

// No more extensions - reset src to the last actual consumed position
_, err := src.Seek(consumed, io.SeekStart)
if err != nil {
return written, errors.Wrap(err, "reset src position")
}
return written, nil
}
70 changes: 70 additions & 0 deletions internal/pgdump/extensions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package pgdump

import (
"bytes"
"io"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/hexops/autogold"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPartialCopyWithoutExtensions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Test doesn't work on Windows of weirdness with t.TempDir() handling")
}

// Create test data - there is no stdlib in-memory io.ReadSeeker implementation
src, err := os.Create(filepath.Join(t.TempDir(), t.Name()))
require.NoError(t, err)
_, err = src.WriteString(`-- Some comment

CREATE EXTENSION foobar

COMMENT ON EXTENSION barbaz

CREATE TYPE asdf

CREATE TABLE robert (
...
)

CREATE TABLE bobhead (
...
)`)
require.NoError(t, err)
_, err = src.Seek(0, io.SeekStart)
require.NoError(t, err)

// Set up target to assert against
var dst bytes.Buffer

// Perform partial copy
_, err = PartialCopyWithoutExtensions(&dst, src, func(i int64) {})
assert.NoError(t, err)

// Copy rest of contents
_, err = io.Copy(&dst, src)
assert.NoError(t, err)

// Assert contents (update with -update)
autogold.Want("partial-copy-without-extensions", `-- Some comment

CREATE EXTENSION foobar

-- COMMENT ON EXTENSION barbaz

CREATE TYPE asdf

CREATE TABLE robert (
...
)

CREATE TABLE bobhead (
...
)`).Equal(t, dst.String())
}