diff --git a/internal/services/ecr/store.go b/internal/services/ecr/store.go index bffca9e..44e9237 100644 --- a/internal/services/ecr/store.go +++ b/internal/services/ecr/store.go @@ -14,6 +14,7 @@ import ( "path/filepath" "time" + "github.com/skyoo2003/devcloud/internal/shared" "github.com/skyoo2003/devcloud/internal/storage/sqlite" ) @@ -387,6 +388,12 @@ func (s *ECRStore) InitiateLayerUpload(accountID, repoName string) (string, erro // UploadLayerPart saves layer part blob to the filesystem and records the part metadata. func (s *ECRStore) UploadLayerPart(accountID, repoName, uploadID string, partFirst, partLast int64, blob []byte) error { + // uploadID is expected to be a 32-char lowercase hex value generated by InitiateLayerUpload. + // Reject anything else to prevent path traversal/path injection. + if !shared.ValidateUploadID(uploadID) { + return ErrLayerUploadNotFound + } + // Verify upload exists. var exists int _ = s.db().QueryRow(`SELECT COUNT(*) FROM layers WHERE upload_id=? AND repo_name=? AND account_id=?`, uploadID, repoName, accountID).Scan(&exists) diff --git a/internal/services/s3/provider.go b/internal/services/s3/provider.go index ba44a90..951eaa1 100644 --- a/internal/services/s3/provider.go +++ b/internal/services/s3/provider.go @@ -16,7 +16,8 @@ import ( "net/url" "os" "path/filepath" - "regexp" + + "github.com/skyoo2003/devcloud/internal/shared" "sort" "strconv" "strings" @@ -461,12 +462,6 @@ func generateUploadID() (string, error) { return hex.EncodeToString(b), nil } -var uploadIDPattern = regexp.MustCompile(`^[a-f0-9]{32}$`) - -func isValidUploadID(uploadID string) bool { - return uploadIDPattern.MatchString(uploadID) -} - // multipartDir returns the directory used to store parts for an upload. func (p *S3Provider) multipartDir(uploadID string) string { return filepath.Join(p.fileStore.baseDir, "_multipart", filepath.Base(uploadID)) @@ -891,7 +886,7 @@ func (p *S3Provider) uploadPart(_ context.Context, bucket, key, uploadID, partNu } func (p *S3Provider) completeMultipartUpload(_ context.Context, bucket, key, uploadID string, req *http.Request) (*plugin.Response, error) { - if !isValidUploadID(uploadID) { + if !shared.ValidateUploadID(uploadID) { return xmlError("NoSuchUpload", "upload not found", http.StatusNotFound), nil } @@ -971,7 +966,7 @@ func (p *S3Provider) completeMultipartUpload(_ context.Context, bucket, key, upl } func (p *S3Provider) abortMultipartUpload(_ context.Context, bucket, key, uploadID string) (*plugin.Response, error) { - if !isValidUploadID(uploadID) { + if !shared.ValidateUploadID(uploadID) { return xmlError("InvalidRequest", "invalid uploadId", http.StatusBadRequest), nil } if _, err := p.metaStore.GetMultipartUpload(uploadID); err != nil { diff --git a/internal/shared/validate.go b/internal/shared/validate.go new file mode 100644 index 0000000..9324d62 --- /dev/null +++ b/internal/shared/validate.go @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 + +package shared + +// ValidateUploadID checks that id is a 32-character lowercase hex string, +// matching the format produced by InitiateLayerUpload in S3 and ECR. +func ValidateUploadID(id string) bool { + if len(id) != 32 { + return false + } + for _, c := range id { + if !isHexLower(c) { + return false + } + } + return true +} + +func isHexLower(c rune) bool { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') +}