diff --git a/README.md b/README.md index 87ce2c3..6779bb9 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![lint](https://github.com/hyp3rd/sectools/actions/workflows/lint.yml/badge.svg)](https://github.com/hyp3rd/sectools/actions/workflows/lint.yml) [![test](https://github.com/hyp3rd/sectools/actions/workflows/test.yml/badge.svg)](https://github.com/hyp3rd/sectools/actions/workflows/test.yml) [![security](https://github.com/hyp3rd/sectools/actions/workflows/security.yml/badge.svg)](https://github.com/hyp3rd/sectools/actions/workflows/security.yml) -Security-focused Go helpers for file I/O, in-memory handling of sensitive data, auth tokens, password hashing, and safe numeric conversions. +Security-focused Go helpers for file I/O, in-memory handling of sensitive data, auth tokens, password hashing, input validation/sanitization, and safe numeric conversions. ## Features @@ -16,6 +16,8 @@ Security-focused Go helpers for file I/O, in-memory handling of sensitive data, - Secure in-memory buffers with best-effort zeroization - JWT/PASETO helpers with strict validation and safe defaults - Password hashing presets for argon2id/bcrypt with rehash detection +- Email and URL validation with optional DNS/redirect/reputation checks +- HTML/Markdown sanitization, SQL input guards, and filename sanitizers - Safe integer conversion helpers with overflow/negative guards ## Requirements @@ -200,6 +202,69 @@ func main() { } ``` +### Input validation + +```go +package main + +import ( + "context" + + "github.com/hyp3rd/sectools/pkg/validate" +) + +func main() { + emailValidator, err := validate.NewEmailValidator( + validate.WithEmailVerifyDomain(true), + ) + if err != nil { + panic(err) + } + + _, _ = emailValidator.Validate(context.Background(), "user@example.com") + + urlValidator, err := validate.NewURLValidator( + validate.WithURLCheckRedirects(3), + ) + if err != nil { + panic(err) + } + + _, _ = urlValidator.Validate(context.Background(), "https://example.com") +} +``` + +### Sanitization + +```go +package main + +import ( + "github.com/hyp3rd/sectools/pkg/sanitize" +) + +func main() { + htmlSanitizer, err := sanitize.NewHTMLSanitizer() + if err != nil { + panic(err) + } + + safeHTML, _ := htmlSanitizer.Sanitize("hello") + + sqlSanitizer, err := sanitize.NewSQLSanitizer( + sanitize.WithSQLMode(sanitize.SQLModeIdentifier), + sanitize.WithSQLAllowQualifiedIdentifiers(true), + ) + if err != nil { + panic(err) + } + + safeIdentifier, _ := sqlSanitizer.Sanitize("public.users") + + _, _ = safeHTML, safeIdentifier +} +``` + ## Security and behavior notes - `ReadFile` only permits relative paths under `os.TempDir()` by default. Use `NewWithOptions` with `WithAllowAbsolute` to allow absolute paths or alternate roots. diff --git a/cspell.json b/cspell.json index cebd178..66db5c4 100644 --- a/cspell.json +++ b/cspell.json @@ -30,6 +30,7 @@ "anchore", "argon2", "argon2id", + "Atext", "Atoi", "aud", "autobuild", @@ -37,6 +38,7 @@ "behaviour", "benchmem", "benchtime", + "bücher", "bufbuild", "CODEOWNERS", "CodeQL", @@ -85,9 +87,12 @@ "GOTOOLCHAIN", "govulncheck", "honnef", + "hostnames", "HS256", "hyperlogger", "iat", + "IDN", + "idna", "Infof", "internalio", "ints", @@ -97,15 +102,19 @@ "jwt", "Keyfunc", "kid", + "localhost", "localmodule", + "markdown", "memprofile", "mkdocs", "mlock", "mvdan", + "MX", "myproject", "mypy", "myuser", "nbf", + "nethtml", "Newf", "nolint", "nonamedreturns", @@ -126,6 +135,10 @@ "pycache", "recvcheck", "Renovate", + "sanitization", + "sanitize", + "sanitizer", + "sanitizers", "SBOM", "SBOMs", "sectauth", @@ -137,6 +150,7 @@ "sigstore", "SLSA", "softprops", + "sql", "staticcheck", "stdlib", "strconv", @@ -146,6 +160,7 @@ "syncdir", "syncer", "tagalign", + "tautologies", "TempDir", "TempFile", "TempPrefix", @@ -154,6 +169,8 @@ "Tracef", "uid", "umask", + "URLHTTP", + "userinfo", "varnamelen", "vettool", "Warnf", diff --git a/docs/security-checklist.md b/docs/security-checklist.md index 2569902..26d3bb4 100644 --- a/docs/security-checklist.md +++ b/docs/security-checklist.md @@ -31,6 +31,18 @@ This checklist is a quick reference for teams using sectools in production. - Rehash stored passwords when `needsRehash` is true. - Enforce bcrypt's 72-byte limit to avoid silent truncation. +## Input Validation + +- Use `pkg/validate` for email/URL parsing instead of ad-hoc regexes. +- Enable DNS verification only when you can tolerate network lookups and timeouts. +- Keep URL schemes restricted and avoid enabling private IPs unless required. + +## Sanitization + +- Use `pkg/sanitize` for HTML/Markdown sanitization instead of ad-hoc escaping. +- Prefer parameterized SQL queries; use `SQLSanitizer` only for identifiers or literals when needed. +- Use `SQLInjectionDetector` as a heuristic guard for untrusted input before query composition. + ## Cleanup - Use `Remove`/`RemoveAll` to enforce root scoping. diff --git a/docs/usage.md b/docs/usage.md index acb4382..2811de3 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -8,6 +8,8 @@ supporting implementations in `internal/`. - `pkg/io`: secure file read/write helpers. - `pkg/auth`: JWT and PASETO helpers with strict validation. - `pkg/password`: password hashing helpers. +- `pkg/validate`: email and URL validation helpers. +- `pkg/sanitize`: HTML/Markdown sanitizers, SQL input guards, and filename sanitizers. - `pkg/memory`: secure in-memory buffers. - `pkg/converters`: safe numeric conversions. - `internal/io`: implementation details; not part of the public API contract. @@ -274,6 +276,103 @@ Behavior: - `Verify` returns `needsRehash` when parameters or cost drift from the current preset. - Bcrypt rejects passwords longer than 72 bytes to avoid silent truncation. +## pkg/validate + +### Email validation + +```go +func NewEmailValidator(opts ...EmailOption) (*EmailValidator, error) +func (v *EmailValidator) Validate(ctx context.Context, input string) (EmailResult, error) +``` + +Behavior: + +- Rejects display names by default; use `WithEmailAllowDisplayName(true)` to permit. +- Validates local part syntax (dot-atom by default); quoted local parts are optional. +- Validates domain labels and length; IDN domains require `WithEmailAllowIDN(true)`. +- Optional DNS verification with `WithEmailVerifyDomain(true)` using MX and optional A/AAAA fallback. + +### URL validation + +```go +func NewURLValidator(opts ...URLOption) (*URLValidator, error) +func (v *URLValidator) Validate(ctx context.Context, raw string) (URLResult, error) +``` + +Behavior: + +- Enforces `https` only; non-https schemes are rejected (including if configured). +- Rejects userinfo by default; use `WithURLAllowUserInfo(true)` to permit. +- Blocks private/loopback IPs by default; use `WithURLAllowPrivateIP(true)` to permit. +- Optional redirect checks with `WithURLCheckRedirects` and an HTTP client. +- Optional reputation checks with `WithURLReputationChecker`. + +## pkg/sanitize + +### HTML sanitization + +```go +func NewHTMLSanitizer(opts ...HTMLOption) (*HTMLSanitizer, error) +func (s *HTMLSanitizer) Sanitize(input string) (string, error) +``` + +Behavior: + +- Escapes HTML by default (`HTMLSanitizeEscape`). +- Supports stripping tags to plain text with `WithHTMLMode(HTMLSanitizeStrip)`. +- Allows custom policies via `WithHTMLPolicy`. + +### Markdown sanitization + +```go +func NewMarkdownSanitizer(opts ...MarkdownOption) (*MarkdownSanitizer, error) +func (s *MarkdownSanitizer) Sanitize(input string) (string, error) +``` + +Behavior: + +- Escapes raw HTML by default. +- Allows raw HTML with `WithMarkdownAllowRawHTML(true)`. + +### SQL sanitization + +```go +func NewSQLSanitizer(opts ...SQLOption) (*SQLSanitizer, error) +func (s *SQLSanitizer) Sanitize(input string) (string, error) +``` + +Behavior: + +- Identifier mode rejects unsafe characters and can allow dotted identifiers. +- Literal mode escapes single quotes using SQL-standard doubling. +- LIKE mode escapes `%`/`_` and the configured escape character. +- Always prefer parameterized queries; sanitization is a safety net. + +### SQL injection detection + +```go +func NewSQLInjectionDetector(opts ...SQLDetectOption) (*SQLInjectionDetector, error) +func (d *SQLInjectionDetector) Detect(input string) error +``` + +Behavior: + +- Flags common SQL injection patterns (comments, statement separators, tautologies). +- The detector is heuristic; tune patterns if your input includes SQL-like content. + +### Filename sanitization + +```go +func NewFilenameSanitizer(opts ...FilenameOption) (*FilenameSanitizer, error) +func (s *FilenameSanitizer) Sanitize(input string) (string, error) +``` + +Behavior: + +- Normalizes a single filename or path segment. +- Replaces disallowed characters with a configurable replacement rune. +- Rejects empty results and reserved dot segments. + ## pkg/memory `SecureBuffer` is a public type for holding sensitive data in memory. diff --git a/go.mod b/go.mod index dd09120..993c56c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/hyp3rd/hyperlogger v0.0.8 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.46.0 + golang.org/x/net v0.48.0 ) require ( @@ -16,6 +17,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sys v0.39.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f145d92..422d99d 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,12 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/sanitize/doc.go b/pkg/sanitize/doc.go new file mode 100644 index 0000000..1221d6d --- /dev/null +++ b/pkg/sanitize/doc.go @@ -0,0 +1,2 @@ +// Package sanitize provides safe-by-default sanitizers for untrusted input. +package sanitize diff --git a/pkg/sanitize/errors.go b/pkg/sanitize/errors.go new file mode 100644 index 0000000..0855a37 --- /dev/null +++ b/pkg/sanitize/errors.go @@ -0,0 +1,39 @@ +package sanitize + +import "github.com/hyp3rd/ewrap" + +var ( + // ErrInvalidHTMLConfig indicates an invalid HTML sanitizer configuration. + ErrInvalidHTMLConfig = ewrap.New("invalid html sanitize config") + // ErrInvalidMarkdownConfig indicates an invalid Markdown sanitizer configuration. + ErrInvalidMarkdownConfig = ewrap.New("invalid markdown sanitize config") + // ErrInvalidSQLConfig indicates an invalid SQL sanitizer configuration. + ErrInvalidSQLConfig = ewrap.New("invalid sql sanitize config") + // ErrInvalidFilenameConfig indicates an invalid filename sanitizer configuration. + ErrInvalidFilenameConfig = ewrap.New("invalid filename sanitize config") + + // ErrHTMLTooLong indicates the HTML input exceeds the configured limit. + ErrHTMLTooLong = ewrap.New("html input too long") + // ErrHTMLInvalid indicates the HTML input could not be parsed safely. + ErrHTMLInvalid = ewrap.New("html input invalid") + // ErrMarkdownTooLong indicates the Markdown input exceeds the configured limit. + ErrMarkdownTooLong = ewrap.New("markdown input too long") + + // ErrSQLInputTooLong indicates the SQL input exceeds the configured limit. + ErrSQLInputTooLong = ewrap.New("sql input too long") + // ErrSQLIdentifierInvalid indicates the SQL identifier is invalid. + ErrSQLIdentifierInvalid = ewrap.New("sql identifier invalid") + // ErrSQLLiteralInvalid indicates the SQL literal is invalid. + ErrSQLLiteralInvalid = ewrap.New("sql literal invalid") + // ErrSQLLikeEscapeInvalid indicates the SQL LIKE escape character is invalid. + ErrSQLLikeEscapeInvalid = ewrap.New("sql like escape invalid") + // ErrSQLInjectionDetected indicates the input matched SQL injection heuristics. + ErrSQLInjectionDetected = ewrap.New("sql injection detected") + + // ErrFilenameEmpty indicates the filename is empty after sanitization. + ErrFilenameEmpty = ewrap.New("filename empty") + // ErrFilenameTooLong indicates the filename exceeds the configured limit. + ErrFilenameTooLong = ewrap.New("filename too long") + // ErrFilenameInvalid indicates the filename contains invalid characters. + ErrFilenameInvalid = ewrap.New("filename invalid") +) diff --git a/pkg/sanitize/filename.go b/pkg/sanitize/filename.go new file mode 100644 index 0000000..17aca40 --- /dev/null +++ b/pkg/sanitize/filename.go @@ -0,0 +1,197 @@ +package sanitize + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +const ( + filenameDefaultMaxLength = 255 + filenameDefaultReplacement = '_' + filenameDeleteRune = 0x7f +) + +// FilenameOption configures filename sanitization. +type FilenameOption func(*filenameOptions) error + +type filenameOptions struct { + maxLength int + allowSpaces bool + allowUnicode bool + allowLeadingDot bool + replacement rune +} + +// FilenameSanitizer sanitizes a single filename or path segment. +type FilenameSanitizer struct { + opts filenameOptions +} + +// NewFilenameSanitizer constructs a filename sanitizer with options. +func NewFilenameSanitizer(opts ...FilenameOption) (*FilenameSanitizer, error) { + cfg := filenameOptions{ + maxLength: filenameDefaultMaxLength, + replacement: filenameDefaultReplacement, + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + err := validateFilenameConfig(cfg) + if err != nil { + return nil, err + } + + return &FilenameSanitizer{opts: cfg}, nil +} + +// WithFilenameMaxLength sets the maximum accepted filename length. +func WithFilenameMaxLength(maxLength int) FilenameOption { + return func(cfg *filenameOptions) error { + if maxLength <= 0 { + return ErrInvalidFilenameConfig + } + + cfg.maxLength = maxLength + + return nil + } +} + +// WithFilenameAllowSpaces allows spaces in filenames. +func WithFilenameAllowSpaces(allow bool) FilenameOption { + return func(cfg *filenameOptions) error { + cfg.allowSpaces = allow + + return nil + } +} + +// WithFilenameAllowUnicode allows Unicode characters in filenames. +func WithFilenameAllowUnicode(allow bool) FilenameOption { + return func(cfg *filenameOptions) error { + cfg.allowUnicode = allow + + return nil + } +} + +// WithFilenameAllowLeadingDot allows filenames starting with a dot. +func WithFilenameAllowLeadingDot(allow bool) FilenameOption { + return func(cfg *filenameOptions) error { + cfg.allowLeadingDot = allow + + return nil + } +} + +// WithFilenameReplacement sets the replacement rune for invalid characters. +func WithFilenameReplacement(replacement rune) FilenameOption { + return func(cfg *filenameOptions) error { + if !isValidReplacement(replacement) { + return ErrInvalidFilenameConfig + } + + cfg.replacement = replacement + + return nil + } +} + +// Sanitize normalizes a filename or path segment. +func (s *FilenameSanitizer) Sanitize(input string) (string, error) { + value := strings.TrimSpace(input) + if value == "" { + return "", ErrFilenameEmpty + } + + if len(value) > s.opts.maxLength { + return "", ErrFilenameTooLong + } + + var builder strings.Builder + builder.Grow(len(value)) + + for _, ch := range value { + if isAllowedFilenameRune(ch, s.opts, builder.Len() == 0) { + builder.WriteRune(ch) + + continue + } + + builder.WriteRune(s.opts.replacement) + } + + result := builder.String() + if len(result) > s.opts.maxLength { + return "", ErrFilenameTooLong + } + + if result == "." || result == ".." { + return "", ErrFilenameInvalid + } + + if strings.HasSuffix(result, ".") || strings.HasSuffix(result, " ") { + return "", ErrFilenameInvalid + } + + return result, nil +} + +func validateFilenameConfig(cfg filenameOptions) error { + if cfg.maxLength <= 0 { + return ErrInvalidFilenameConfig + } + + if !isValidReplacement(cfg.replacement) { + return ErrInvalidFilenameConfig + } + + return nil +} + +func isValidReplacement(replacement rune) bool { + if replacement == '.' || unicode.IsSpace(replacement) { + return false + } + + return isAllowedFilenameRune(replacement, filenameOptions{allowUnicode: false, allowSpaces: false}, false) +} + +func isAllowedFilenameRune(ch rune, cfg filenameOptions, isStart bool) bool { + if !cfg.allowUnicode && ch >= utf8.RuneSelf { + return false + } + + if ch == 0 || ch == utf8.RuneError { + return false + } + + if ch < ' ' || ch == filenameDeleteRune { + return false + } + + if !cfg.allowSpaces && unicode.IsSpace(ch) { + return false + } + + if !cfg.allowLeadingDot && isStart && ch == '.' { + return false + } + + switch ch { + case '/', '\\', ':', '*', '?', '"', '<', '>', '|': + return false + default: + return true + } +} diff --git a/pkg/sanitize/filename_test.go b/pkg/sanitize/filename_test.go new file mode 100644 index 0000000..b943f34 --- /dev/null +++ b/pkg/sanitize/filename_test.go @@ -0,0 +1,75 @@ +package sanitize + +import "testing" + +func TestFilenameSanitizeBasic(t *testing.T) { + sanitizer, err := NewFilenameSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("report.pdf") + if err != nil { + t.Fatalf("expected sanitized filename, got %v", err) + } + + if output != "report.pdf" { + t.Fatalf("expected report.pdf, got %q", output) + } +} + +func TestFilenameSanitizeLeadingDot(t *testing.T) { + sanitizer, err := NewFilenameSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize(".env") + if err != nil { + t.Fatalf("expected sanitized filename, got %v", err) + } + + if output != "_env" { + t.Fatalf("expected _env, got %q", output) + } +} + +func TestFilenameSanitizeSeparators(t *testing.T) { + sanitizer, err := NewFilenameSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("foo/bar") + if err != nil { + t.Fatalf("expected sanitized filename, got %v", err) + } + + if output != "foo_bar" { + t.Fatalf("expected foo_bar, got %q", output) + } +} + +func TestFilenameSanitizeEmpty(t *testing.T) { + sanitizer, err := NewFilenameSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + _, err = sanitizer.Sanitize(" ") + if err != ErrFilenameEmpty { + t.Fatalf("expected ErrFilenameEmpty, got %v", err) + } +} + +func TestFilenameSanitizeMaxLength(t *testing.T) { + sanitizer, err := NewFilenameSanitizer(WithFilenameMaxLength(3)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + _, err = sanitizer.Sanitize("abcd") + if err != ErrFilenameTooLong { + t.Fatalf("expected ErrFilenameTooLong, got %v", err) + } +} diff --git a/pkg/sanitize/html.go b/pkg/sanitize/html.go new file mode 100644 index 0000000..bd2e6c3 --- /dev/null +++ b/pkg/sanitize/html.go @@ -0,0 +1,182 @@ +package sanitize + +import ( + "fmt" + "html" + "strings" + + nethtml "golang.org/x/net/html" +) + +const ( + htmlDefaultMaxLength = 100_000 +) + +// HTMLSanitizeMode describes how HTML is sanitized. +type HTMLSanitizeMode int + +const ( + // HTMLSanitizeEscape escapes HTML tags and entities. + HTMLSanitizeEscape HTMLSanitizeMode = iota + // HTMLSanitizeStrip removes HTML tags and returns plain text. + HTMLSanitizeStrip +) + +// HTMLPolicy defines a custom HTML sanitizer. +type HTMLPolicy interface { + Sanitize(input string) (string, error) +} + +// HTMLPolicyFunc adapts a function to HTMLPolicy. +type HTMLPolicyFunc func(input string) (string, error) + +// Sanitize implements HTMLPolicy. +func (fn HTMLPolicyFunc) Sanitize(input string) (string, error) { + return fn(input) +} + +// HTMLOption configures the HTML sanitizer. +type HTMLOption func(*htmlOptions) error + +type htmlOptions struct { + maxLength int + mode HTMLSanitizeMode + policy HTMLPolicy +} + +// HTMLSanitizer sanitizes HTML input with safe defaults. +type HTMLSanitizer struct { + opts htmlOptions +} + +// NewHTMLSanitizer constructs an HTML sanitizer with options. +func NewHTMLSanitizer(opts ...HTMLOption) (*HTMLSanitizer, error) { + cfg := htmlOptions{ + maxLength: htmlDefaultMaxLength, + mode: HTMLSanitizeEscape, + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + err := validateHTMLConfig(cfg) + if err != nil { + return nil, err + } + + return &HTMLSanitizer{opts: cfg}, nil +} + +// WithHTMLMode sets the HTML sanitization mode. +func WithHTMLMode(mode HTMLSanitizeMode) HTMLOption { + return func(cfg *htmlOptions) error { + cfg.mode = mode + + return nil + } +} + +// WithHTMLMaxLength sets the maximum accepted HTML input length. +func WithHTMLMaxLength(maxLength int) HTMLOption { + return func(cfg *htmlOptions) error { + if maxLength <= 0 { + return ErrInvalidHTMLConfig + } + + cfg.maxLength = maxLength + + return nil + } +} + +// WithHTMLPolicy sets a custom HTML policy. +func WithHTMLPolicy(policy HTMLPolicy) HTMLOption { + return func(cfg *htmlOptions) error { + if policy == nil { + return ErrInvalidHTMLConfig + } + + cfg.policy = policy + + return nil + } +} + +// Sanitize sanitizes HTML content and returns a safe string. +func (s *HTMLSanitizer) Sanitize(input string) (string, error) { + if len(input) > s.opts.maxLength { + return "", ErrHTMLTooLong + } + + if s.opts.policy != nil { + return s.opts.policy.Sanitize(input) + } + + switch s.opts.mode { + case HTMLSanitizeEscape: + return html.EscapeString(input), nil + case HTMLSanitizeStrip: + return stripHTML(input) + default: + return "", ErrInvalidHTMLConfig + } +} + +func validateHTMLConfig(cfg htmlOptions) error { + if cfg.maxLength <= 0 { + return ErrInvalidHTMLConfig + } + + if cfg.mode != HTMLSanitizeEscape && cfg.mode != HTMLSanitizeStrip { + return ErrInvalidHTMLConfig + } + + return nil +} + +func stripHTML(input string) (string, error) { + doc, err := nethtml.Parse(strings.NewReader(input)) + if err != nil { + return "", fmt.Errorf("%w: %w", ErrHTMLInvalid, err) + } + + var builder strings.Builder + appendHTMLText(&builder, doc) + + return builder.String(), nil +} + +func appendHTMLText(builder *strings.Builder, node *nethtml.Node) { + if node == nil { + return + } + + if node.Type == nethtml.ElementNode && isStripElement(node.Data) { + return + } + + if node.Type == nethtml.TextNode { + builder.WriteString(node.Data) + } + + for child := node.FirstChild; child != nil; child = child.NextSibling { + appendHTMLText(builder, child) + } +} + +func isStripElement(tag string) bool { + switch tag { + case "script", "style": + return true + default: + return false + } +} diff --git a/pkg/sanitize/html_test.go b/pkg/sanitize/html_test.go new file mode 100644 index 0000000..9ebe61d --- /dev/null +++ b/pkg/sanitize/html_test.go @@ -0,0 +1,70 @@ +package sanitize + +import ( + "html" + "testing" +) + +func TestHTMLSanitizeEscape(t *testing.T) { + sanitizer, err := NewHTMLSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + input := `` + output, err := sanitizer.Sanitize(input) + if err != nil { + t.Fatalf("expected sanitized html, got %v", err) + } + + expected := html.EscapeString(input) + if output != expected { + t.Fatalf("expected %q, got %q", expected, output) + } +} + +func TestHTMLSanitizeStrip(t *testing.T) { + sanitizer, err := NewHTMLSanitizer(WithHTMLMode(HTMLSanitizeStrip)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("Hello World") + if err != nil { + t.Fatalf("expected sanitized html, got %v", err) + } + + if output != "Hello World" { + t.Fatalf("expected stripped text, got %q", output) + } +} + +func TestHTMLSanitizePolicy(t *testing.T) { + sanitizer, err := NewHTMLSanitizer(WithHTMLPolicy(HTMLPolicyFunc(func(_ string) (string, error) { + return "policy", nil + }))) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("ignored") + if err != nil { + t.Fatalf("expected sanitized html, got %v", err) + } + + if output != "policy" { + t.Fatalf("expected policy output, got %q", output) + } +} + +func TestHTMLSanitizeMaxLength(t *testing.T) { + sanitizer, err := NewHTMLSanitizer(WithHTMLMaxLength(1)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + _, err = sanitizer.Sanitize("ab") + if err != ErrHTMLTooLong { + t.Fatalf("expected ErrHTMLTooLong, got %v", err) + } +} diff --git a/pkg/sanitize/markdown.go b/pkg/sanitize/markdown.go new file mode 100644 index 0000000..c2d4531 --- /dev/null +++ b/pkg/sanitize/markdown.go @@ -0,0 +1,88 @@ +package sanitize + +import "html" + +const ( + markdownDefaultMaxLength = 100_000 +) + +// MarkdownOption configures the Markdown sanitizer. +type MarkdownOption func(*markdownOptions) error + +type markdownOptions struct { + maxLength int + allowRawHTML bool +} + +// MarkdownSanitizer sanitizes Markdown input with safe defaults. +type MarkdownSanitizer struct { + opts markdownOptions +} + +// NewMarkdownSanitizer constructs a Markdown sanitizer with options. +func NewMarkdownSanitizer(opts ...MarkdownOption) (*MarkdownSanitizer, error) { + cfg := markdownOptions{ + maxLength: markdownDefaultMaxLength, + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + err := validateMarkdownConfig(cfg) + if err != nil { + return nil, err + } + + return &MarkdownSanitizer{opts: cfg}, nil +} + +// WithMarkdownMaxLength sets the maximum accepted Markdown input length. +func WithMarkdownMaxLength(maxLength int) MarkdownOption { + return func(cfg *markdownOptions) error { + if maxLength <= 0 { + return ErrInvalidMarkdownConfig + } + + cfg.maxLength = maxLength + + return nil + } +} + +// WithMarkdownAllowRawHTML allows raw HTML inside Markdown. +func WithMarkdownAllowRawHTML(allow bool) MarkdownOption { + return func(cfg *markdownOptions) error { + cfg.allowRawHTML = allow + + return nil + } +} + +// Sanitize sanitizes Markdown input and returns a safe string. +func (s *MarkdownSanitizer) Sanitize(input string) (string, error) { + if len(input) > s.opts.maxLength { + return "", ErrMarkdownTooLong + } + + if s.opts.allowRawHTML { + return input, nil + } + + return html.EscapeString(input), nil +} + +func validateMarkdownConfig(cfg markdownOptions) error { + if cfg.maxLength <= 0 { + return ErrInvalidMarkdownConfig + } + + return nil +} diff --git a/pkg/sanitize/markdown_test.go b/pkg/sanitize/markdown_test.go new file mode 100644 index 0000000..e5fcfa1 --- /dev/null +++ b/pkg/sanitize/markdown_test.go @@ -0,0 +1,53 @@ +package sanitize + +import ( + "html" + "testing" +) + +func TestMarkdownSanitizeEscape(t *testing.T) { + sanitizer, err := NewMarkdownSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + input := "hello" + output, err := sanitizer.Sanitize(input) + if err != nil { + t.Fatalf("expected sanitized markdown, got %v", err) + } + + expected := html.EscapeString(input) + if output != expected { + t.Fatalf("expected %q, got %q", expected, output) + } +} + +func TestMarkdownAllowRawHTML(t *testing.T) { + sanitizer, err := NewMarkdownSanitizer(WithMarkdownAllowRawHTML(true)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + input := "hello" + output, err := sanitizer.Sanitize(input) + if err != nil { + t.Fatalf("expected sanitized markdown, got %v", err) + } + + if output != input { + t.Fatalf("expected raw html, got %q", output) + } +} + +func TestMarkdownMaxLength(t *testing.T) { + sanitizer, err := NewMarkdownSanitizer(WithMarkdownMaxLength(1)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + _, err = sanitizer.Sanitize("ab") + if err != ErrMarkdownTooLong { + t.Fatalf("expected ErrMarkdownTooLong, got %v", err) + } +} diff --git a/pkg/sanitize/sql.go b/pkg/sanitize/sql.go new file mode 100644 index 0000000..78dcc1f --- /dev/null +++ b/pkg/sanitize/sql.go @@ -0,0 +1,293 @@ +package sanitize + +import ( + "strings" + "unicode/utf8" +) + +const ( + sqlDefaultIdentifierMaxLength = 128 + sqlDefaultLiteralMaxLength = 4096 + sqlDefaultLikeMaxLength = 4096 + sqlDefaultLikeEscape = '\\' +) + +// SQLMode describes the SQL sanitization strategy. +type SQLMode int + +const ( + // SQLModeIdentifier sanitizes SQL identifiers (table/column names). + SQLModeIdentifier SQLMode = iota + // SQLModeLiteral sanitizes SQL literals for safe embedding in string literals. + SQLModeLiteral + // SQLModeLikePattern sanitizes SQL LIKE patterns with escaping. + SQLModeLikePattern +) + +// SQLOption configures the SQL sanitizer. +type SQLOption func(*sqlOptions) error + +type sqlOptions struct { + mode SQLMode + maxLength int + allowQualified bool + likeEscape rune +} + +// SQLSanitizer sanitizes SQL inputs with safe defaults. +type SQLSanitizer struct { + opts sqlOptions +} + +// NewSQLSanitizer constructs a SQL sanitizer with options. +func NewSQLSanitizer(opts ...SQLOption) (*SQLSanitizer, error) { + cfg := sqlOptions{ + mode: SQLModeIdentifier, + likeEscape: sqlDefaultLikeEscape, + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + if cfg.maxLength == 0 { + cfg.maxLength = sqlDefaultMaxLength(cfg.mode) + } + + err := validateSQLConfig(cfg) + if err != nil { + return nil, err + } + + return &SQLSanitizer{opts: cfg}, nil +} + +// WithSQLMode sets the SQL sanitization mode. +func WithSQLMode(mode SQLMode) SQLOption { + return func(cfg *sqlOptions) error { + cfg.mode = mode + + return nil + } +} + +// WithSQLMaxLength sets the maximum accepted SQL input length. +func WithSQLMaxLength(maxLength int) SQLOption { + return func(cfg *sqlOptions) error { + if maxLength <= 0 { + return ErrInvalidSQLConfig + } + + cfg.maxLength = maxLength + + return nil + } +} + +// WithSQLAllowQualifiedIdentifiers allows dotted identifiers (schema.table). +func WithSQLAllowQualifiedIdentifiers(allow bool) SQLOption { + return func(cfg *sqlOptions) error { + cfg.allowQualified = allow + + return nil + } +} + +// WithSQLLikeEscapeChar sets the escape character for SQL LIKE patterns. +func WithSQLLikeEscapeChar(ch rune) SQLOption { + return func(cfg *sqlOptions) error { + cfg.likeEscape = ch + + return nil + } +} + +// Sanitize sanitizes SQL input for the configured mode. +func (s *SQLSanitizer) Sanitize(input string) (string, error) { + if len(input) > s.opts.maxLength { + return "", ErrSQLInputTooLong + } + + switch s.opts.mode { + case SQLModeIdentifier: + return s.sanitizeIdentifier(input) + case SQLModeLiteral: + return sanitizeSQLLiteral(input) + case SQLModeLikePattern: + return s.sanitizeLikePattern(input) + default: + return "", ErrInvalidSQLConfig + } +} + +func validateSQLConfig(cfg sqlOptions) error { + if cfg.maxLength <= 0 { + return ErrInvalidSQLConfig + } + + if cfg.mode != SQLModeIdentifier && cfg.mode != SQLModeLiteral && cfg.mode != SQLModeLikePattern { + return ErrInvalidSQLConfig + } + + if cfg.allowQualified && cfg.mode != SQLModeIdentifier { + return ErrInvalidSQLConfig + } + + if cfg.mode == SQLModeLikePattern && !isValidLikeEscape(cfg.likeEscape) { + return ErrSQLLikeEscapeInvalid + } + + return nil +} + +func (s *SQLSanitizer) sanitizeIdentifier(input string) (string, error) { + value := strings.TrimSpace(input) + if value == "" { + return "", ErrSQLIdentifierInvalid + } + + if s.opts.allowQualified { + return sanitizeQualifiedIdentifier(value) + } + + err := validateIdentifierSegment(value) + if err != nil { + return "", err + } + + return value, nil +} + +func sanitizeQualifiedIdentifier(value string) (string, error) { + parts := strings.SplitSeq(value, ".") + for part := range parts { + err := validateIdentifierSegment(part) + if err != nil { + return "", err + } + } + + return value, nil +} + +func validateIdentifierSegment(segment string) error { + if segment == "" { + return ErrSQLIdentifierInvalid + } + + for index := range len(segment) { + ch := segment[index] + if ch >= utf8.RuneSelf { + return ErrSQLIdentifierInvalid + } + + if index == 0 { + if !isIdentifierStart(ch) { + return ErrSQLIdentifierInvalid + } + + continue + } + + if !isIdentifierPart(ch) { + return ErrSQLIdentifierInvalid + } + } + + return nil +} + +func isIdentifierStart(ch byte) bool { + if ch == '_' { + return true + } + + return isASCIIAlpha(ch) +} + +func isIdentifierPart(ch byte) bool { + return isIdentifierStart(ch) || isASCIIDigit(ch) +} + +func isASCIIAlpha(ch byte) bool { + return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') +} + +func isASCIIDigit(ch byte) bool { + return ch >= '0' && ch <= '9' +} + +func sanitizeSQLLiteral(input string) (string, error) { + if strings.ContainsRune(input, 0) { + return "", ErrSQLLiteralInvalid + } + + return strings.ReplaceAll(input, "'", "''"), nil +} + +func (s *SQLSanitizer) sanitizeLikePattern(input string) (string, error) { + if strings.ContainsRune(input, 0) { + return "", ErrSQLLiteralInvalid + } + + return escapeLikePattern(input, s.opts.likeEscape) +} + +func escapeLikePattern(input string, escape rune) (string, error) { + if !isValidLikeEscape(escape) { + return "", ErrSQLLikeEscapeInvalid + } + + var builder strings.Builder + builder.Grow(len(input)) + + for _, ch := range input { + switch ch { + case '\'': + builder.WriteString("''") + + continue + case '%', '_': + builder.WriteRune(escape) + default: + if ch == escape { + builder.WriteRune(escape) + } + } + + builder.WriteRune(ch) + } + + return builder.String(), nil +} + +func isValidLikeEscape(ch rune) bool { + if ch == 0 { + return false + } + + if ch == '\'' || ch == '%' || ch == '_' { + return false + } + + return ch < utf8.RuneSelf +} + +func sqlDefaultMaxLength(mode SQLMode) int { + if mode == SQLModeIdentifier { + return sqlDefaultIdentifierMaxLength + } + + if mode == SQLModeLikePattern { + return sqlDefaultLikeMaxLength + } + + return sqlDefaultLiteralMaxLength +} diff --git a/pkg/sanitize/sql_detect.go b/pkg/sanitize/sql_detect.go new file mode 100644 index 0000000..294f14f --- /dev/null +++ b/pkg/sanitize/sql_detect.go @@ -0,0 +1,180 @@ +package sanitize + +import ( + "strings" + "unicode" +) + +const ( + sqlDetectDefaultMaxLength = 4096 +) + +// SQLDetectOption configures the SQL injection detector. +type SQLDetectOption func(*sqlDetectOptions) error + +type sqlDetectOptions struct { + maxLength int + patterns []string +} + +// SQLInjectionDetector checks inputs for SQL injection heuristics. +type SQLInjectionDetector struct { + opts sqlDetectOptions +} + +// NewSQLInjectionDetector constructs a detector with safe defaults. +func NewSQLInjectionDetector(opts ...SQLDetectOption) (*SQLInjectionDetector, error) { + cfg := sqlDetectOptions{ + maxLength: sqlDetectDefaultMaxLength, + patterns: defaultSQLInjectionPatterns(), + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + err := validateSQLDetectConfig(cfg) + if err != nil { + return nil, err + } + + return &SQLInjectionDetector{opts: cfg}, nil +} + +// WithSQLDetectMaxLength sets the maximum input length for detection. +func WithSQLDetectMaxLength(maxLength int) SQLDetectOption { + return func(cfg *sqlDetectOptions) error { + if maxLength <= 0 { + return ErrInvalidSQLConfig + } + + cfg.maxLength = maxLength + + return nil + } +} + +// WithSQLDetectPatterns replaces the default detection patterns. +func WithSQLDetectPatterns(patterns ...string) SQLDetectOption { + return func(cfg *sqlDetectOptions) error { + if len(patterns) == 0 { + return ErrInvalidSQLConfig + } + + normalized := normalizeDetectPatterns(patterns) + if len(normalized) == 0 { + return ErrInvalidSQLConfig + } + + cfg.patterns = normalized + + return nil + } +} + +// Detect returns ErrSQLInjectionDetected when a pattern matches. +func (d *SQLInjectionDetector) Detect(input string) error { + if len(input) > d.opts.maxLength { + return ErrSQLInputTooLong + } + + normalized := " " + normalizeDetectInput(input) + for _, pattern := range d.opts.patterns { + if strings.Contains(normalized, pattern) { + return ErrSQLInjectionDetected + } + } + + return nil +} + +func validateSQLDetectConfig(cfg sqlDetectOptions) error { + if cfg.maxLength <= 0 { + return ErrInvalidSQLConfig + } + + if len(cfg.patterns) == 0 { + return ErrInvalidSQLConfig + } + + for _, pattern := range cfg.patterns { + if strings.TrimSpace(pattern) == "" { + return ErrInvalidSQLConfig + } + } + + return nil +} + +func normalizeDetectInput(input string) string { + lower := strings.ToLower(input) + + var builder strings.Builder + builder.Grow(len(lower)) + + spacePending := false + + for _, ch := range lower { + if unicode.IsSpace(ch) { + spacePending = true + + continue + } + + if spacePending { + builder.WriteByte(' ') + + spacePending = false + } + + builder.WriteRune(ch) + } + + return builder.String() +} + +func defaultSQLInjectionPatterns() []string { + return []string{ + "--", + "/*", + "*/", + ";", + "union select", + "union all select", + " or 1=1", + " or 1 = 1", + " or '1'='1'", + " or '1' = '1'", + " or \"1\"=\"1\"", + " or \"1\" = \"1\"", + " or true", + " and 1=1", + " and 1 = 1", + " and true", + "sleep(", + "pg_sleep(", + "benchmark(", + "waitfor delay", + } +} + +func normalizeDetectPatterns(patterns []string) []string { + normalized := make([]string, 0, len(patterns)) + for _, pattern := range patterns { + value := strings.TrimSpace(strings.ToLower(pattern)) + if value == "" { + continue + } + + normalized = append(normalized, value) + } + + return normalized +} diff --git a/pkg/sanitize/sql_test.go b/pkg/sanitize/sql_test.go new file mode 100644 index 0000000..bd87551 --- /dev/null +++ b/pkg/sanitize/sql_test.go @@ -0,0 +1,99 @@ +package sanitize + +import "testing" + +func TestSQLSanitizeIdentifier(t *testing.T) { + sanitizer, err := NewSQLSanitizer() + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("users") + if err != nil { + t.Fatalf("expected sanitized identifier, got %v", err) + } + + if output != "users" { + t.Fatalf("expected users, got %q", output) + } + + _, err = sanitizer.Sanitize("users;drop") + if err != ErrSQLIdentifierInvalid { + t.Fatalf("expected ErrSQLIdentifierInvalid, got %v", err) + } +} + +func TestSQLSanitizeQualifiedIdentifier(t *testing.T) { + sanitizer, err := NewSQLSanitizer(WithSQLAllowQualifiedIdentifiers(true)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("public.users") + if err != nil { + t.Fatalf("expected sanitized identifier, got %v", err) + } + + if output != "public.users" { + t.Fatalf("expected public.users, got %q", output) + } + + _, err = sanitizer.Sanitize("public..users") + if err != ErrSQLIdentifierInvalid { + t.Fatalf("expected ErrSQLIdentifierInvalid, got %v", err) + } +} + +func TestSQLSanitizeLiteral(t *testing.T) { + sanitizer, err := NewSQLSanitizer(WithSQLMode(SQLModeLiteral)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize("O'Reilly") + if err != nil { + t.Fatalf("expected sanitized literal, got %v", err) + } + + if output != "O''Reilly" { + t.Fatalf("expected escaped literal, got %q", output) + } + + _, err = sanitizer.Sanitize("bad\x00") + if err != ErrSQLLiteralInvalid { + t.Fatalf("expected ErrSQLLiteralInvalid, got %v", err) + } +} + +func TestSQLSanitizeLikePattern(t *testing.T) { + sanitizer, err := NewSQLSanitizer(WithSQLMode(SQLModeLikePattern)) + if err != nil { + t.Fatalf("expected sanitizer, got %v", err) + } + + output, err := sanitizer.Sanitize(`50%_off\`) + if err != nil { + t.Fatalf("expected sanitized pattern, got %v", err) + } + + if output != `50\%\_off\\` { + t.Fatalf("expected escaped pattern, got %q", output) + } +} + +func TestSQLInjectionDetector(t *testing.T) { + detector, err := NewSQLInjectionDetector() + if err != nil { + t.Fatalf("expected detector, got %v", err) + } + + err = detector.Detect("1 OR 1=1; --") + if err != ErrSQLInjectionDetected { + t.Fatalf("expected ErrSQLInjectionDetected, got %v", err) + } + + err = detector.Detect("username") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} diff --git a/pkg/validate/doc.go b/pkg/validate/doc.go new file mode 100644 index 0000000..df983cd --- /dev/null +++ b/pkg/validate/doc.go @@ -0,0 +1,2 @@ +// Package validate provides input validation helpers with safe defaults. +package validate diff --git a/pkg/validate/email.go b/pkg/validate/email.go new file mode 100644 index 0000000..24bc230 --- /dev/null +++ b/pkg/validate/email.go @@ -0,0 +1,582 @@ +package validate + +import ( + "context" + "errors" + "fmt" + "net" + "net/mail" + "strings" + "unicode/utf8" + + "golang.org/x/net/idna" +) + +const ( + emailMaxAddressLength = 254 + emailMaxLocalLength = 64 + emailMaxDomainLength = 255 + emailMaxLabelLength = 63 + emailMinLabelLength = 1 + + emailIPLiteralPrefix = "[" + emailIPLiteralSuffix = "]" + emailIPv6LiteralPrefix = "ipv6:" + + emailDot = '.' + emailAt = '@' +) + +// DNSResolver abstracts DNS lookups for email validation. +type DNSResolver interface { + LookupMX(ctx context.Context, name string) ([]*net.MX, error) + LookupHost(ctx context.Context, name string) ([]string, error) +} + +// EmailOption configures EmailValidator. +type EmailOption func(*emailOptions) error + +type emailOptions struct { + allowDisplayName bool + allowQuotedLocal bool + allowIPLiteral bool + allowIDN bool + requireTLD bool + verifyDomain bool + requireMX bool + allowARecordFallback bool + resolver DNSResolver +} + +// EmailResult contains normalized email details. +type EmailResult struct { + Address string + LocalPart string + Domain string + DomainASCII string + DomainVerified bool + VerifiedByMX bool + VerifiedByA bool +} + +// EmailValidator validates email addresses with optional DNS checks. +type EmailValidator struct { + opts emailOptions +} + +// NewEmailValidator constructs a validator with optional configuration. +func NewEmailValidator(opts ...EmailOption) (*EmailValidator, error) { + cfg := emailOptions{ + requireTLD: true, + allowARecordFallback: true, + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + if cfg.verifyDomain && cfg.resolver == nil { + cfg.resolver = net.DefaultResolver + } + + return &EmailValidator{opts: cfg}, nil +} + +// WithEmailAllowDisplayName permits display names like "Name ". +func WithEmailAllowDisplayName(allow bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.allowDisplayName = allow + + return nil + } +} + +// WithEmailAllowQuotedLocal permits quoted local parts. +func WithEmailAllowQuotedLocal(allow bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.allowQuotedLocal = allow + + return nil + } +} + +// WithEmailAllowIPLiteral permits [ip] literal domains. +func WithEmailAllowIPLiteral(allow bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.allowIPLiteral = allow + + return nil + } +} + +// WithEmailAllowIDN permits IDN domains and normalizes them to ASCII. +func WithEmailAllowIDN(allow bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.allowIDN = allow + + return nil + } +} + +// WithEmailRequireTLD requires a dot in the domain part. +func WithEmailRequireTLD(require bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.requireTLD = require + + return nil + } +} + +// WithEmailVerifyDomain enables MX/A lookups for domain verification. +func WithEmailVerifyDomain(verify bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.verifyDomain = verify + + return nil + } +} + +// WithEmailRequireMX enforces MX records when domain verification is enabled. +func WithEmailRequireMX(require bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.requireMX = require + + return nil + } +} + +// WithEmailAllowARecordFallback enables A/AAAA fallback when MX is missing. +func WithEmailAllowARecordFallback(allow bool) EmailOption { + return func(cfg *emailOptions) error { + cfg.allowARecordFallback = allow + + return nil + } +} + +// WithEmailDNSResolver sets a custom DNS resolver. +func WithEmailDNSResolver(resolver DNSResolver) EmailOption { + return func(cfg *emailOptions) error { + if resolver == nil { + return ErrInvalidEmailConfig + } + + cfg.resolver = resolver + + return nil + } +} + +// Validate validates an email address and optionally verifies its domain. +func (v *EmailValidator) Validate(ctx context.Context, input string) (EmailResult, error) { + trimmed, err := normalizeEmailInput(input) + if err != nil { + return EmailResult{}, err + } + + address, err := v.normalizeAddress(trimmed) + if err != nil { + return EmailResult{}, err + } + + localPart, domain, err := splitEmail(address) + if err != nil { + return EmailResult{}, err + } + + err = v.validateLocalPart(localPart) + if err != nil { + return EmailResult{}, err + } + + domainInfo, err := v.validateDomain(domain) + if err != nil { + return EmailResult{}, err + } + + result := EmailResult{ + Address: address, + LocalPart: localPart, + Domain: domainInfo.normalized, + DomainASCII: domainInfo.ascii, + } + + err = v.applyDomainVerification(ctx, domainInfo, &result) + if err != nil { + return EmailResult{}, err + } + + return result, nil +} + +func normalizeEmailInput(input string) (string, error) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return "", ErrEmailEmpty + } + + if len(trimmed) > emailMaxAddressLength { + return "", ErrEmailAddressTooLong + } + + return trimmed, nil +} + +func (v *EmailValidator) normalizeAddress(input string) (string, error) { + addr, err := mail.ParseAddress(input) + if err != nil { + addr = nil + } + + if addr == nil { + return input, nil + } + + address := addr.Address + if !v.opts.allowDisplayName && input != address { + return "", ErrEmailDisplayName + } + + return address, nil +} + +func (v *EmailValidator) validateLocalPart(local string) error { + if len(local) > emailMaxLocalLength { + return ErrEmailLocalPartTooLong + } + + return validateLocalPartSyntax(local, v.opts.allowQuotedLocal) +} + +type emailDomainInfo struct { + normalized string + ascii string + isIPLiteral bool +} + +func (v *EmailValidator) validateDomain(domain string) (emailDomainInfo, error) { + domainInfo, err := normalizeDomain(domain, v.opts.allowIDN) + if err != nil { + return emailDomainInfo{}, err + } + + if domainInfo.isIPLiteral { + if !v.opts.allowIPLiteral { + return emailDomainInfo{}, ErrEmailIPLiteralNotAllowed + } + + return domainInfo, nil + } + + if len(domainInfo.ascii) > emailMaxDomainLength { + return emailDomainInfo{}, ErrEmailDomainTooLong + } + + if v.opts.requireTLD && !strings.ContainsRune(domainInfo.ascii, emailDot) { + return emailDomainInfo{}, ErrEmailDomainInvalid + } + + err = validateDomainLabels(domainInfo.ascii) + if err != nil { + return emailDomainInfo{}, err + } + + return domainInfo, nil +} + +func (v *EmailValidator) applyDomainVerification(ctx context.Context, domainInfo emailDomainInfo, result *EmailResult) error { + if !v.opts.verifyDomain || domainInfo.isIPLiteral { + return nil + } + + if ctx == nil { + return ErrEmailInvalid + } + + verification, err := v.verifyDomain(ctx, domainInfo.ascii) + if err != nil { + return err + } + + result.DomainVerified = verification.verified + result.VerifiedByMX = verification.byMX + result.VerifiedByA = verification.byA + + return nil +} + +type domainVerification struct { + verified bool + byMX bool + byA bool +} + +func splitEmail(address string) (local, domain string, err error) { + at := strings.LastIndexByte(address, emailAt) + if at <= 0 || at >= len(address)-1 { + return "", "", ErrEmailInvalid + } + + local = address[:at] + + domain = address[at+1:] + if local == "" || domain == "" { + return "", "", ErrEmailInvalid + } + + return local, domain, nil +} + +func validateLocalPartSyntax(local string, allowQuoted bool) error { + if isQuoted(local) { + if !allowQuoted { + return ErrEmailQuotedLocalPart + } + + if !isValidQuotedLocal(local) { + return ErrEmailLocalPartInvalid + } + + return nil + } + + if !isDotAtom(local) { + return ErrEmailLocalPartInvalid + } + + return nil +} + +func isQuoted(local string) bool { + return len(local) >= 2 && local[0] == '"' && local[len(local)-1] == '"' +} + +func isValidQuotedLocal(local string) bool { + if len(local) < 2 { + return false + } + + for _, r := range local { + if r == '\n' || r == '\r' { + return false + } + } + + return true +} + +func isDotAtom(local string) bool { + if len(local) == 0 { + return false + } + + if local[0] == '.' || local[len(local)-1] == '.' { + return false + } + + parts := strings.SplitSeq(local, ".") + for part := range parts { + if part == "" { + return false + } + + for _, r := range part { + if r > utf8.RuneSelf { + return false + } + + if !isAtext(byte(r)) { + return false + } + } + } + + return true +} + +func isAtext(ch byte) bool { + if ch >= 'A' && ch <= 'Z' { + return true + } + + if ch >= 'a' && ch <= 'z' { + return true + } + + if ch >= '0' && ch <= '9' { + return true + } + + switch ch { + case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '/', '=', '?', '^', '_', '`', '{', '|', '}', '~': + return true + default: + return false + } +} + +func normalizeDomain(domain string, allowIDN bool) (emailDomainInfo, error) { + normalized := strings.TrimSuffix(domain, string(emailDot)) + if normalized == "" { + return emailDomainInfo{}, ErrEmailDomainInvalid + } + + if strings.HasPrefix(normalized, emailIPLiteralPrefix) && strings.HasSuffix(normalized, emailIPLiteralSuffix) { + literal := normalized[1 : len(normalized)-1] + if ip := parseIPLiteral(literal); ip == nil { + return emailDomainInfo{isIPLiteral: true}, ErrEmailDomainInvalid + } + + return emailDomainInfo{ + normalized: normalized, + ascii: normalized, + isIPLiteral: true, + }, nil + } + + asciiDomain := normalized + if !isASCII(normalized) { + if !allowIDN { + return emailDomainInfo{}, ErrEmailIDNNotAllowed + } + + converted, err := idna.Lookup.ToASCII(normalized) + if err != nil { + return emailDomainInfo{}, ErrEmailDomainInvalid + } + + asciiDomain = converted + } + + asciiDomain = strings.ToLower(asciiDomain) + + return emailDomainInfo{ + normalized: normalized, + ascii: asciiDomain, + }, nil +} + +func parseIPLiteral(literal string) net.IP { + lower := strings.ToLower(literal) + if strings.HasPrefix(lower, emailIPv6LiteralPrefix) { + value := literal[len(emailIPv6LiteralPrefix):] + + return net.ParseIP(value) + } + + return net.ParseIP(literal) +} + +func isASCII(value string) bool { + for i := range len(value) { + if value[i] > utf8.RuneSelf { + return false + } + } + + return true +} + +func validateDomainLabels(domain string) error { + labels := strings.SplitSeq(domain, ".") + for label := range labels { + err := validateDomainLabel(label) + if err != nil { + return err + } + } + + return nil +} + +func validateDomainLabel(label string) error { + if len(label) < emailMinLabelLength || len(label) > emailMaxLabelLength { + return ErrEmailDomainInvalid + } + + if label[0] == '-' || label[len(label)-1] == '-' { + return ErrEmailDomainInvalid + } + + for i := range len(label) { + if !isLabelChar(label[i]) { + return ErrEmailDomainInvalid + } + } + + return nil +} + +func isLabelChar(ch byte) bool { + if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '-' { + return true + } + + return ch >= 'A' && ch <= 'Z' +} + +func (v *EmailValidator) verifyDomain(ctx context.Context, domain string) (domainVerification, error) { + mxRecords, err := v.opts.resolver.LookupMX(ctx, domain) + if err == nil && hasValidMX(mxRecords) { + return domainVerification{verified: true, byMX: true}, nil + } + + if v.opts.requireMX { + if err != nil { + return domainVerification{}, fmt.Errorf("%w: %w", ErrEmailDomainLookupFailed, err) + } + + return domainVerification{}, ErrEmailDomainUnverified + } + + if v.opts.allowARecordFallback { + hosts, hostErr := v.opts.resolver.LookupHost(ctx, domain) + if hostErr == nil && len(hosts) > 0 { + return domainVerification{verified: true, byA: true}, nil + } + + if hostErr != nil && !isNotFound(hostErr) { + return domainVerification{}, fmt.Errorf("%w: %w", ErrEmailDomainLookupFailed, hostErr) + } + } + + if err != nil && !isNotFound(err) { + return domainVerification{}, fmt.Errorf("%w: %w", ErrEmailDomainLookupFailed, err) + } + + return domainVerification{}, ErrEmailDomainUnverified +} + +func hasValidMX(records []*net.MX) bool { + for _, record := range records { + if record == nil { + continue + } + + if strings.TrimSpace(record.Host) == "." { + continue + } + + if strings.TrimSpace(record.Host) == "" { + continue + } + + return true + } + + return false +} + +func isNotFound(err error) bool { + dnsErr := &net.DNSError{} + ok := errors.As(err, &dnsErr) + + return ok && dnsErr.IsNotFound +} diff --git a/pkg/validate/email_test.go b/pkg/validate/email_test.go new file mode 100644 index 0000000..65036b9 --- /dev/null +++ b/pkg/validate/email_test.go @@ -0,0 +1,210 @@ +package validate + +import ( + "context" + "net" + "testing" + + "golang.org/x/net/idna" +) + +type fakeResolver struct { + mxRecords map[string][]*net.MX + hosts map[string][]string + mxErr map[string]error + hostErr map[string]error +} + +func (r *fakeResolver) LookupMX(_ context.Context, name string) ([]*net.MX, error) { + if err, ok := r.mxErr[name]; ok { + return nil, err + } + return r.mxRecords[name], nil +} + +func (r *fakeResolver) LookupHost(_ context.Context, name string) ([]string, error) { + if err, ok := r.hostErr[name]; ok { + return nil, err + } + return r.hosts[name], nil +} + +func TestEmailValidateBasic(t *testing.T) { + validator, err := NewEmailValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + result, err := validator.Validate(context.Background(), "user@example.com") + if err != nil { + t.Fatalf("expected valid email, got %v", err) + } + + if result.DomainASCII != "example.com" { + t.Fatalf("expected domain ascii, got %s", result.DomainASCII) + } +} + +func TestEmailRejectDisplayName(t *testing.T) { + validator, err := NewEmailValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "Name ") + if err != ErrEmailDisplayName { + t.Fatalf("expected ErrEmailDisplayName, got %v", err) + } +} + +func TestEmailAllowDisplayName(t *testing.T) { + validator, err := NewEmailValidator(WithEmailAllowDisplayName(true)) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + result, err := validator.Validate(context.Background(), "Name ") + if err != nil { + t.Fatalf("expected valid email, got %v", err) + } + + if result.Address != "user@example.com" { + t.Fatalf("expected normalized address, got %s", result.Address) + } +} + +func TestEmailInvalidLocalPart(t *testing.T) { + validator, err := NewEmailValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "user..dot@example.com") + if err != ErrEmailLocalPartInvalid { + t.Fatalf("expected ErrEmailLocalPartInvalid, got %v", err) + } +} + +func TestEmailRequireTLD(t *testing.T) { + validator, err := NewEmailValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "user@localhost") + if err != ErrEmailDomainInvalid { + t.Fatalf("expected ErrEmailDomainInvalid, got %v", err) + } +} + +func TestEmailDomainVerificationMX(t *testing.T) { + resolver := &fakeResolver{ + mxRecords: map[string][]*net.MX{ + "example.com": {{Host: "mx.example.com."}}, + }, + } + + validator, err := NewEmailValidator( + WithEmailVerifyDomain(true), + WithEmailDNSResolver(resolver), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + result, err := validator.Validate(context.Background(), "user@example.com") + if err != nil { + t.Fatalf("expected valid email, got %v", err) + } + + if !result.DomainVerified || !result.VerifiedByMX { + t.Fatalf("expected mx verification") + } +} + +func TestEmailDomainVerificationFallback(t *testing.T) { + resolver := &fakeResolver{ + hosts: map[string][]string{ + "example.com": {"203.0.113.10"}, + }, + } + + validator, err := NewEmailValidator( + WithEmailVerifyDomain(true), + WithEmailDNSResolver(resolver), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + result, err := validator.Validate(context.Background(), "user@example.com") + if err != nil { + t.Fatalf("expected valid email, got %v", err) + } + + if !result.DomainVerified || !result.VerifiedByA { + t.Fatalf("expected A record verification") + } +} + +func TestEmailDomainVerificationUnverified(t *testing.T) { + resolver := &fakeResolver{} + + validator, err := NewEmailValidator( + WithEmailVerifyDomain(true), + WithEmailDNSResolver(resolver), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "user@example.com") + if err != ErrEmailDomainUnverified { + t.Fatalf("expected ErrEmailDomainUnverified, got %v", err) + } +} + +func TestEmailAllowIDN(t *testing.T) { + validator, err := NewEmailValidator(WithEmailAllowIDN(true)) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + result, err := validator.Validate(context.Background(), "user@bücher.example") + if err != nil { + t.Fatalf("expected valid email, got %v", err) + } + + ascii, err := idna.Lookup.ToASCII("bücher.example") + if err != nil { + t.Fatalf("expected idna conversion, got %v", err) + } + + if result.DomainASCII != ascii { + t.Fatalf("expected ascii domain %s, got %s", ascii, result.DomainASCII) + } +} + +func TestEmailIPLiteralDisallowed(t *testing.T) { + validator, err := NewEmailValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "user@[127.0.0.1]") + if err != ErrEmailIPLiteralNotAllowed { + t.Fatalf("expected ErrEmailIPLiteralNotAllowed, got %v", err) + } +} + +func TestEmailIPLiteralAllowed(t *testing.T) { + validator, err := NewEmailValidator(WithEmailAllowIPLiteral(true)) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "user@[127.0.0.1]") + if err != nil { + t.Fatalf("expected valid ip-literal, got %v", err) + } +} diff --git a/pkg/validate/errors.go b/pkg/validate/errors.go new file mode 100644 index 0000000..a491af4 --- /dev/null +++ b/pkg/validate/errors.go @@ -0,0 +1,61 @@ +package validate + +import "github.com/hyp3rd/ewrap" + +var ( + // ErrInvalidEmailConfig indicates that the email validation configuration is invalid. + ErrInvalidEmailConfig = ewrap.New("invalid email validation config") + // ErrInvalidURLConfig indicates that the URL validation configuration is invalid. + ErrInvalidURLConfig = ewrap.New("invalid url validation config") + // ErrEmailEmpty indicates that the email is empty. + ErrEmailEmpty = ewrap.New("email is empty") + // ErrEmailInvalid indicates that the email is invalid. + ErrEmailInvalid = ewrap.New("email is invalid") + // ErrEmailDisplayName indicates that the email display name is not allowed. + ErrEmailDisplayName = ewrap.New("email display name is not allowed") + // ErrEmailLocalPartInvalid indicates that the email local part is invalid. + ErrEmailLocalPartInvalid = ewrap.New("email local part is invalid") + // ErrEmailDomainInvalid indicates that the email domain is invalid. + ErrEmailDomainInvalid = ewrap.New("email domain is invalid") + // ErrEmailDomainTooLong indicates that the email domain is too long. + ErrEmailDomainTooLong = ewrap.New("email domain is too long") + // ErrEmailLocalPartTooLong indicates that the email local part is too long. + ErrEmailLocalPartTooLong = ewrap.New("email local part is too long") + // ErrEmailAddressTooLong indicates that the email address is too long. + ErrEmailAddressTooLong = ewrap.New("email address is too long") + // ErrEmailQuotedLocalPart indicates that the email quoted local part is not allowed. + ErrEmailQuotedLocalPart = ewrap.New("email quoted local part is not allowed") + // ErrEmailIPLiteralNotAllowed indicates that the email ip-literal domain is not allowed. + ErrEmailIPLiteralNotAllowed = ewrap.New("email ip-literal domain is not allowed") + // ErrEmailIDNNotAllowed indicates that the email idn domains are not allowed. + ErrEmailIDNNotAllowed = ewrap.New("email idn domains are not allowed") + // ErrEmailDomainLookupFailed indicates that the email domain lookup failed. + ErrEmailDomainLookupFailed = ewrap.New("email domain lookup failed") + // ErrEmailDomainUnverified indicates that the email domain is unverified. + ErrEmailDomainUnverified = ewrap.New("email domain is unverified") + + // ErrURLInvalid indicates that the URL is invalid. + ErrURLInvalid = ewrap.New("url is invalid") + // ErrURLTooLong indicates that the URL is too long. + ErrURLTooLong = ewrap.New("url is too long") + // ErrURLSchemeNotAllowed indicates that the URL scheme is not allowed. + ErrURLSchemeNotAllowed = ewrap.New("url scheme is not allowed") + // ErrURLHostMissing indicates that the URL host is required. + ErrURLHostMissing = ewrap.New("url host is required") + // ErrURLUserInfoNotAllowed indicates that the URL userinfo is not allowed. + ErrURLUserInfoNotAllowed = ewrap.New("url userinfo is not allowed") + // ErrURLHostNotAllowed indicates that the URL host is not allowed. + ErrURLHostNotAllowed = ewrap.New("url host is not allowed") + // ErrURLPrivateIPNotAllowed indicates that the URL private IP is not allowed. + ErrURLPrivateIPNotAllowed = ewrap.New("url private ip is not allowed") + // ErrURLRedirectNotAllowed indicates that URL redirects are not allowed. + ErrURLRedirectNotAllowed = ewrap.New("url redirect is not allowed") + // ErrURLRedirectLoop indicates that a URL redirect loop was detected. + ErrURLRedirectLoop = ewrap.New("url redirect loop detected") + // ErrURLRedirectLimit indicates that the URL redirect limit was exceeded. + ErrURLRedirectLimit = ewrap.New("url redirect limit exceeded") + // ErrURLReputationFailed indicates that the URL reputation check failed. + ErrURLReputationFailed = ewrap.New("url reputation check failed") + // ErrURLReputationBlocked indicates that the URL reputation check blocked the URL. + ErrURLReputationBlocked = ewrap.New("url reputation blocked") +) diff --git a/pkg/validate/url.go b/pkg/validate/url.go new file mode 100644 index 0000000..59c9e23 --- /dev/null +++ b/pkg/validate/url.go @@ -0,0 +1,688 @@ +package validate + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/net/idna" +) + +const ( + urlDefaultMaxLength = 2048 + urlDefaultMaxRedirects = 10 + urlDefaultTimeout = 5 * time.Second + + schemeHTTPS = "https" + + httpMethodHead = "HEAD" + httpMethodGet = "GET" + + redirectStatusMultipleChoices = 300 + redirectStatusMovedPermanently = 301 + redirectStatusFound = 302 + redirectStatusSeeOther = 303 + redirectStatusTemporaryRedirect = 307 + redirectStatusPermanentRedirect = 308 +) + +// URLReputationChecker evaluates a URL's reputation. +type URLReputationChecker interface { + Check(ctx context.Context, target *url.URL) (ReputationResult, error) +} + +// URLReputationCheckerFunc adapts a function to URLReputationChecker. +type URLReputationCheckerFunc func(ctx context.Context, target *url.URL) (ReputationResult, error) + +// Check implements URLReputationChecker. +func (f URLReputationCheckerFunc) Check(ctx context.Context, target *url.URL) (ReputationResult, error) { + return f(ctx, target) +} + +// ReputationVerdict indicates reputation outcome. +type ReputationVerdict int + +const ( + // ReputationUnknown indicates an unknown reputation verdict. + ReputationUnknown ReputationVerdict = iota + // ReputationAllowed indicates an allowed reputation verdict. + ReputationAllowed + // ReputationBlocked indicates a blocked reputation verdict. + ReputationBlocked +) + +// ReputationResult describes a reputation check result. +type ReputationResult struct { + Verdict ReputationVerdict + Reason string +} + +// StaticReputation checks against allow/block lists. +type StaticReputation struct { + allow map[string]struct{} + block map[string]struct{} +} + +// NewStaticReputation constructs a static checker. +func NewStaticReputation(allowHosts, blockHosts []string) *StaticReputation { + return &StaticReputation{ + allow: normalizeHostSet(allowHosts), + block: normalizeHostSet(blockHosts), + } +} + +// Check implements URLReputationChecker. +func (s *StaticReputation) Check(_ context.Context, target *url.URL) (ReputationResult, error) { + if target == nil { + return ReputationResult{Verdict: ReputationUnknown}, nil + } + + host := strings.ToLower(target.Hostname()) + if host == "" { + return ReputationResult{Verdict: ReputationUnknown}, nil + } + + if _, ok := s.block[host]; ok { + return ReputationResult{Verdict: ReputationBlocked, Reason: "blocked"}, nil + } + + if len(s.allow) > 0 { + if _, ok := s.allow[host]; ok { + return ReputationResult{Verdict: ReputationAllowed}, nil + } + + return ReputationResult{Verdict: ReputationBlocked, Reason: "not allowed"}, nil + } + + return ReputationResult{Verdict: ReputationUnknown}, nil +} + +// URLOption configures URLValidator. +type URLOption func(*urlOptions) error + +type urlOptions struct { + allowedSchemes map[string]struct{} + allowUserInfo bool + allowIDN bool + allowIPLiteral bool + allowPrivateIP bool + allowLocalhost bool + maxLength int + checkRedirects bool + maxRedirects int + redirectMethod string + httpClient *http.Client + reputationChecker URLReputationChecker + allowedHosts map[string]struct{} + blockedHosts map[string]struct{} +} + +// URLResult describes URL validation output. +type URLResult struct { + NormalizedURL string + FinalURL string + Redirects []URLRedirect + Reputation ReputationResult +} + +// URLRedirect captures a single redirect hop. +type URLRedirect struct { + From string + To string + StatusCode int +} + +// URLValidator validates URLs with optional redirect and reputation checks. +type URLValidator struct { + opts urlOptions +} + +// NewURLValidator constructs a validator with options. +func NewURLValidator(opts ...URLOption) (*URLValidator, error) { + cfg := urlOptions{ + allowedSchemes: map[string]struct{}{ + schemeHTTPS: {}, + }, + maxLength: urlDefaultMaxLength, + maxRedirects: urlDefaultMaxRedirects, + redirectMethod: httpMethodHead, + } + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt(&cfg) + if err != nil { + return nil, err + } + } + + err := validateURLOptions(&cfg) + if err != nil { + return nil, err + } + + return &URLValidator{opts: cfg}, nil +} + +// WithURLAllowedSchemes sets allowed schemes. +func WithURLAllowedSchemes(schemes ...string) URLOption { + return func(cfg *urlOptions) error { + clean := make(map[string]struct{}) + + for _, scheme := range schemes { + value := strings.ToLower(strings.TrimSpace(scheme)) + if value == "" { + continue + } + + if value != schemeHTTPS { + return ErrInvalidURLConfig + } + + clean[value] = struct{}{} + } + + if len(clean) == 0 { + return ErrInvalidURLConfig + } + + cfg.allowedSchemes = clean + + return nil + } +} + +// WithURLAllowUserInfo allows userinfo in URLs. +func WithURLAllowUserInfo(allow bool) URLOption { + return func(cfg *urlOptions) error { + cfg.allowUserInfo = allow + + return nil + } +} + +// WithURLAllowIDN allows IDN hostnames. +func WithURLAllowIDN(allow bool) URLOption { + return func(cfg *urlOptions) error { + cfg.allowIDN = allow + + return nil + } +} + +// WithURLAllowIPLiteral allows IP literal hosts. +func WithURLAllowIPLiteral(allow bool) URLOption { + return func(cfg *urlOptions) error { + cfg.allowIPLiteral = allow + + return nil + } +} + +// WithURLAllowPrivateIP allows private/loopback IPs. +func WithURLAllowPrivateIP(allow bool) URLOption { + return func(cfg *urlOptions) error { + cfg.allowPrivateIP = allow + + return nil + } +} + +// WithURLAllowLocalhost allows localhost hostnames. +func WithURLAllowLocalhost(allow bool) URLOption { + return func(cfg *urlOptions) error { + cfg.allowLocalhost = allow + + return nil + } +} + +// WithURLMaxLength sets the max URL length. +func WithURLMaxLength(maxLen int) URLOption { + return func(cfg *urlOptions) error { + if maxLen <= 0 { + return ErrInvalidURLConfig + } + + cfg.maxLength = maxLen + + return nil + } +} + +// WithURLCheckRedirects enables redirect checks with a max hop count. +func WithURLCheckRedirects(maxRedirects int) URLOption { + return func(cfg *urlOptions) error { + if maxRedirects <= 0 { + return ErrInvalidURLConfig + } + + cfg.checkRedirects = true + cfg.maxRedirects = maxRedirects + + return nil + } +} + +// WithURLRedirectMethod sets the HTTP method for redirect checks. +func WithURLRedirectMethod(method string) URLOption { + return func(cfg *urlOptions) error { + value := strings.ToUpper(strings.TrimSpace(method)) + if value != httpMethodHead && value != httpMethodGet { + return ErrInvalidURLConfig + } + + cfg.redirectMethod = value + + return nil + } +} + +// WithURLHTTPClient sets a custom HTTP client for redirect checks. +func WithURLHTTPClient(client *http.Client) URLOption { + return func(cfg *urlOptions) error { + if client == nil { + return ErrInvalidURLConfig + } + + cfg.httpClient = client + + return nil + } +} + +// WithURLReputationChecker sets a reputation checker. +func WithURLReputationChecker(checker URLReputationChecker) URLOption { + return func(cfg *urlOptions) error { + if checker == nil { + return ErrInvalidURLConfig + } + + cfg.reputationChecker = checker + + return nil + } +} + +// WithURLAllowedHosts restricts validation to specific hosts. +func WithURLAllowedHosts(hosts ...string) URLOption { + return func(cfg *urlOptions) error { + cfg.allowedHosts = normalizeHostSet(hosts) + + return nil + } +} + +// WithURLBlockedHosts blocks specific hosts. +func WithURLBlockedHosts(hosts ...string) URLOption { + return func(cfg *urlOptions) error { + cfg.blockedHosts = normalizeHostSet(hosts) + + return nil + } +} + +// Validate validates the URL, optionally checking redirects and reputation. +func (v *URLValidator) Validate(ctx context.Context, raw string) (URLResult, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return URLResult{}, ErrURLInvalid + } + + if len(trimmed) > v.opts.maxLength { + return URLResult{}, ErrURLTooLong + } + + parsed, err := url.Parse(trimmed) + if err != nil { + return URLResult{}, ErrURLInvalid + } + + err = v.validateParsed(parsed) + if err != nil { + return URLResult{}, err + } + + result := URLResult{ + NormalizedURL: parsed.String(), + FinalURL: parsed.String(), + } + + if v.opts.reputationChecker != nil { + err := v.checkReputation(ctx, parsed) + if err != nil { + return URLResult{}, err + } + } + + if v.opts.checkRedirects { + finalURL, redirects, err := v.followRedirects(ctx, parsed) + if err != nil { + return URLResult{}, err + } + + result.FinalURL = finalURL.String() + result.Redirects = redirects + + if v.opts.reputationChecker != nil { + err := v.checkReputation(ctx, finalURL) + if err != nil { + return URLResult{}, err + } + } + } + + return result, nil +} + +func validateURLOptions(cfg *urlOptions) error { + if len(cfg.allowedSchemes) == 0 { + return ErrInvalidURLConfig + } + + if len(cfg.allowedSchemes) != 1 { + return ErrInvalidURLConfig + } + + if _, ok := cfg.allowedSchemes[schemeHTTPS]; !ok { + return ErrInvalidURLConfig + } + + if cfg.maxLength <= 0 { + return ErrInvalidURLConfig + } + + if cfg.checkRedirects && cfg.maxRedirects <= 0 { + return ErrInvalidURLConfig + } + + if cfg.redirectMethod != httpMethodHead && cfg.redirectMethod != httpMethodGet { + return ErrInvalidURLConfig + } + + return nil +} + +func normalizeHostSet(hosts []string) map[string]struct{} { + clean := make(map[string]struct{}) + + for _, host := range hosts { + value := strings.ToLower(strings.TrimSpace(host)) + if value == "" { + continue + } + + clean[value] = struct{}{} + } + + return clean +} + +func (v *URLValidator) validateParsed(parsed *url.URL) error { + if parsed == nil { + return ErrURLInvalid + } + + err := v.validateScheme(parsed) + if err != nil { + return err + } + + err = v.validateUserInfo(parsed) + if err != nil { + return err + } + + host, err := v.normalizedHost(parsed) + if err != nil { + return err + } + + err = v.validateHost(host) + if err != nil { + return err + } + + return v.validateIPHost(host) +} + +func (v *URLValidator) validateScheme(parsed *url.URL) error { + scheme := strings.ToLower(parsed.Scheme) + if scheme == "" { + return ErrURLInvalid + } + + if _, ok := v.opts.allowedSchemes[scheme]; !ok { + return ErrURLSchemeNotAllowed + } + + return nil +} + +func (v *URLValidator) validateUserInfo(parsed *url.URL) error { + if parsed.User != nil && !v.opts.allowUserInfo { + return ErrURLUserInfoNotAllowed + } + + return nil +} + +func (v *URLValidator) normalizedHost(parsed *url.URL) (string, error) { + host := parsed.Hostname() + if host == "" { + return "", ErrURLHostMissing + } + + return normalizeHost(host, v.opts.allowIDN) +} + +func (v *URLValidator) validateHost(host string) error { + if !v.opts.allowLocalhost && isLocalhost(host) { + return ErrURLHostNotAllowed + } + + return v.checkHostRestrictions(host) +} + +func (v *URLValidator) validateIPHost(host string) error { + ip := net.ParseIP(host) + if ip == nil { + return nil + } + + if !v.opts.allowIPLiteral { + return ErrURLHostNotAllowed + } + + if !v.opts.allowPrivateIP && isPrivateIP(ip) { + return ErrURLPrivateIPNotAllowed + } + + return nil +} + +func normalizeHost(host string, allowIDN bool) (string, error) { + normalized := strings.TrimSuffix(host, ".") + if normalized == "" { + return "", ErrURLHostMissing + } + + if !isASCII(normalized) { + if !allowIDN { + return "", ErrURLHostNotAllowed + } + + converted, err := idna.Lookup.ToASCII(normalized) + if err != nil { + return "", ErrURLHostNotAllowed + } + + normalized = converted + } + + return strings.ToLower(normalized), nil +} + +func (v *URLValidator) checkHostRestrictions(host string) error { + if _, ok := v.opts.blockedHosts[host]; ok { + return ErrURLHostNotAllowed + } + + if len(v.opts.allowedHosts) > 0 { + if _, ok := v.opts.allowedHosts[host]; !ok { + return ErrURLHostNotAllowed + } + } + + return nil +} + +func isLocalhost(host string) bool { + return host == "localhost" || strings.HasSuffix(host, ".localhost") +} + +func isPrivateIP(ip net.IP) bool { + if ip == nil { + return false + } + + if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + if ip.IsPrivate() { + return true + } + + return ip.IsMulticast() +} + +func (v *URLValidator) followRedirects(ctx context.Context, start *url.URL) (*url.URL, []URLRedirect, error) { + if ctx == nil { + return nil, nil, ErrURLInvalid + } + + client := v.httpClient() + current := start + visited := make(map[string]struct{}) + redirects := make([]URLRedirect, 0) + + for range v.opts.maxRedirects { + hopKey := current.String() + if _, ok := visited[hopKey]; ok { + return nil, nil, ErrURLRedirectLoop + } + + visited[hopKey] = struct{}{} + + nextURL, redirect, err := v.nextRedirect(ctx, client, current) + if err != nil { + return nil, nil, err + } + + if redirect == nil { + return current, redirects, nil + } + + redirects = append(redirects, *redirect) + current = nextURL + } + + return nil, nil, ErrURLRedirectLimit +} + +func (v *URLValidator) nextRedirect(ctx context.Context, client *http.Client, current *url.URL) (*url.URL, *URLRedirect, error) { + req, err := http.NewRequestWithContext(ctx, v.opts.redirectMethod, current.String(), nil) + if err != nil { + return nil, nil, ErrURLInvalid + } + + resp, err := client.Do(req) + if err != nil { + return nil, nil, ErrURLRedirectNotAllowed + } + + //nolint:errcheck + _ = resp.Body.Close() + + if !isRedirectStatus(resp.StatusCode) { + return current, nil, nil + } + + location := resp.Header.Get("Location") + if location == "" { + return nil, nil, ErrURLRedirectNotAllowed + } + + nextURL, err := url.Parse(location) + if err != nil { + return nil, nil, ErrURLRedirectNotAllowed + } + + nextURL = current.ResolveReference(nextURL) + + err = v.validateParsed(nextURL) + if err != nil { + return nil, nil, err + } + + redirect := URLRedirect{ + From: current.String(), + To: nextURL.String(), + StatusCode: resp.StatusCode, + } + + return nextURL, &redirect, nil +} + +func isRedirectStatus(code int) bool { + switch code { + case redirectStatusMultipleChoices, + redirectStatusMovedPermanently, + redirectStatusFound, + redirectStatusSeeOther, + redirectStatusTemporaryRedirect, + redirectStatusPermanentRedirect: + return true + default: + return false + } +} + +func (v *URLValidator) httpClient() *http.Client { + client := v.opts.httpClient + if client == nil { + client = &http.Client{Timeout: urlDefaultTimeout} + } + + clone := *client + clone.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + } + + return &clone +} + +func (v *URLValidator) checkReputation(ctx context.Context, target *url.URL) error { + if ctx == nil { + return ErrURLInvalid + } + + result, err := v.opts.reputationChecker.Check(ctx, target) + if err != nil { + return fmt.Errorf("%w: %w", ErrURLReputationFailed, err) + } + + if result.Verdict == ReputationBlocked { + return ErrURLReputationBlocked + } + + return nil +} diff --git a/pkg/validate/url_test.go b/pkg/validate/url_test.go new file mode 100644 index 0000000..a515930 --- /dev/null +++ b/pkg/validate/url_test.go @@ -0,0 +1,148 @@ +package validate + +import ( + "context" + "io" + "net/http" + "strings" + "testing" +) + +type fakeRoundTripper struct { + responses map[string]*http.Response +} + +func (f *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if resp, ok := f.responses[req.URL.String()]; ok { + return resp, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil +} + +func TestURLValidateBasic(t *testing.T) { + validator, err := NewURLValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "https://example.com/path") + if err != nil { + t.Fatalf("expected valid url, got %v", err) + } +} + +func TestURLRejectUserInfo(t *testing.T) { + validator, err := NewURLValidator() + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "https://user:pass@example.com") + if err != ErrURLUserInfoNotAllowed { + t.Fatalf("expected ErrURLUserInfoNotAllowed, got %v", err) + } +} + +func TestURLRejectPrivateIP(t *testing.T) { + validator, err := NewURLValidator( + WithURLAllowIPLiteral(true), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "https://127.0.0.1") + if err != ErrURLPrivateIPNotAllowed { + t.Fatalf("expected ErrURLPrivateIPNotAllowed, got %v", err) + } +} + +func TestURLRedirectCheck(t *testing.T) { + client := &http.Client{ + Transport: &fakeRoundTripper{ + responses: map[string]*http.Response{ + "https://example.com/start": { + StatusCode: http.StatusFound, + Header: http.Header{"Location": []string{"/final"}}, + Body: io.NopCloser(strings.NewReader("")), + }, + "https://example.com/final": { + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + }, + }, + }, + } + + validator, err := NewURLValidator( + WithURLCheckRedirects(3), + WithURLHTTPClient(client), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + result, err := validator.Validate(context.Background(), "https://example.com/start") + if err != nil { + t.Fatalf("expected valid url, got %v", err) + } + + if len(result.Redirects) != 1 { + t.Fatalf("expected 1 redirect, got %d", len(result.Redirects)) + } +} + +func TestURLRedirectLoop(t *testing.T) { + client := &http.Client{ + Transport: &fakeRoundTripper{ + responses: map[string]*http.Response{ + "https://example.com/loop": { + StatusCode: http.StatusFound, + Header: http.Header{"Location": []string{"/loop"}}, + Body: io.NopCloser(strings.NewReader("")), + }, + }, + }, + } + + validator, err := NewURLValidator( + WithURLCheckRedirects(2), + WithURLHTTPClient(client), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "https://example.com/loop") + if err != ErrURLRedirectLoop { + t.Fatalf("expected ErrURLRedirectLoop, got %v", err) + } +} + +func TestURLReputationBlock(t *testing.T) { + checker := NewStaticReputation(nil, []string{"example.com"}) + validator, err := NewURLValidator( + WithURLReputationChecker(checker), + ) + if err != nil { + t.Fatalf("expected validator, got %v", err) + } + + _, err = validator.Validate(context.Background(), "https://example.com") + if err != ErrURLReputationBlocked { + t.Fatalf("expected ErrURLReputationBlocked, got %v", err) + } +} + +func TestURLRejectHTTPWithSchemesOption(t *testing.T) { + _, err := NewURLValidator(WithURLAllowedSchemes("http")) + if err != ErrInvalidURLConfig { + t.Fatalf("expected ErrInvalidURLConfig, got %v", err) + } +}