diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3fa92833..d3a64e2a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,7 +26,7 @@ jobs: run: | # This corresponds with the list in Makefile:1, but omits the "userns" # and "capability" modules, which require go1.21 as minimum. - echo 'PACKAGES=mountinfo mount reexec sequential signal symlink user' >> $GITHUB_ENV + echo 'PACKAGES=atomicwriter mountinfo mount reexec sequential signal symlink user' >> $GITHUB_ENV - name: go mod tidy run: | make foreach CMD="go mod tidy" diff --git a/Makefile b/Makefile index bc0c1246..7f68e84d 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PACKAGES ?= capability mountinfo mount reexec sequential signal symlink user userns # IMPORTANT: when updating this list, also update the conditional one in .github/workflows/test.yml +PACKAGES ?= atomicwriter capability mountinfo mount reexec sequential signal symlink user userns # IMPORTANT: when updating this list, also update the conditional one in .github/workflows/test.yml BINDIR ?= _build/bin CROSS ?= linux/arm linux/arm64 linux/ppc64le linux/s390x \ freebsd/amd64 openbsd/amd64 darwin/amd64 darwin/arm64 windows/amd64 @@ -29,9 +29,12 @@ test: test-local test: CMD=go test $(RUN_VIA_SUDO) -v -coverprofile=coverage.txt -covermode=atomic . test: foreach -# Test the mount module against the local mountinfo source code instead of the -# release specified in its go.mod. This allows catching regressions / breaking -# changes in mountinfo. +# Some modules in this repo have interdependencies: +# - mount depends on mountinfo +# - atomicwrite depends on sequential +# +# The code below tests these modules against their local dependencies +# to catch regressions / breaking changes early. .PHONY: test-local test-local: MOD = -modfile=go-local.mod test-local: @@ -39,6 +42,10 @@ test-local: # Run go mod tidy to make sure mountinfo dependency versions are met. cd mount && go mod tidy $(MOD) && go test $(MOD) $(RUN_VIA_SUDO) -v . $(RM) mount/go-local.* + echo 'replace github.com/moby/sys/sequential => ../sequential' | cat atomicwriter/go.mod - > atomicwriter/go-local.mod + # Run go mod tidy to make sure sequential dependency versions are met. + cd atomicwriter && go mod tidy $(MOD) && go test $(MOD) $(RUN_VIA_SUDO) -v . + $(RM) atomicwriter/go-local.* .PHONY: lint lint: $(BINDIR)/golangci-lint diff --git a/atomicwriter/atomicwriter.go b/atomicwriter/atomicwriter.go new file mode 100644 index 00000000..d0d3be88 --- /dev/null +++ b/atomicwriter/atomicwriter.go @@ -0,0 +1,245 @@ +// Package atomicwriter provides utilities to perform atomic writes to a +// file or set of files. +package atomicwriter + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "syscall" + + "github.com/moby/sys/sequential" +) + +func validateDestination(fileName string) error { + if fileName == "" { + return errors.New("file name is empty") + } + if dir := filepath.Dir(fileName); dir != "" && dir != "." && dir != ".." { + di, err := os.Stat(dir) + if err != nil { + return fmt.Errorf("invalid output path: %w", err) + } + if !di.IsDir() { + return fmt.Errorf("invalid output path: %w", &os.PathError{Op: "stat", Path: dir, Err: syscall.ENOTDIR}) + } + } + + // Deliberately using Lstat here to match the behavior of [os.Rename], + // which is used when completing the write and does not resolve symlinks. + fi, err := os.Lstat(fileName) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to stat output path: %w", err) + } + + switch mode := fi.Mode(); { + case mode.IsRegular(): + return nil // Regular file + case mode&os.ModeDir != 0: + return errors.New("cannot write to a directory") + case mode&os.ModeSymlink != 0: + return errors.New("cannot write to a symbolic link directly") + case mode&os.ModeNamedPipe != 0: + return errors.New("cannot write to a named pipe (FIFO)") + case mode&os.ModeSocket != 0: + return errors.New("cannot write to a socket") + case mode&os.ModeDevice != 0: + if mode&os.ModeCharDevice != 0 { + return errors.New("cannot write to a character device file") + } + return errors.New("cannot write to a block device file") + case mode&os.ModeSetuid != 0: + return errors.New("cannot write to a setuid file") + case mode&os.ModeSetgid != 0: + return errors.New("cannot write to a setgid file") + case mode&os.ModeSticky != 0: + return errors.New("cannot write to a sticky bit file") + default: + return fmt.Errorf("unknown file mode: %[1]s (%#[1]o)", mode) + } +} + +// New returns a WriteCloser so that writing to it writes to a +// temporary file and closing it atomically changes the temporary file to +// destination path. Writing and closing concurrently is not allowed. +// NOTE: umask is not considered for the file's permissions. +// +// New uses [sequential.CreateTemp] to use sequential file access on Windows, +// avoiding depleting the standby list un-necessarily. On Linux, this equates to +// a regular [os.CreateTemp]. Refer to the [Win32 API documentation] for details +// on sequential file access. +// +// [Win32 API documentation]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea#FILE_FLAG_SEQUENTIAL_SCAN +func New(filename string, perm os.FileMode) (io.WriteCloser, error) { + if err := validateDestination(filename); err != nil { + return nil, err + } + abspath, err := filepath.Abs(filename) + if err != nil { + return nil, err + } + + f, err := sequential.CreateTemp(filepath.Dir(abspath), ".tmp-"+filepath.Base(filename)) + if err != nil { + return nil, err + } + return &atomicFileWriter{ + f: f, + fn: abspath, + perm: perm, + }, nil +} + +// WriteFile atomically writes data to a file named by filename and with the +// specified permission bits. The given filename is created if it does not exist, +// but the destination directory must exist. It can be used as a drop-in replacement +// for [os.WriteFile], but currently does not allow the destination path to be +// a symlink. WriteFile is implemented using [New] for its implementation. +// +// NOTE: umask is not considered for the file's permissions. +func WriteFile(filename string, data []byte, perm os.FileMode) error { + f, err := New(filename, perm) + if err != nil { + return err + } + n, err := f.Write(data) + if err == nil && n < len(data) { + err = io.ErrShortWrite + f.(*atomicFileWriter).writeErr = err + } + if err1 := f.Close(); err == nil { + err = err1 + } + return err +} + +type atomicFileWriter struct { + f *os.File + fn string + writeErr error + written bool + perm os.FileMode +} + +func (w *atomicFileWriter) Write(dt []byte) (int, error) { + w.written = true + n, err := w.f.Write(dt) + if err != nil { + w.writeErr = err + } + return n, err +} + +func (w *atomicFileWriter) Close() (retErr error) { + defer func() { + if err := os.Remove(w.f.Name()); !errors.Is(err, os.ErrNotExist) && retErr == nil { + retErr = err + } + }() + if err := w.f.Sync(); err != nil { + _ = w.f.Close() + return err + } + if err := w.f.Close(); err != nil { + return err + } + if err := os.Chmod(w.f.Name(), w.perm); err != nil { + return err + } + if w.writeErr == nil && w.written { + return os.Rename(w.f.Name(), w.fn) + } + return nil +} + +// WriteSet is used to atomically write a set +// of files and ensure they are visible at the same time. +// Must be committed to a new directory. +type WriteSet struct { + root string +} + +// NewWriteSet creates a new atomic write set to +// atomically create a set of files. The given directory +// is used as the base directory for storing files before +// commit. If no temporary directory is given the system +// default is used. +func NewWriteSet(tmpDir string) (*WriteSet, error) { + td, err := os.MkdirTemp(tmpDir, "write-set-") + if err != nil { + return nil, err + } + + return &WriteSet{ + root: td, + }, nil +} + +// WriteFile writes a file to the set, guaranteeing the file +// has been synced. +func (ws *WriteSet) WriteFile(filename string, data []byte, perm os.FileMode) error { + f, err := ws.FileWriter(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + n, err := f.Write(data) + if err == nil && n < len(data) { + err = io.ErrShortWrite + } + if err1 := f.Close(); err == nil { + err = err1 + } + return err +} + +type syncFileCloser struct { + *os.File +} + +func (w syncFileCloser) Close() error { + err := w.File.Sync() + if err1 := w.File.Close(); err == nil { + err = err1 + } + return err +} + +// FileWriter opens a file writer inside the set. The file +// should be synced and closed before calling commit. +// +// FileWriter uses [sequential.OpenFile] to use sequential file access on Windows, +// avoiding depleting the standby list un-necessarily. On Linux, this equates to +// a regular [os.OpenFile]. Refer to the [Win32 API documentation] for details +// on sequential file access. +// +// [Win32 API documentation]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea#FILE_FLAG_SEQUENTIAL_SCAN +func (ws *WriteSet) FileWriter(name string, flag int, perm os.FileMode) (io.WriteCloser, error) { + f, err := sequential.OpenFile(filepath.Join(ws.root, name), flag, perm) + if err != nil { + return nil, err + } + return syncFileCloser{f}, nil +} + +// Cancel cancels the set and removes all temporary data +// created in the set. +func (ws *WriteSet) Cancel() error { + return os.RemoveAll(ws.root) +} + +// Commit moves all created files to the target directory. The +// target directory must not exist and the parent of the target +// directory must exist. +func (ws *WriteSet) Commit(target string) error { + return os.Rename(ws.root, target) +} + +// String returns the location the set is writing to. +func (ws *WriteSet) String() string { + return ws.root +} diff --git a/atomicwriter/atomicwriter_test.go b/atomicwriter/atomicwriter_test.go new file mode 100644 index 00000000..e98f7f33 --- /dev/null +++ b/atomicwriter/atomicwriter_test.go @@ -0,0 +1,325 @@ +package atomicwriter + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "runtime" + "strings" + "syscall" + "testing" +) + +// testMode returns the file-mode to use in tests, accounting for Windows +// not supporting full Linux file mode. +func testMode() os.FileMode { + if runtime.GOOS == "windows" { + return 0o666 + } + return 0o640 +} + +// assertFile asserts the given fileName to exist, and to have the expected +// content and mode. +func assertFile(t *testing.T, fileName string, fileContent []byte, expectedMode os.FileMode) { + t.Helper() + actual, err := os.ReadFile(fileName) + if err != nil { + t.Fatalf("Error reading from file: %v", err) + } + + if !bytes.Equal(actual, fileContent) { + t.Errorf("Data mismatch, expected %q, got %q", fileContent, actual) + } + + st, err := os.Stat(fileName) + if err != nil { + t.Fatalf("Error statting file: %v", err) + } + if st.Mode() != expectedMode { + t.Errorf("Mode mismatched, expected %o, got %o", expectedMode, st.Mode()) + } +} + +// assertFileCount asserts the given directory has the expected number +// of files, and returns the list of files found. +func assertFileCount(t *testing.T, directory string, expected int) []os.DirEntry { + t.Helper() + files, err := os.ReadDir(directory) + if err != nil { + t.Fatalf("Error reading dir: %v", err) + } + if len(files) != expected { + t.Errorf("Expected %d files, got %d: %v", expected, len(files), files) + } + return files +} + +func TestNew(t *testing.T) { + for _, tc := range []string{"normal", "symlinked"} { + tmpDir := t.TempDir() + parentDir := tmpDir + actualParentDir := parentDir + if tc == "symlinked" { + actualParentDir = filepath.Join(tmpDir, "parent-dir") + if err := os.Mkdir(actualParentDir, 0o700); err != nil { + t.Fatal(err) + } + parentDir = filepath.Join(tmpDir, "parent-dir-symlink") + if err := os.Symlink(actualParentDir, parentDir); err != nil { + t.Fatal(err) + } + } + t.Run(tc, func(t *testing.T) { + for _, tc := range []string{"new-file", "existing-file"} { + t.Run(tc, func(t *testing.T) { + fileName := filepath.Join(parentDir, "test.txt") + var origFileCount int + if tc == "existing-file" { + if err := os.WriteFile(fileName, []byte("original content"), testMode()); err != nil { + t.Fatalf("Error writing file: %v", err) + } + origFileCount = 1 + } + writer, err := New(fileName, testMode()) + if writer == nil { + t.Errorf("Writer is nil") + } + if err != nil { + t.Fatalf("Error creating new atomicwriter: %v", err) + } + files := assertFileCount(t, actualParentDir, origFileCount+1) + if tmpFileName := files[0].Name(); !strings.HasPrefix(tmpFileName, ".tmp-test.txt") { + t.Errorf("Unexpected file name for temp-file: %s", tmpFileName) + } + + // Closing the writer without writing should clean up the temp-file, + // and should not replace the destination file. + if err = writer.Close(); err != nil { + t.Errorf("Error closing writer: %v", err) + } + assertFileCount(t, actualParentDir, origFileCount) + if tc == "existing-file" { + assertFile(t, fileName, []byte("original content"), testMode()) + } + }) + } + }) + } +} + +func TestNewInvalid(t *testing.T) { + t.Run("missing target dir", func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, "missing-dir", "test.txt") + writer, err := New(fileName, testMode()) + if writer != nil { + t.Errorf("Should not have created writer") + } + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("Should produce a 'not found' error, but got %[1]T (%[1]v)", err) + } + }) + t.Run("target dir is not a directory", func(t *testing.T) { + tmpDir := t.TempDir() + parentPath := filepath.Join(tmpDir, "not-a-dir") + err := os.WriteFile(parentPath, nil, testMode()) + if err != nil { + t.Fatalf("Error writing file: %v", err) + } + fileName := filepath.Join(parentPath, "new-file.txt") + writer, err := New(fileName, testMode()) + if writer != nil { + t.Errorf("Should not have created writer") + } + // This should match the behavior of os.WriteFile, which returns a [os.PathError] with [syscall.ENOTDIR]. + if !errors.Is(err, syscall.ENOTDIR) { + t.Errorf("Should produce a 'not a directory' error, but got %[1]T (%[1]v)", err) + } + }) + t.Run("empty filename", func(t *testing.T) { + writer, err := New("", testMode()) + if writer != nil { + t.Errorf("Should not have created writer") + } + if err == nil || err.Error() != "file name is empty" { + t.Errorf("Should produce a 'file name is empty' error, but got %[1]T (%[1]v)", err) + } + }) + t.Run("directory", func(t *testing.T) { + tmpDir := t.TempDir() + writer, err := New(tmpDir, testMode()) + if writer != nil { + t.Errorf("Should not have created writer") + } + if err == nil || err.Error() != "cannot write to a directory" { + t.Errorf("Should produce a 'cannot write to a directory' error, but got %[1]T (%[1]v)", err) + } + }) + t.Run("symlinked file", func(t *testing.T) { + tmpDir := t.TempDir() + linkTarget := filepath.Join(tmpDir, "symlink-target") + if err := os.WriteFile(linkTarget, []byte("orig content"), testMode()); err != nil { + t.Fatal(err) + } + fileName := filepath.Join(tmpDir, "symlinked-file") + if err := os.Symlink(linkTarget, fileName); err != nil { + t.Fatal(err) + } + writer, err := New(fileName, testMode()) + if writer != nil { + t.Errorf("Should not have created writer") + } + if err == nil || err.Error() != "cannot write to a symbolic link directly" { + t.Errorf("Should produce a 'cannot write to a symbolic link directly' error, but got %[1]T (%[1]v)", err) + } + }) +} + +func TestWriteFile(t *testing.T) { + t.Run("empty filename", func(t *testing.T) { + err := WriteFile("", nil, testMode()) + if err == nil || err.Error() != "file name is empty" { + t.Errorf("Should produce a 'file name is empty' error, but got %[1]T (%[1]v)", err) + } + }) + t.Run("write to directory", func(t *testing.T) { + err := WriteFile(t.TempDir(), nil, testMode()) + if err == nil || err.Error() != "cannot write to a directory" { + t.Errorf("Should produce a 'cannot write to a directory' error, but got %[1]T (%[1]v)", err) + } + }) + t.Run("write to file", func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, "test.txt") + fileContent := []byte("file content") + fileMode := testMode() + if err := WriteFile(fileName, fileContent, fileMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + assertFile(t, fileName, fileContent, fileMode) + assertFileCount(t, tmpDir, 1) + }) + t.Run("missing parent directory", func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, "missing-dir", "test.txt") + fileContent := []byte("file content") + fileMode := testMode() + if err := WriteFile(fileName, fileContent, fileMode); !errors.Is(err, os.ErrNotExist) { + t.Errorf("Should produce a 'not found' error, but got %[1]T (%[1]v)", err) + } + assertFileCount(t, tmpDir, 0) + }) + t.Run("symlinked file", func(t *testing.T) { + tmpDir := t.TempDir() + linkTarget := filepath.Join(tmpDir, "symlink-target") + originalContent := []byte("original content") + fileMode := testMode() + if err := os.WriteFile(linkTarget, originalContent, fileMode); err != nil { + t.Fatal(err) + } + if err := os.Symlink(linkTarget, filepath.Join(tmpDir, "symlinked-file")); err != nil { + t.Fatal(err) + } + origFileCount := 2 + assertFileCount(t, tmpDir, origFileCount) + + fileName := filepath.Join(tmpDir, "symlinked-file") + err := WriteFile(fileName, []byte("new content"), testMode()) + if err == nil || err.Error() != "cannot write to a symbolic link directly" { + t.Errorf("Should produce a 'cannot write to a symbolic link directly' error, but got %[1]T (%[1]v)", err) + } + assertFile(t, linkTarget, originalContent, fileMode) + assertFileCount(t, tmpDir, origFileCount) + }) + t.Run("symlinked directory", func(t *testing.T) { + tmpDir := t.TempDir() + actualParentDir := filepath.Join(tmpDir, "parent-dir") + if err := os.Mkdir(actualParentDir, 0o700); err != nil { + t.Fatal(err) + } + actualTargetFile := filepath.Join(actualParentDir, "target-file") + if err := os.WriteFile(actualTargetFile, []byte("orig content"), testMode()); err != nil { + t.Fatal(err) + } + parentDir := filepath.Join(tmpDir, "parent-dir-symlink") + if err := os.Symlink(actualParentDir, parentDir); err != nil { + t.Fatal(err) + } + origFileCount := 1 + assertFileCount(t, actualParentDir, origFileCount) + + fileName := filepath.Join(parentDir, "target-file") + fileContent := []byte("new content") + fileMode := testMode() + if err := WriteFile(fileName, fileContent, fileMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + assertFile(t, fileName, fileContent, fileMode) + assertFile(t, actualTargetFile, fileContent, fileMode) + assertFileCount(t, actualParentDir, origFileCount) + }) +} + +func TestWriteSetCommit(t *testing.T) { + tmpDir := t.TempDir() + + if err := os.Mkdir(filepath.Join(tmpDir, "tmp"), 0o700); err != nil { + t.Fatalf("Error creating tmp directory: %s", err) + } + + targetDir := filepath.Join(tmpDir, "target") + ws, err := NewWriteSet(filepath.Join(tmpDir, "tmp")) + if err != nil { + t.Fatalf("Error creating atomic write set: %s", err) + } + + fileContent := []byte("file content") + fileMode := testMode() + + if err := ws.WriteFile("foo", fileContent, fileMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + + if _, err := os.ReadFile(filepath.Join(targetDir, "foo")); err == nil { + t.Fatalf("Expected error reading file where should not exist") + } + + if err := ws.Commit(targetDir); err != nil { + t.Fatalf("Error committing file: %s", err) + } + + assertFile(t, filepath.Join(targetDir, "foo"), fileContent, fileMode) + assertFileCount(t, targetDir, 1) +} + +func TestWriteSetCancel(t *testing.T) { + tmpDir := t.TempDir() + + if err := os.Mkdir(filepath.Join(tmpDir, "tmp"), 0o700); err != nil { + t.Fatalf("Error creating tmp directory: %s", err) + } + + ws, err := NewWriteSet(filepath.Join(tmpDir, "tmp")) + if err != nil { + t.Fatalf("Error creating atomic write set: %s", err) + } + + fileContent := []byte("file content") + fileMode := testMode() + if err := ws.WriteFile("foo", fileContent, fileMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + + if err := ws.Cancel(); err != nil { + t.Fatalf("Error committing file: %s", err) + } + + if _, err := os.ReadFile(filepath.Join(tmpDir, "target", "foo")); err == nil { + t.Fatalf("Expected error reading file where should not exist") + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("Unexpected error reading file: %s", err) + } + assertFileCount(t, filepath.Join(tmpDir, "tmp"), 0) +} diff --git a/atomicwriter/go.mod b/atomicwriter/go.mod new file mode 100644 index 00000000..cb5908ea --- /dev/null +++ b/atomicwriter/go.mod @@ -0,0 +1,7 @@ +module github.com/moby/sys/atomicwriter + +go 1.18 + +require github.com/moby/sys/sequential v0.6.0 + +require golang.org/x/sys v0.1.0 // indirect diff --git a/atomicwriter/go.sum b/atomicwriter/go.sum new file mode 100644 index 00000000..b9a5d4aa --- /dev/null +++ b/atomicwriter/go.sum @@ -0,0 +1,4 @@ +github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= +github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=