Skip to content

Commit 173d2ad

Browse files
authored
snapshot: trim extension statements from pgdump when uploading (#893)
When uploading a snapshot, trim extension statements unsupported by Google Cloud SQL. We do this on upload so that the customer's copy of the dump remains unmodified (in case something goes wrong, we still have the original to work with - trimming can be disabled with trim-extensions=false).
1 parent 7462d79 commit 173d2ad

File tree

3 files changed

+177
-19
lines changed

3 files changed

+177
-19
lines changed

cmd/src/snapshot_upload.go

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ BUCKET
3636
flagSet := flag.NewFlagSet("upload", flag.ExitOnError)
3737
bucketName := flagSet.String("bucket", "", "destination Cloud Storage bucket name")
3838
credentialsPath := flagSet.String("credentials", "", "JSON credentials file for Google Cloud service account")
39+
trimExtensions := flagSet.Bool("trim-extensions", true, "trim EXTENSION statements from database dumps for import to Google Cloud SQL")
3940

4041
snapshotCommands = append(snapshotCommands, &command{
4142
flagSet: flagSet,
@@ -59,8 +60,9 @@ BUCKET
5960
}
6061

6162
type upload struct {
62-
file *os.File
63-
stat os.FileInfo
63+
file *os.File
64+
stat os.FileInfo
65+
trimExtensions bool
6466
}
6567
var (
6668
uploads []upload // index aligned with progressBars
@@ -76,8 +78,9 @@ BUCKET
7678
return errors.Wrap(err, "get file size")
7779
}
7880
uploads = append(uploads, upload{
79-
file: f,
80-
stat: stat,
81+
file: f,
82+
stat: stat,
83+
trimExtensions: false, // not a database dump
8184
})
8285
progressBars = append(progressBars, output.ProgressBar{
8386
Label: stat.Name(),
@@ -95,8 +98,9 @@ BUCKET
9598
return errors.Wrap(err, "get file size")
9699
}
97100
uploads = append(uploads, upload{
98-
file: f,
99-
stat: stat,
101+
file: f,
102+
stat: stat,
103+
trimExtensions: *trimExtensions,
100104
})
101105
progressBars = append(progressBars, output.ProgressBar{
102106
Label: stat.Name(),
@@ -116,7 +120,7 @@ BUCKET
116120
g.Go(func(ctx context.Context) error {
117121
progressFn := func(p int64) { progress.SetValue(i, float64(p)) }
118122

119-
if err := copyToBucket(ctx, u.file, u.stat, bucket, progressFn); err != nil {
123+
if err := copyDumpToBucket(ctx, u.file, u.stat, bucket, progressFn, u.trimExtensions); err != nil {
120124
return errors.Wrap(err, u.stat.Name())
121125
}
122126

@@ -139,26 +143,43 @@ BUCKET
139143
})
140144
}
141145

142-
func copyToBucket(ctx context.Context, src io.Reader, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64)) error {
143-
writer := dst.Object(stat.Name()).NewWriter(ctx)
144-
writer.ProgressFunc = progressFn
145-
defer writer.Close()
146+
func copyDumpToBucket(ctx context.Context, src io.ReadSeeker, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64), trimExtensions bool) error {
147+
// Set up object to write to
148+
object := dst.Object(stat.Name()).NewWriter(ctx)
149+
object.ProgressFunc = progressFn
150+
defer object.Close()
151+
152+
// To assert against actual file size
153+
var totalWritten int64
154+
155+
// Do a partial copy that trims out unwanted statements
156+
if trimExtensions {
157+
written, err := pgdump.PartialCopyWithoutExtensions(object, src, progressFn)
158+
if err != nil {
159+
return errors.Wrap(err, "trim extensions and upload")
160+
}
161+
totalWritten += written
162+
}
146163

147164
// io.Copy is the best way to copy from a reader to writer in Go, and storage.Writer
148-
// has its own chunking mechanisms internally.
149-
written, err := io.Copy(writer, src)
165+
// has its own chunking mechanisms internally. io.Reader is stateful, so this copy
166+
// will just continue from where we left off if we use copyAndTrimExtensions.
167+
written, err := io.Copy(object, src)
150168
if err != nil {
151-
return err
169+
return errors.Wrap(err, "upload")
152170
}
171+
totalWritten += written
153172

154-
// Progress is not called on completion, so we call it manually after io.Copy is done
173+
// Progress is not called on completion of io.Copy, so we call it manually after to
174+
// update our pretty progress bars.
155175
progressFn(written)
156176

157-
// Validate we have sent all data
177+
// Validate we have sent all data. copyAndTrimExtensions may add some bytes, so the
178+
// check is not a strict equality.
158179
size := stat.Size()
159-
if written != size {
160-
return errors.Newf("expected to write %d bytes, but actually wrote %d bytes",
161-
size, written)
180+
if totalWritten < size {
181+
return errors.Newf("expected to write %d bytes, but actually wrote %d bytes (diff: %d bytes)",
182+
size, totalWritten, totalWritten-size)
162183
}
163184

164185
return nil

internal/pgdump/extensions.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package pgdump
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"io"
7+
8+
"github.com/sourcegraph/sourcegraph/lib/errors"
9+
)
10+
11+
// PartialCopyWithoutExtensions will perform a partial copy of a SQL database dump from
12+
// src to dst while commenting out EXTENSIONs-related statements. When it determines there
13+
// are no more EXTENSIONs-related statements, it will return, resetting src to the position
14+
// of the last contents written to dst.
15+
//
16+
// This is needed for import to Google Cloud Storage, which does not like many EXTENSION
17+
// statements. For more details, see https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
18+
//
19+
// Filtering requires reading entire lines into memory - this can be a very expensive
20+
// operation, so when filtering is complete the more efficient io.Copy should be used
21+
// to perform the remainder of the copy from src to dst.
22+
func PartialCopyWithoutExtensions(dst io.Writer, src io.ReadSeeker, progressFn func(int64)) (int64, error) {
23+
var (
24+
reader = bufio.NewReader(src)
25+
// position we have consumed up to, track separately because bufio.Reader may have
26+
// read ahead on src. This allows us to reset src later.
27+
consumed int64
28+
// number of bytes we have actually written to dst - it should always be returned.
29+
written int64
30+
// set to true when we have done all our filtering
31+
noMoreExtensions bool
32+
)
33+
34+
for !noMoreExtensions {
35+
// Read up to a line, keeping track of our position in src
36+
line, err := reader.ReadBytes('\n')
37+
consumed += int64(len(line))
38+
if err != nil {
39+
return written, err
40+
}
41+
42+
// Once we start seeing table creations, we are definitely done with extensions,
43+
// so we can hand off the rest to the superior io.Copy implementation.
44+
if bytes.HasPrefix(line, []byte("CREATE TABLE")) {
45+
// we are done with extensions
46+
noMoreExtensions = true
47+
} else if bytes.HasPrefix(line, []byte("COMMENT ON EXTENSION")) {
48+
// comment out this line
49+
line = append([]byte("-- "), line...)
50+
}
51+
52+
// Write this line and update our progress before returning on error
53+
lineWritten, err := dst.Write(line)
54+
written += int64(lineWritten)
55+
progressFn(written)
56+
if err != nil {
57+
return written, err
58+
}
59+
}
60+
61+
// No more extensions - reset src to the last actual consumed position
62+
_, err := src.Seek(consumed, io.SeekStart)
63+
if err != nil {
64+
return written, errors.Wrap(err, "reset src position")
65+
}
66+
return written, nil
67+
}

internal/pgdump/extensions_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package pgdump
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"os"
7+
"path/filepath"
8+
"runtime"
9+
"testing"
10+
11+
"github.com/hexops/autogold"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func TestPartialCopyWithoutExtensions(t *testing.T) {
17+
if runtime.GOOS == "windows" {
18+
t.Skip("Test doesn't work on Windows of weirdness with t.TempDir() handling")
19+
}
20+
21+
// Create test data - there is no stdlib in-memory io.ReadSeeker implementation
22+
src, err := os.Create(filepath.Join(t.TempDir(), t.Name()))
23+
require.NoError(t, err)
24+
_, err = src.WriteString(`-- Some comment
25+
26+
CREATE EXTENSION foobar
27+
28+
COMMENT ON EXTENSION barbaz
29+
30+
CREATE TYPE asdf
31+
32+
CREATE TABLE robert (
33+
...
34+
)
35+
36+
CREATE TABLE bobhead (
37+
...
38+
)`)
39+
require.NoError(t, err)
40+
_, err = src.Seek(0, io.SeekStart)
41+
require.NoError(t, err)
42+
43+
// Set up target to assert against
44+
var dst bytes.Buffer
45+
46+
// Perform partial copy
47+
_, err = PartialCopyWithoutExtensions(&dst, src, func(i int64) {})
48+
assert.NoError(t, err)
49+
50+
// Copy rest of contents
51+
_, err = io.Copy(&dst, src)
52+
assert.NoError(t, err)
53+
54+
// Assert contents (update with -update)
55+
autogold.Want("partial-copy-without-extensions", `-- Some comment
56+
57+
CREATE EXTENSION foobar
58+
59+
-- COMMENT ON EXTENSION barbaz
60+
61+
CREATE TYPE asdf
62+
63+
CREATE TABLE robert (
64+
...
65+
)
66+
67+
CREATE TABLE bobhead (
68+
...
69+
)`).Equal(t, dst.String())
70+
}

0 commit comments

Comments
 (0)