diff --git a/cmd/src/snapshot.go b/cmd/src/snapshot.go index 0c604ca088..837ed3c21a 100644 --- a/cmd/src/snapshot.go +++ b/cmd/src/snapshot.go @@ -8,15 +8,23 @@ import ( var snapshotCommands commander func init() { - usage := `'src snapshot' manages snapshots of Sourcegraph instance data. All subcommands are currently EXPERIMENTAL. + usage := `'src snapshot' manages snapshots of Sourcegraph instance databases. All subcommands are currently EXPERIMENTAL. -USAGE - src [-v] snapshot +Usage: -COMMANDS + src snapshot + +The commands are: + + databases export databases from a Sourcegraph instance + restore restore databases from an export + upload upload exported databases and summary file when migrating to Sourcegraph Cloud summary export summary data about an instance for acceptance testing of a restored Sourcegraph instance test use exported summary data and instance health indicators to validate a restored and upgraded instance + +Use "src snapshot [command] -h" for more information about a command. + ` flagSet := flag.NewFlagSet("snapshot", flag.ExitOnError) diff --git a/cmd/src/snapshot_upload.go b/cmd/src/snapshot_upload.go index 85e938f6e8..86c141053b 100644 --- a/cmd/src/snapshot_upload.go +++ b/cmd/src/snapshot_upload.go @@ -5,9 +5,10 @@ import ( "flag" "fmt" "io" - "io/fs" "os" - "path" + "path/filepath" + "slices" + "strings" "cloud.google.com/go/storage" "github.com/sourcegraph/conc/pool" @@ -15,171 +16,346 @@ import ( "github.com/sourcegraph/sourcegraph/lib/output" "google.golang.org/api/option" + "github.com/sourcegraph/src-cli/internal/cmderrors" "github.com/sourcegraph/src-cli/internal/pgdump" ) +// Package-level variables const srcSnapshotDir = "./src-snapshot" -var srcSnapshotSummaryPath = path.Join(srcSnapshotDir, "summary.json") +// summaryFile on its own, as it gets handled a little differently +const summaryFile = "summary.json" -// https://pkg.go.dev/cloud.google.com/go/storage#section-readme +var srcSnapshotSummaryPath = filepath.Join(srcSnapshotDir, summaryFile) + +// listOfValidFiles defines the valid snapshot filenames (with extensions) that can be uploaded +var listOfValidFiles = []string{ + "codeinsights.sql", + "codeintel.sql", + "pgsql.sql", + summaryFile, +} + +// Define types +type uploadArgs struct { + bucketName string + credentialsPath string + filterSQL bool + filesToUpload []string +} + +// Google Cloud Storage upload client +type gcsClient struct { + ctx context.Context + out *output.Output + storageClient *storage.Client +} + +// uploadFile represents a file opened for upload +type uploadFile struct { + file *os.File + stat os.FileInfo + filterSQL bool // Whether to filter incompatible SQL statements during upload, true for database files, false for summary file +} func init() { - usage := `'src snapshot upload' uploads instance snapshot contents generated by 'src snapshot databases' and 'src snapshot summary' to the designated bucket. + usage := fmt.Sprintf(`'src snapshot upload' uploads the files generated by 'src snapshot databases' and 'src snapshot summary' to the specified GCS bucket, for self-hosted Sourcegraph customers migrating to Sourcegraph Cloud. + +Usage: + + src snapshot upload -bucket=$MIGRATION_BUCKET_NAME -credentials=$CREDENTIALS_FILE_PATH [-file] + +Examples: + + src snapshot upload -bucket=example-bucket-name -credentials=path/to/migration_private_key.json + + src snapshot upload -bucket=example-bucket-name -credentials=./migration_private_key.json -file=pgsql.sql -USAGE - src snapshot upload -bucket=$BUCKET -credentials=$CREDENTIALS_FILE + src snapshot upload -bucket=example-bucket-name -credentials=./migration_private_key.json -file="codeinsights.sql, codeintel.sql, pgsql.sql" + +Args: + + -bucket + Name of the Google Cloud Storage bucket provided by Sourcegraph + Required + Type: string + + -credentials + File path to the credentials file provided by Sourcegraph + Required + Type: file path, as string + + -file + Specify which files from the ./src-snapshot directory to upload + Optional + Type: comma-delimited list of file names, with file-type extensions, as a string + Valid values: %s + Default: All valid values + +`, strings.Join(listOfValidFiles, ", ")) -BUCKET - In general, a Google Cloud Storage bucket and relevant credentials will be provided by Sourcegraph when using this functionality to share a snapshot with Sourcegraph. -` flagSet := flag.NewFlagSet("upload", flag.ExitOnError) - bucketName := flagSet.String("bucket", "", "destination Cloud Storage bucket name") - credentialsPath := flagSet.String("credentials", "", "JSON credentials file for Google Cloud service account") - trimExtensions := flagSet.Bool("trim-extensions", true, "trim EXTENSION statements from database dumps for import to Google Cloud SQL") + bucketName := flagSet.String("bucket", "", "Name of the Google Cloud Storage bucket provided by Sourcegraph") + credentialsPath := flagSet.String("credentials", "", "File path to the credentials file provided by Sourcegraph") + fileArg := flagSet.String("file", strings.Join(listOfValidFiles, ","), "Specify which files from the ./src-snapshot directory to upload") + filterSQL := flagSet.Bool("filter-sql", true, "Filter incompatible SQL statements from database snapshots which break the import into Google Cloud SQL") + // Register this command with the parent 'src snapshot' command. + // The parent snapshot.go command runs all registered subcommands via snapshotCommands.run(). + // This self-registration pattern allows subcommands to automatically register themselves + // when their init() functions run, without requiring a central registry file. snapshotCommands = append(snapshotCommands, &command{ - flagSet: flagSet, - handler: func(args []string) error { - if err := flagSet.Parse(args); err != nil { - return err - } + flagSet: flagSet, + handler: snapshotUploadHandler(flagSet, bucketName, credentialsPath, filterSQL, fileArg), + usageFunc: func() { fmt.Fprint(flag.CommandLine.Output(), usage) }, + }) +} - if *bucketName == "" { - return errors.New("-bucket required") - } - if *credentialsPath == "" { - return errors.New("-credentials required") - } +// Handler function to keep init() succinct +func snapshotUploadHandler(flagSet *flag.FlagSet, bucketName, credentialsPath *string, filterSQL *bool, fileArg *string) func([]string) error { - out := output.NewOutput(flagSet.Output(), output.OutputOpts{Verbose: *verbose}) - ctx := context.Background() - c, err := storage.NewClient(ctx, option.WithCredentialsFile(*credentialsPath)) - if err != nil { - return errors.Wrap(err, "create Cloud Storage client") - } + return func(args []string) error { + if err := flagSet.Parse(args); err != nil { + return err + } - type upload struct { - file *os.File - stat os.FileInfo - trimExtensions bool - } - var ( - uploads []upload // index aligned with progressBars - progressBars []output.ProgressBar // index aligned with uploads - ) - - // Open snapshot summary - if f, err := os.Open(srcSnapshotSummaryPath); err != nil { - return errors.Wrap(err, "failed to open snapshot summary - generate one with 'src snapshot summary'") - } else { - stat, err := f.Stat() - if err != nil { - return errors.Wrap(err, "get file size") - } - uploads = append(uploads, upload{ - file: f, - stat: stat, - trimExtensions: false, // not a database dump - }) - progressBars = append(progressBars, output.ProgressBar{ - Label: stat.Name(), - Max: float64(stat.Size()), - }) - } + // Validate and parse inputs into an uploadArgs-type object + uploadArgs, err := validateUploadInputs(*bucketName, *credentialsPath, *fileArg, *filterSQL) + if err != nil { + return err + } - // Open database dumps - for _, o := range pgdump.Outputs(srcSnapshotDir, pgdump.Targets{}) { - if f, err := os.Open(o.Output); err != nil { - return errors.Wrap(err, "failed to database dump - generate one with 'src snapshot databases'") - } else { - stat, err := f.Stat() - if err != nil { - return errors.Wrap(err, "get file size") - } - uploads = append(uploads, upload{ - file: f, - stat: stat, - trimExtensions: *trimExtensions, - }) - progressBars = append(progressBars, output.ProgressBar{ - Label: stat.Name(), - Max: float64(stat.Size()), - }) - } - } + // Create client + client, err := createGcsClient(flagSet, uploadArgs.credentialsPath) + if err != nil { + return err + } + + // Open files and create progress bars + openedFiles, progressBars, err := openFilesAndCreateProgressBars(uploadArgs) + if err != nil { + return err + } + + // Upload files to bucket + return uploadFilesToBucket(client, uploadArgs, openedFiles, progressBars) + } +} + +// Validate user inputs, and convert them to an object of type uploadArgs +func validateUploadInputs(bucketName, credentialsPath, fileArg string, filterSQL bool) (*uploadArgs, error) { + + if bucketName == "" { + return nil, cmderrors.Usage("-bucket required") + } + + if credentialsPath == "" { + return nil, cmderrors.Usage("-credentials required") + } + + filesToUpload, err := parseFileArg(fileArg) + if err != nil { + return nil, err + } + + return &uploadArgs{ + bucketName: bucketName, + credentialsPath: credentialsPath, + filterSQL: filterSQL, + filesToUpload: filesToUpload, + }, nil +} + +// Parse the --file arg values, and return a list of strings of the files to upload +func parseFileArg(fileArg string) ([]string, error) { + + // Default: all files + if fileArg == "" { + return listOfValidFiles, nil + } - // Start uploads - progress := out.Progress(progressBars, nil) - progress.WriteLine(output.Emoji(output.EmojiHourglass, "Starting uploads...")) - bucket := c.Bucket(*bucketName) - g := pool.New().WithErrors().WithContext(ctx) - for i, u := range uploads { - i := i - u := u - g.Go(func(ctx context.Context) error { - progressFn := func(p int64) { progress.SetValue(i, float64(p)) } - - if err := copyDumpToBucket(ctx, u.file, u.stat, bucket, progressFn, u.trimExtensions); err != nil { - return errors.Wrap(err, u.stat.Name()) - } - - return nil - }) + var filesToUpload []string + + // Parse comma-delimited list + for _, part := range strings.Split(fileArg, ",") { + + // Trim whitespace + filename := strings.TrimSpace(part) + + // Validate against list of valid files + if !slices.Contains(listOfValidFiles, filename) { + return nil, cmderrors.Usagef("invalid -file value %q. Valid values: %s", part, strings.Join(listOfValidFiles, ", ")) + } + + filesToUpload = append(filesToUpload, filename) + } + + // Sort files alphabetically for consistent ordering + slices.Sort(filesToUpload) + + // Remove duplicates (works on sorted slices by removing adjacent duplicates) + filesToUpload = slices.Compact(filesToUpload) + + return filesToUpload, nil +} + +func createGcsClient(flagSet *flag.FlagSet, credentialsPath string) (*gcsClient, error) { + + ctx := context.Background() + out := output.NewOutput(flagSet.Output(), output.OutputOpts{Verbose: *verbose}) + + // https://pkg.go.dev/cloud.google.com/go/storage#section-readme + client, err := storage.NewClient(ctx, option.WithCredentialsFile(credentialsPath)) + + if err != nil { + return nil, errors.Wrap(err, "create Google Cloud Storage client") + } + + return &gcsClient{ + ctx: ctx, + out: out, + storageClient: client, + }, nil +} + +// openFilesAndCreateProgressBars opens selected snapshot files from disk and creates progress bars for UI display. +// Returns arrays of uploadFile and progress bars (aligned by index). +func openFilesAndCreateProgressBars(args *uploadArgs) ([]uploadFile, []output.ProgressBar, error) { + var ( + openedFiles []uploadFile // Files opened from disk, ready for upload (aligned with progressBars) + progressBars []output.ProgressBar // Progress bars for UI (aligned with openedFiles) + ) + + // addFile opens a file from disk and registers it for upload. + // It adds the file to the openedFiles array and creates a corresponding progress bar. + // For database dumps (!isSummary), SQL filtering is enabled based on args.filterSQL. + addFile := func(filePath string) error { + + isSummary := strings.HasSuffix(filePath, summaryFile) + + // Open the file + openFile, err := os.Open(filePath) + + if err != nil { + if isSummary { + return errors.Wrap(err, fmt.Sprintf("failed to open snapshot summary %s - Please generate it with 'src snapshot summary'", filePath)) } + return errors.Wrap(err, fmt.Sprintf("failed to open database dump %s - Please generate them with 'src snapshot databases'", filePath)) + } - // Finalize - errs := g.Wait() - progress.Complete() - if errs != nil { - out.WriteLine(output.Line(output.EmojiFailure, output.StyleFailure, "Some snapshot contents failed to upload.")) - return errs + // Get file metadata (name, size) + stat, err := openFile.Stat() + if err != nil { + return errors.Wrap(err, "get file size") + } + + // Register file for upload + openedFiles = append(openedFiles, uploadFile{ + file: openFile, + stat: stat, + filterSQL: !isSummary && args.filterSQL, // Only filter SQL for database dumps + }) + + // Create progress bar for this file + progressBars = append(progressBars, output.ProgressBar{ + Label: stat.Name(), + Max: float64(stat.Size()), + }) + return nil + } + + // Open files based on user's selection (via --file arg) + // Iterate through the user's selected files and open each one + for _, selectedFile := range args.filesToUpload { + + // Construct full file path + filePath := filepath.Join(srcSnapshotDir, selectedFile) + + if err := addFile(filePath); err != nil { + return nil, nil, err + } + } + + return openedFiles, progressBars, nil +} + +// uploadFilesToBucket uploads the prepared files to Google Cloud Storage bucket. +// Uploads are performed in parallel with progress tracking. +func uploadFilesToBucket(client *gcsClient, args *uploadArgs, openedFiles []uploadFile, progressBars []output.ProgressBar) error { + + // Start uploads with progress tracking + progress := client.out.Progress(progressBars, nil) + progress.WriteLine(output.Emoji(output.EmojiHourglass, "Starting uploads...")) + bucket := client.storageClient.Bucket(args.bucketName) + uploadPool := pool.New().WithErrors().WithContext(client.ctx) + + // Upload each file in parallel + for fileIndex, openedFile := range openedFiles { + + fileIndex := fileIndex + openedFile := openedFile + + uploadPool.Go(func(ctx context.Context) error { + progressFn := func(bytesWritten int64) { progress.SetValue(fileIndex, float64(bytesWritten)) } + + if err := streamFileToBucket(ctx, &openedFile, bucket, progressFn); err != nil { + return errors.Wrap(err, openedFile.stat.Name()) } - out.WriteLine(output.Emoji(output.EmojiSuccess, "Summary contents uploaded!")) return nil - }, - usageFunc: func() { fmt.Fprint(flag.CommandLine.Output(), usage) }, - }) + }) + } + + // Wait for all uploads to complete + errs := uploadPool.Wait() + progress.Complete() + if errs != nil { + client.out.WriteLine(output.Line(output.EmojiFailure, output.StyleFailure, "Some file(s) failed to upload.")) + return errs + } + + client.out.WriteLine(output.Emoji(output.EmojiSuccess, "File(s) uploaded successfully!")) + return nil } -func copyDumpToBucket(ctx context.Context, src io.ReadSeeker, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64), trimExtensions bool) error { - // Set up object to write to - object := dst.Object(stat.Name()).NewWriter(ctx) - object.ProgressFunc = progressFn - defer object.Close() +func streamFileToBucket(ctx context.Context, file *uploadFile, bucket *storage.BucketHandle, progressFn func(int64)) error { + + // Set up GCS writer for the destination file + writer := bucket.Object(file.stat.Name()).NewWriter(ctx) + writer.ProgressFunc = progressFn + defer writer.Close() // To assert against actual file size - var totalWritten int64 + var totalBytesWritten int64 - // Do a partial copy, that filters out incompatible statements - if trimExtensions { - written, err := pgdump.FilterInvalidLines(object, src, progressFn) + // Start a partial copy, that filters out incompatible statements + if file.filterSQL { + bytesWritten, err := pgdump.FilterInvalidLines(writer, file.file, progressFn) if err != nil { return errors.Wrap(err, "filter out incompatible statements and upload") } - totalWritten += written + totalBytesWritten += bytesWritten } // io.Copy is the best way to copy from a reader to writer in Go, // storage.Writer has its own chunking mechanisms internally. // io.Reader is stateful, so this copy will just continue from where FilterInvalidLines left off, if used - written, err := io.Copy(object, src) + bytesWritten, err := io.Copy(writer, file.file) if err != nil { return errors.Wrap(err, "upload") } - totalWritten += written + totalBytesWritten += bytesWritten // Progress is not called on completion of io.Copy, // so we call it manually after to update our pretty progress bars. - progressFn(written) + progressFn(bytesWritten) // Validate we have sent all data. // FilterInvalidLines may add some bytes, so the check is not a strict equality. - size := stat.Size() - if totalWritten < size { + fileSize := file.stat.Size() + if totalBytesWritten < fileSize { return errors.Newf("expected to write %d bytes, but actually wrote %d bytes (diff: %d bytes)", - size, totalWritten, totalWritten-size) + fileSize, totalBytesWritten, totalBytesWritten-fileSize) } return nil diff --git a/cmd/src/snapshot_upload_test.go b/cmd/src/snapshot_upload_test.go new file mode 100644 index 0000000000..568e582111 --- /dev/null +++ b/cmd/src/snapshot_upload_test.go @@ -0,0 +1,286 @@ +package main + +import ( + "os" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/sourcegraph/src-cli/internal/pgdump" +) + +func setupSnapshotFiles(t *testing.T) string { + t.Helper() + dir := t.TempDir() + + // Create the snapshot directory structure + snapshotDir := filepath.Join(dir, "src-snapshot") + err := os.Mkdir(snapshotDir, 0755) + require.NoError(t, err) + + // Create summary.json + summaryPath := filepath.Join(snapshotDir, "summary.json") + err = os.WriteFile(summaryPath, []byte(`{"version": "test"}`), 0644) + require.NoError(t, err) + + // Create database dump files + for _, output := range pgdump.Outputs(snapshotDir, pgdump.Targets{}) { + err = os.WriteFile(output.Output, []byte("-- test SQL dump"), 0644) + require.NoError(t, err) + } + + return snapshotDir +} + +func TestFileFilterValidation(t *testing.T) { + tests := []struct { + name string + fileFlag string + wantError bool + }{ + { + name: "valid: summary.json", + fileFlag: "summary.json", + wantError: false, + }, + { + name: "valid: pgsql.sql", + fileFlag: "pgsql.sql", + wantError: false, + }, + { + name: "valid: codeintel.sql", + fileFlag: "codeintel.sql", + wantError: false, + }, + { + name: "valid: codeinsights.sql", + fileFlag: "codeinsights.sql", + wantError: false, + }, + { + name: "valid: empty (all files)", + fileFlag: "", + wantError: false, + }, + { + name: "valid: comma-delimited", + fileFlag: "summary.json,pgsql.sql", + wantError: false, + }, + { + name: "valid: comma-delimited with spaces", + fileFlag: "summary.json, pgsql.sql, codeintel.sql", + wantError: false, + }, + { + name: "invalid: unknown file", + fileFlag: "unknown", + wantError: true, + }, + { + name: "invalid: typo", + fileFlag: "primry", + wantError: true, + }, + { + name: "invalid: one valid, one invalid", + fileFlag: "summary.json,invalid", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the validation logic using the shared listOfValidFiles + var hasError bool + if tt.fileFlag == "" { + // Empty is valid (defaults to all files) + hasError = false + } else { + parts := strings.Split(tt.fileFlag, ",") + for _, part := range parts { + filename := strings.TrimSpace(part) + + if !slices.Contains(listOfValidFiles, filename) { + hasError = true + break + } + } + } + + if tt.wantError { + require.True(t, hasError, "expected invalid file flag to be rejected") + } else { + require.False(t, hasError, "expected valid file flag to be accepted") + } + }) + } +} + +func TestFileSelection(t *testing.T) { + snapshotDir := setupSnapshotFiles(t) + + tests := []struct { + name string + fileFilter string + expectedFiles []string + }{ + { + name: "no filter - all files", + fileFilter: "", + expectedFiles: []string{ + "codeinsights.sql", + "codeintel.sql", + "pgsql.sql", + "summary.json", + }, + }, + { + name: "summary only", + fileFilter: "summary.json", + expectedFiles: []string{ + "summary.json", + }, + }, + { + name: "pgsql only", + fileFilter: "pgsql.sql", + expectedFiles: []string{ + "pgsql.sql", + }, + }, + { + name: "codeintel only", + fileFilter: "codeintel.sql", + expectedFiles: []string{ + "codeintel.sql", + }, + }, + { + name: "codeinsights only", + fileFilter: "codeinsights.sql", + expectedFiles: []string{ + "codeinsights.sql", + }, + }, + { + name: "comma-delimited: summary and pgsql", + fileFilter: "summary.json,pgsql.sql", + expectedFiles: []string{ + "summary.json", + "pgsql.sql", + }, + }, + { + name: "comma-delimited: all database files", + fileFilter: "pgsql.sql,codeintel.sql,codeinsights.sql", + expectedFiles: []string{ + "pgsql.sql", + "codeintel.sql", + "codeinsights.sql", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse file filter into list (mimicking parseFileFilter logic) + var filesToUpload []string + if tt.fileFilter == "" { + filesToUpload = listOfValidFiles + } else { + parts := strings.Split(tt.fileFilter, ",") + for _, part := range parts { + filename := strings.TrimSpace(part) + filesToUpload = append(filesToUpload, filename) + } + } + + // Simulate the file opening logic from openFilesAndCreateProgressBars + var selectedFiles []string + for _, selectedFile := range filesToUpload { + // Construct path and check if file matches + filePath := filepath.Join(snapshotDir, selectedFile) + if _, err := os.Stat(filePath); err == nil { + selectedFiles = append(selectedFiles, selectedFile) + } + } + + require.Equal(t, tt.expectedFiles, selectedFiles, "selected files should match expected") + }) + } +} + +func TestFilterSQLBehavior(t *testing.T) { + tests := []struct { + name string + isSummary bool + filterSQLFlag bool + expectedFilter bool + }{ + { + name: "summary file - filterSQL should be false", + isSummary: true, + filterSQLFlag: true, + expectedFilter: false, + }, + { + name: "database dump - filterSQL flag true", + isSummary: false, + filterSQLFlag: true, + expectedFilter: true, + }, + { + name: "database dump - filterSQL flag false", + isSummary: false, + filterSQLFlag: false, + expectedFilter: false, + }, + { + name: "summary file - filterSQL flag false (should still be false)", + isSummary: true, + filterSQLFlag: false, + expectedFilter: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Simulate the filterSQL logic from snapshot_upload.go + actualFilter := !tt.isSummary && tt.filterSQLFlag + + require.Equal(t, tt.expectedFilter, actualFilter, "filterSQL should be set correctly") + }) + } +} + +func TestDatabaseOutputs(t *testing.T) { + snapshotDir := setupSnapshotFiles(t) + + outputs := pgdump.Outputs(snapshotDir, pgdump.Targets{}) + + // Should have exactly 3 database files + require.Len(t, outputs, 3, "should have 3 database outputs") + + expectedFiles := map[string]bool{ + "pgsql.sql": false, + "codeintel.sql": false, + "codeinsights.sql": false, + } + + for _, output := range outputs { + fileName := filepath.Base(output.Output) + _, exists := expectedFiles[fileName] + require.True(t, exists, "unexpected file: %s", fileName) + expectedFiles[fileName] = true + } + + // Verify all expected files were found + for fileName, found := range expectedFiles { + require.True(t, found, "expected file not found: %s", fileName) + } +}