diff --git a/README.md b/README.md
index 87ce2c3..6779bb9 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
[](https://github.com/hyp3rd/sectools/actions/workflows/lint.yml) [](https://github.com/hyp3rd/sectools/actions/workflows/test.yml) [](https://github.com/hyp3rd/sectools/actions/workflows/security.yml)
-Security-focused Go helpers for file I/O, in-memory handling of sensitive data, auth tokens, password hashing, and safe numeric conversions.
+Security-focused Go helpers for file I/O, in-memory handling of sensitive data, auth tokens, password hashing, input validation/sanitization, and safe numeric conversions.
## Features
@@ -16,6 +16,8 @@ Security-focused Go helpers for file I/O, in-memory handling of sensitive data,
- Secure in-memory buffers with best-effort zeroization
- JWT/PASETO helpers with strict validation and safe defaults
- Password hashing presets for argon2id/bcrypt with rehash detection
+- Email and URL validation with optional DNS/redirect/reputation checks
+- HTML/Markdown sanitization, SQL input guards, and filename sanitizers
- Safe integer conversion helpers with overflow/negative guards
## Requirements
@@ -200,6 +202,69 @@ func main() {
}
```
+### Input validation
+
+```go
+package main
+
+import (
+ "context"
+
+ "github.com/hyp3rd/sectools/pkg/validate"
+)
+
+func main() {
+ emailValidator, err := validate.NewEmailValidator(
+ validate.WithEmailVerifyDomain(true),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ _, _ = emailValidator.Validate(context.Background(), "user@example.com")
+
+ urlValidator, err := validate.NewURLValidator(
+ validate.WithURLCheckRedirects(3),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ _, _ = urlValidator.Validate(context.Background(), "https://example.com")
+}
+```
+
+### Sanitization
+
+```go
+package main
+
+import (
+ "github.com/hyp3rd/sectools/pkg/sanitize"
+)
+
+func main() {
+ htmlSanitizer, err := sanitize.NewHTMLSanitizer()
+ if err != nil {
+ panic(err)
+ }
+
+ safeHTML, _ := htmlSanitizer.Sanitize("hello")
+
+ sqlSanitizer, err := sanitize.NewSQLSanitizer(
+ sanitize.WithSQLMode(sanitize.SQLModeIdentifier),
+ sanitize.WithSQLAllowQualifiedIdentifiers(true),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ safeIdentifier, _ := sqlSanitizer.Sanitize("public.users")
+
+ _, _ = safeHTML, safeIdentifier
+}
+```
+
## Security and behavior notes
- `ReadFile` only permits relative paths under `os.TempDir()` by default. Use `NewWithOptions` with `WithAllowAbsolute` to allow absolute paths or alternate roots.
diff --git a/cspell.json b/cspell.json
index cebd178..66db5c4 100644
--- a/cspell.json
+++ b/cspell.json
@@ -30,6 +30,7 @@
"anchore",
"argon2",
"argon2id",
+ "Atext",
"Atoi",
"aud",
"autobuild",
@@ -37,6 +38,7 @@
"behaviour",
"benchmem",
"benchtime",
+ "bücher",
"bufbuild",
"CODEOWNERS",
"CodeQL",
@@ -85,9 +87,12 @@
"GOTOOLCHAIN",
"govulncheck",
"honnef",
+ "hostnames",
"HS256",
"hyperlogger",
"iat",
+ "IDN",
+ "idna",
"Infof",
"internalio",
"ints",
@@ -97,15 +102,19 @@
"jwt",
"Keyfunc",
"kid",
+ "localhost",
"localmodule",
+ "markdown",
"memprofile",
"mkdocs",
"mlock",
"mvdan",
+ "MX",
"myproject",
"mypy",
"myuser",
"nbf",
+ "nethtml",
"Newf",
"nolint",
"nonamedreturns",
@@ -126,6 +135,10 @@
"pycache",
"recvcheck",
"Renovate",
+ "sanitization",
+ "sanitize",
+ "sanitizer",
+ "sanitizers",
"SBOM",
"SBOMs",
"sectauth",
@@ -137,6 +150,7 @@
"sigstore",
"SLSA",
"softprops",
+ "sql",
"staticcheck",
"stdlib",
"strconv",
@@ -146,6 +160,7 @@
"syncdir",
"syncer",
"tagalign",
+ "tautologies",
"TempDir",
"TempFile",
"TempPrefix",
@@ -154,6 +169,8 @@
"Tracef",
"uid",
"umask",
+ "URLHTTP",
+ "userinfo",
"varnamelen",
"vettool",
"Warnf",
diff --git a/docs/security-checklist.md b/docs/security-checklist.md
index 2569902..26d3bb4 100644
--- a/docs/security-checklist.md
+++ b/docs/security-checklist.md
@@ -31,6 +31,18 @@ This checklist is a quick reference for teams using sectools in production.
- Rehash stored passwords when `needsRehash` is true.
- Enforce bcrypt's 72-byte limit to avoid silent truncation.
+## Input Validation
+
+- Use `pkg/validate` for email/URL parsing instead of ad-hoc regexes.
+- Enable DNS verification only when you can tolerate network lookups and timeouts.
+- Keep URL schemes restricted and avoid enabling private IPs unless required.
+
+## Sanitization
+
+- Use `pkg/sanitize` for HTML/Markdown sanitization instead of ad-hoc escaping.
+- Prefer parameterized SQL queries; use `SQLSanitizer` only for identifiers or literals when needed.
+- Use `SQLInjectionDetector` as a heuristic guard for untrusted input before query composition.
+
## Cleanup
- Use `Remove`/`RemoveAll` to enforce root scoping.
diff --git a/docs/usage.md b/docs/usage.md
index acb4382..2811de3 100644
--- a/docs/usage.md
+++ b/docs/usage.md
@@ -8,6 +8,8 @@ supporting implementations in `internal/`.
- `pkg/io`: secure file read/write helpers.
- `pkg/auth`: JWT and PASETO helpers with strict validation.
- `pkg/password`: password hashing helpers.
+- `pkg/validate`: email and URL validation helpers.
+- `pkg/sanitize`: HTML/Markdown sanitizers, SQL input guards, and filename sanitizers.
- `pkg/memory`: secure in-memory buffers.
- `pkg/converters`: safe numeric conversions.
- `internal/io`: implementation details; not part of the public API contract.
@@ -274,6 +276,103 @@ Behavior:
- `Verify` returns `needsRehash` when parameters or cost drift from the current preset.
- Bcrypt rejects passwords longer than 72 bytes to avoid silent truncation.
+## pkg/validate
+
+### Email validation
+
+```go
+func NewEmailValidator(opts ...EmailOption) (*EmailValidator, error)
+func (v *EmailValidator) Validate(ctx context.Context, input string) (EmailResult, error)
+```
+
+Behavior:
+
+- Rejects display names by default; use `WithEmailAllowDisplayName(true)` to permit.
+- Validates local part syntax (dot-atom by default); quoted local parts are optional.
+- Validates domain labels and length; IDN domains require `WithEmailAllowIDN(true)`.
+- Optional DNS verification with `WithEmailVerifyDomain(true)` using MX and optional A/AAAA fallback.
+
+### URL validation
+
+```go
+func NewURLValidator(opts ...URLOption) (*URLValidator, error)
+func (v *URLValidator) Validate(ctx context.Context, raw string) (URLResult, error)
+```
+
+Behavior:
+
+- Enforces `https` only; non-https schemes are rejected (including if configured).
+- Rejects userinfo by default; use `WithURLAllowUserInfo(true)` to permit.
+- Blocks private/loopback IPs by default; use `WithURLAllowPrivateIP(true)` to permit.
+- Optional redirect checks with `WithURLCheckRedirects` and an HTTP client.
+- Optional reputation checks with `WithURLReputationChecker`.
+
+## pkg/sanitize
+
+### HTML sanitization
+
+```go
+func NewHTMLSanitizer(opts ...HTMLOption) (*HTMLSanitizer, error)
+func (s *HTMLSanitizer) Sanitize(input string) (string, error)
+```
+
+Behavior:
+
+- Escapes HTML by default (`HTMLSanitizeEscape`).
+- Supports stripping tags to plain text with `WithHTMLMode(HTMLSanitizeStrip)`.
+- Allows custom policies via `WithHTMLPolicy`.
+
+### Markdown sanitization
+
+```go
+func NewMarkdownSanitizer(opts ...MarkdownOption) (*MarkdownSanitizer, error)
+func (s *MarkdownSanitizer) Sanitize(input string) (string, error)
+```
+
+Behavior:
+
+- Escapes raw HTML by default.
+- Allows raw HTML with `WithMarkdownAllowRawHTML(true)`.
+
+### SQL sanitization
+
+```go
+func NewSQLSanitizer(opts ...SQLOption) (*SQLSanitizer, error)
+func (s *SQLSanitizer) Sanitize(input string) (string, error)
+```
+
+Behavior:
+
+- Identifier mode rejects unsafe characters and can allow dotted identifiers.
+- Literal mode escapes single quotes using SQL-standard doubling.
+- LIKE mode escapes `%`/`_` and the configured escape character.
+- Always prefer parameterized queries; sanitization is a safety net.
+
+### SQL injection detection
+
+```go
+func NewSQLInjectionDetector(opts ...SQLDetectOption) (*SQLInjectionDetector, error)
+func (d *SQLInjectionDetector) Detect(input string) error
+```
+
+Behavior:
+
+- Flags common SQL injection patterns (comments, statement separators, tautologies).
+- The detector is heuristic; tune patterns if your input includes SQL-like content.
+
+### Filename sanitization
+
+```go
+func NewFilenameSanitizer(opts ...FilenameOption) (*FilenameSanitizer, error)
+func (s *FilenameSanitizer) Sanitize(input string) (string, error)
+```
+
+Behavior:
+
+- Normalizes a single filename or path segment.
+- Replaces disallowed characters with a configurable replacement rune.
+- Rejects empty results and reserved dot segments.
+
## pkg/memory
`SecureBuffer` is a public type for holding sensitive data in memory.
diff --git a/go.mod b/go.mod
index dd09120..993c56c 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/hyp3rd/hyperlogger v0.0.8
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.46.0
+ golang.org/x/net v0.48.0
)
require (
@@ -16,6 +17,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- golang.org/x/sys v0.39.0 // indirect
+ golang.org/x/sys v0.40.0 // indirect
+ golang.org/x/text v0.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index f145d92..422d99d 100644
--- a/go.sum
+++ b/go.sum
@@ -20,8 +20,12 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
-golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
-golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
+golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
+golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
+golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
+golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
diff --git a/pkg/sanitize/doc.go b/pkg/sanitize/doc.go
new file mode 100644
index 0000000..1221d6d
--- /dev/null
+++ b/pkg/sanitize/doc.go
@@ -0,0 +1,2 @@
+// Package sanitize provides safe-by-default sanitizers for untrusted input.
+package sanitize
diff --git a/pkg/sanitize/errors.go b/pkg/sanitize/errors.go
new file mode 100644
index 0000000..0855a37
--- /dev/null
+++ b/pkg/sanitize/errors.go
@@ -0,0 +1,39 @@
+package sanitize
+
+import "github.com/hyp3rd/ewrap"
+
+var (
+ // ErrInvalidHTMLConfig indicates an invalid HTML sanitizer configuration.
+ ErrInvalidHTMLConfig = ewrap.New("invalid html sanitize config")
+ // ErrInvalidMarkdownConfig indicates an invalid Markdown sanitizer configuration.
+ ErrInvalidMarkdownConfig = ewrap.New("invalid markdown sanitize config")
+ // ErrInvalidSQLConfig indicates an invalid SQL sanitizer configuration.
+ ErrInvalidSQLConfig = ewrap.New("invalid sql sanitize config")
+ // ErrInvalidFilenameConfig indicates an invalid filename sanitizer configuration.
+ ErrInvalidFilenameConfig = ewrap.New("invalid filename sanitize config")
+
+ // ErrHTMLTooLong indicates the HTML input exceeds the configured limit.
+ ErrHTMLTooLong = ewrap.New("html input too long")
+ // ErrHTMLInvalid indicates the HTML input could not be parsed safely.
+ ErrHTMLInvalid = ewrap.New("html input invalid")
+ // ErrMarkdownTooLong indicates the Markdown input exceeds the configured limit.
+ ErrMarkdownTooLong = ewrap.New("markdown input too long")
+
+ // ErrSQLInputTooLong indicates the SQL input exceeds the configured limit.
+ ErrSQLInputTooLong = ewrap.New("sql input too long")
+ // ErrSQLIdentifierInvalid indicates the SQL identifier is invalid.
+ ErrSQLIdentifierInvalid = ewrap.New("sql identifier invalid")
+ // ErrSQLLiteralInvalid indicates the SQL literal is invalid.
+ ErrSQLLiteralInvalid = ewrap.New("sql literal invalid")
+ // ErrSQLLikeEscapeInvalid indicates the SQL LIKE escape character is invalid.
+ ErrSQLLikeEscapeInvalid = ewrap.New("sql like escape invalid")
+ // ErrSQLInjectionDetected indicates the input matched SQL injection heuristics.
+ ErrSQLInjectionDetected = ewrap.New("sql injection detected")
+
+ // ErrFilenameEmpty indicates the filename is empty after sanitization.
+ ErrFilenameEmpty = ewrap.New("filename empty")
+ // ErrFilenameTooLong indicates the filename exceeds the configured limit.
+ ErrFilenameTooLong = ewrap.New("filename too long")
+ // ErrFilenameInvalid indicates the filename contains invalid characters.
+ ErrFilenameInvalid = ewrap.New("filename invalid")
+)
diff --git a/pkg/sanitize/filename.go b/pkg/sanitize/filename.go
new file mode 100644
index 0000000..17aca40
--- /dev/null
+++ b/pkg/sanitize/filename.go
@@ -0,0 +1,197 @@
+package sanitize
+
+import (
+ "strings"
+ "unicode"
+ "unicode/utf8"
+)
+
+const (
+ filenameDefaultMaxLength = 255
+ filenameDefaultReplacement = '_'
+ filenameDeleteRune = 0x7f
+)
+
+// FilenameOption configures filename sanitization.
+type FilenameOption func(*filenameOptions) error
+
+type filenameOptions struct {
+ maxLength int
+ allowSpaces bool
+ allowUnicode bool
+ allowLeadingDot bool
+ replacement rune
+}
+
+// FilenameSanitizer sanitizes a single filename or path segment.
+type FilenameSanitizer struct {
+ opts filenameOptions
+}
+
+// NewFilenameSanitizer constructs a filename sanitizer with options.
+func NewFilenameSanitizer(opts ...FilenameOption) (*FilenameSanitizer, error) {
+ cfg := filenameOptions{
+ maxLength: filenameDefaultMaxLength,
+ replacement: filenameDefaultReplacement,
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err := validateFilenameConfig(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &FilenameSanitizer{opts: cfg}, nil
+}
+
+// WithFilenameMaxLength sets the maximum accepted filename length.
+func WithFilenameMaxLength(maxLength int) FilenameOption {
+ return func(cfg *filenameOptions) error {
+ if maxLength <= 0 {
+ return ErrInvalidFilenameConfig
+ }
+
+ cfg.maxLength = maxLength
+
+ return nil
+ }
+}
+
+// WithFilenameAllowSpaces allows spaces in filenames.
+func WithFilenameAllowSpaces(allow bool) FilenameOption {
+ return func(cfg *filenameOptions) error {
+ cfg.allowSpaces = allow
+
+ return nil
+ }
+}
+
+// WithFilenameAllowUnicode allows Unicode characters in filenames.
+func WithFilenameAllowUnicode(allow bool) FilenameOption {
+ return func(cfg *filenameOptions) error {
+ cfg.allowUnicode = allow
+
+ return nil
+ }
+}
+
+// WithFilenameAllowLeadingDot allows filenames starting with a dot.
+func WithFilenameAllowLeadingDot(allow bool) FilenameOption {
+ return func(cfg *filenameOptions) error {
+ cfg.allowLeadingDot = allow
+
+ return nil
+ }
+}
+
+// WithFilenameReplacement sets the replacement rune for invalid characters.
+func WithFilenameReplacement(replacement rune) FilenameOption {
+ return func(cfg *filenameOptions) error {
+ if !isValidReplacement(replacement) {
+ return ErrInvalidFilenameConfig
+ }
+
+ cfg.replacement = replacement
+
+ return nil
+ }
+}
+
+// Sanitize normalizes a filename or path segment.
+func (s *FilenameSanitizer) Sanitize(input string) (string, error) {
+ value := strings.TrimSpace(input)
+ if value == "" {
+ return "", ErrFilenameEmpty
+ }
+
+ if len(value) > s.opts.maxLength {
+ return "", ErrFilenameTooLong
+ }
+
+ var builder strings.Builder
+ builder.Grow(len(value))
+
+ for _, ch := range value {
+ if isAllowedFilenameRune(ch, s.opts, builder.Len() == 0) {
+ builder.WriteRune(ch)
+
+ continue
+ }
+
+ builder.WriteRune(s.opts.replacement)
+ }
+
+ result := builder.String()
+ if len(result) > s.opts.maxLength {
+ return "", ErrFilenameTooLong
+ }
+
+ if result == "." || result == ".." {
+ return "", ErrFilenameInvalid
+ }
+
+ if strings.HasSuffix(result, ".") || strings.HasSuffix(result, " ") {
+ return "", ErrFilenameInvalid
+ }
+
+ return result, nil
+}
+
+func validateFilenameConfig(cfg filenameOptions) error {
+ if cfg.maxLength <= 0 {
+ return ErrInvalidFilenameConfig
+ }
+
+ if !isValidReplacement(cfg.replacement) {
+ return ErrInvalidFilenameConfig
+ }
+
+ return nil
+}
+
+func isValidReplacement(replacement rune) bool {
+ if replacement == '.' || unicode.IsSpace(replacement) {
+ return false
+ }
+
+ return isAllowedFilenameRune(replacement, filenameOptions{allowUnicode: false, allowSpaces: false}, false)
+}
+
+func isAllowedFilenameRune(ch rune, cfg filenameOptions, isStart bool) bool {
+ if !cfg.allowUnicode && ch >= utf8.RuneSelf {
+ return false
+ }
+
+ if ch == 0 || ch == utf8.RuneError {
+ return false
+ }
+
+ if ch < ' ' || ch == filenameDeleteRune {
+ return false
+ }
+
+ if !cfg.allowSpaces && unicode.IsSpace(ch) {
+ return false
+ }
+
+ if !cfg.allowLeadingDot && isStart && ch == '.' {
+ return false
+ }
+
+ switch ch {
+ case '/', '\\', ':', '*', '?', '"', '<', '>', '|':
+ return false
+ default:
+ return true
+ }
+}
diff --git a/pkg/sanitize/filename_test.go b/pkg/sanitize/filename_test.go
new file mode 100644
index 0000000..b943f34
--- /dev/null
+++ b/pkg/sanitize/filename_test.go
@@ -0,0 +1,75 @@
+package sanitize
+
+import "testing"
+
+func TestFilenameSanitizeBasic(t *testing.T) {
+ sanitizer, err := NewFilenameSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("report.pdf")
+ if err != nil {
+ t.Fatalf("expected sanitized filename, got %v", err)
+ }
+
+ if output != "report.pdf" {
+ t.Fatalf("expected report.pdf, got %q", output)
+ }
+}
+
+func TestFilenameSanitizeLeadingDot(t *testing.T) {
+ sanitizer, err := NewFilenameSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize(".env")
+ if err != nil {
+ t.Fatalf("expected sanitized filename, got %v", err)
+ }
+
+ if output != "_env" {
+ t.Fatalf("expected _env, got %q", output)
+ }
+}
+
+func TestFilenameSanitizeSeparators(t *testing.T) {
+ sanitizer, err := NewFilenameSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("foo/bar")
+ if err != nil {
+ t.Fatalf("expected sanitized filename, got %v", err)
+ }
+
+ if output != "foo_bar" {
+ t.Fatalf("expected foo_bar, got %q", output)
+ }
+}
+
+func TestFilenameSanitizeEmpty(t *testing.T) {
+ sanitizer, err := NewFilenameSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ _, err = sanitizer.Sanitize(" ")
+ if err != ErrFilenameEmpty {
+ t.Fatalf("expected ErrFilenameEmpty, got %v", err)
+ }
+}
+
+func TestFilenameSanitizeMaxLength(t *testing.T) {
+ sanitizer, err := NewFilenameSanitizer(WithFilenameMaxLength(3))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ _, err = sanitizer.Sanitize("abcd")
+ if err != ErrFilenameTooLong {
+ t.Fatalf("expected ErrFilenameTooLong, got %v", err)
+ }
+}
diff --git a/pkg/sanitize/html.go b/pkg/sanitize/html.go
new file mode 100644
index 0000000..bd2e6c3
--- /dev/null
+++ b/pkg/sanitize/html.go
@@ -0,0 +1,182 @@
+package sanitize
+
+import (
+ "fmt"
+ "html"
+ "strings"
+
+ nethtml "golang.org/x/net/html"
+)
+
+const (
+ htmlDefaultMaxLength = 100_000
+)
+
+// HTMLSanitizeMode describes how HTML is sanitized.
+type HTMLSanitizeMode int
+
+const (
+ // HTMLSanitizeEscape escapes HTML tags and entities.
+ HTMLSanitizeEscape HTMLSanitizeMode = iota
+ // HTMLSanitizeStrip removes HTML tags and returns plain text.
+ HTMLSanitizeStrip
+)
+
+// HTMLPolicy defines a custom HTML sanitizer.
+type HTMLPolicy interface {
+ Sanitize(input string) (string, error)
+}
+
+// HTMLPolicyFunc adapts a function to HTMLPolicy.
+type HTMLPolicyFunc func(input string) (string, error)
+
+// Sanitize implements HTMLPolicy.
+func (fn HTMLPolicyFunc) Sanitize(input string) (string, error) {
+ return fn(input)
+}
+
+// HTMLOption configures the HTML sanitizer.
+type HTMLOption func(*htmlOptions) error
+
+type htmlOptions struct {
+ maxLength int
+ mode HTMLSanitizeMode
+ policy HTMLPolicy
+}
+
+// HTMLSanitizer sanitizes HTML input with safe defaults.
+type HTMLSanitizer struct {
+ opts htmlOptions
+}
+
+// NewHTMLSanitizer constructs an HTML sanitizer with options.
+func NewHTMLSanitizer(opts ...HTMLOption) (*HTMLSanitizer, error) {
+ cfg := htmlOptions{
+ maxLength: htmlDefaultMaxLength,
+ mode: HTMLSanitizeEscape,
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err := validateHTMLConfig(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &HTMLSanitizer{opts: cfg}, nil
+}
+
+// WithHTMLMode sets the HTML sanitization mode.
+func WithHTMLMode(mode HTMLSanitizeMode) HTMLOption {
+ return func(cfg *htmlOptions) error {
+ cfg.mode = mode
+
+ return nil
+ }
+}
+
+// WithHTMLMaxLength sets the maximum accepted HTML input length.
+func WithHTMLMaxLength(maxLength int) HTMLOption {
+ return func(cfg *htmlOptions) error {
+ if maxLength <= 0 {
+ return ErrInvalidHTMLConfig
+ }
+
+ cfg.maxLength = maxLength
+
+ return nil
+ }
+}
+
+// WithHTMLPolicy sets a custom HTML policy.
+func WithHTMLPolicy(policy HTMLPolicy) HTMLOption {
+ return func(cfg *htmlOptions) error {
+ if policy == nil {
+ return ErrInvalidHTMLConfig
+ }
+
+ cfg.policy = policy
+
+ return nil
+ }
+}
+
+// Sanitize sanitizes HTML content and returns a safe string.
+func (s *HTMLSanitizer) Sanitize(input string) (string, error) {
+ if len(input) > s.opts.maxLength {
+ return "", ErrHTMLTooLong
+ }
+
+ if s.opts.policy != nil {
+ return s.opts.policy.Sanitize(input)
+ }
+
+ switch s.opts.mode {
+ case HTMLSanitizeEscape:
+ return html.EscapeString(input), nil
+ case HTMLSanitizeStrip:
+ return stripHTML(input)
+ default:
+ return "", ErrInvalidHTMLConfig
+ }
+}
+
+func validateHTMLConfig(cfg htmlOptions) error {
+ if cfg.maxLength <= 0 {
+ return ErrInvalidHTMLConfig
+ }
+
+ if cfg.mode != HTMLSanitizeEscape && cfg.mode != HTMLSanitizeStrip {
+ return ErrInvalidHTMLConfig
+ }
+
+ return nil
+}
+
+func stripHTML(input string) (string, error) {
+ doc, err := nethtml.Parse(strings.NewReader(input))
+ if err != nil {
+ return "", fmt.Errorf("%w: %w", ErrHTMLInvalid, err)
+ }
+
+ var builder strings.Builder
+ appendHTMLText(&builder, doc)
+
+ return builder.String(), nil
+}
+
+func appendHTMLText(builder *strings.Builder, node *nethtml.Node) {
+ if node == nil {
+ return
+ }
+
+ if node.Type == nethtml.ElementNode && isStripElement(node.Data) {
+ return
+ }
+
+ if node.Type == nethtml.TextNode {
+ builder.WriteString(node.Data)
+ }
+
+ for child := node.FirstChild; child != nil; child = child.NextSibling {
+ appendHTMLText(builder, child)
+ }
+}
+
+func isStripElement(tag string) bool {
+ switch tag {
+ case "script", "style":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/pkg/sanitize/html_test.go b/pkg/sanitize/html_test.go
new file mode 100644
index 0000000..9ebe61d
--- /dev/null
+++ b/pkg/sanitize/html_test.go
@@ -0,0 +1,70 @@
+package sanitize
+
+import (
+ "html"
+ "testing"
+)
+
+func TestHTMLSanitizeEscape(t *testing.T) {
+ sanitizer, err := NewHTMLSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ input := ``
+ output, err := sanitizer.Sanitize(input)
+ if err != nil {
+ t.Fatalf("expected sanitized html, got %v", err)
+ }
+
+ expected := html.EscapeString(input)
+ if output != expected {
+ t.Fatalf("expected %q, got %q", expected, output)
+ }
+}
+
+func TestHTMLSanitizeStrip(t *testing.T) {
+ sanitizer, err := NewHTMLSanitizer(WithHTMLMode(HTMLSanitizeStrip))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("Hello World")
+ if err != nil {
+ t.Fatalf("expected sanitized html, got %v", err)
+ }
+
+ if output != "Hello World" {
+ t.Fatalf("expected stripped text, got %q", output)
+ }
+}
+
+func TestHTMLSanitizePolicy(t *testing.T) {
+ sanitizer, err := NewHTMLSanitizer(WithHTMLPolicy(HTMLPolicyFunc(func(_ string) (string, error) {
+ return "policy", nil
+ })))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("ignored")
+ if err != nil {
+ t.Fatalf("expected sanitized html, got %v", err)
+ }
+
+ if output != "policy" {
+ t.Fatalf("expected policy output, got %q", output)
+ }
+}
+
+func TestHTMLSanitizeMaxLength(t *testing.T) {
+ sanitizer, err := NewHTMLSanitizer(WithHTMLMaxLength(1))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ _, err = sanitizer.Sanitize("ab")
+ if err != ErrHTMLTooLong {
+ t.Fatalf("expected ErrHTMLTooLong, got %v", err)
+ }
+}
diff --git a/pkg/sanitize/markdown.go b/pkg/sanitize/markdown.go
new file mode 100644
index 0000000..c2d4531
--- /dev/null
+++ b/pkg/sanitize/markdown.go
@@ -0,0 +1,88 @@
+package sanitize
+
+import "html"
+
+const (
+ markdownDefaultMaxLength = 100_000
+)
+
+// MarkdownOption configures the Markdown sanitizer.
+type MarkdownOption func(*markdownOptions) error
+
+type markdownOptions struct {
+ maxLength int
+ allowRawHTML bool
+}
+
+// MarkdownSanitizer sanitizes Markdown input with safe defaults.
+type MarkdownSanitizer struct {
+ opts markdownOptions
+}
+
+// NewMarkdownSanitizer constructs a Markdown sanitizer with options.
+func NewMarkdownSanitizer(opts ...MarkdownOption) (*MarkdownSanitizer, error) {
+ cfg := markdownOptions{
+ maxLength: markdownDefaultMaxLength,
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err := validateMarkdownConfig(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &MarkdownSanitizer{opts: cfg}, nil
+}
+
+// WithMarkdownMaxLength sets the maximum accepted Markdown input length.
+func WithMarkdownMaxLength(maxLength int) MarkdownOption {
+ return func(cfg *markdownOptions) error {
+ if maxLength <= 0 {
+ return ErrInvalidMarkdownConfig
+ }
+
+ cfg.maxLength = maxLength
+
+ return nil
+ }
+}
+
+// WithMarkdownAllowRawHTML allows raw HTML inside Markdown.
+func WithMarkdownAllowRawHTML(allow bool) MarkdownOption {
+ return func(cfg *markdownOptions) error {
+ cfg.allowRawHTML = allow
+
+ return nil
+ }
+}
+
+// Sanitize sanitizes Markdown input and returns a safe string.
+func (s *MarkdownSanitizer) Sanitize(input string) (string, error) {
+ if len(input) > s.opts.maxLength {
+ return "", ErrMarkdownTooLong
+ }
+
+ if s.opts.allowRawHTML {
+ return input, nil
+ }
+
+ return html.EscapeString(input), nil
+}
+
+func validateMarkdownConfig(cfg markdownOptions) error {
+ if cfg.maxLength <= 0 {
+ return ErrInvalidMarkdownConfig
+ }
+
+ return nil
+}
diff --git a/pkg/sanitize/markdown_test.go b/pkg/sanitize/markdown_test.go
new file mode 100644
index 0000000..e5fcfa1
--- /dev/null
+++ b/pkg/sanitize/markdown_test.go
@@ -0,0 +1,53 @@
+package sanitize
+
+import (
+ "html"
+ "testing"
+)
+
+func TestMarkdownSanitizeEscape(t *testing.T) {
+ sanitizer, err := NewMarkdownSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ input := "hello"
+ output, err := sanitizer.Sanitize(input)
+ if err != nil {
+ t.Fatalf("expected sanitized markdown, got %v", err)
+ }
+
+ expected := html.EscapeString(input)
+ if output != expected {
+ t.Fatalf("expected %q, got %q", expected, output)
+ }
+}
+
+func TestMarkdownAllowRawHTML(t *testing.T) {
+ sanitizer, err := NewMarkdownSanitizer(WithMarkdownAllowRawHTML(true))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ input := "hello"
+ output, err := sanitizer.Sanitize(input)
+ if err != nil {
+ t.Fatalf("expected sanitized markdown, got %v", err)
+ }
+
+ if output != input {
+ t.Fatalf("expected raw html, got %q", output)
+ }
+}
+
+func TestMarkdownMaxLength(t *testing.T) {
+ sanitizer, err := NewMarkdownSanitizer(WithMarkdownMaxLength(1))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ _, err = sanitizer.Sanitize("ab")
+ if err != ErrMarkdownTooLong {
+ t.Fatalf("expected ErrMarkdownTooLong, got %v", err)
+ }
+}
diff --git a/pkg/sanitize/sql.go b/pkg/sanitize/sql.go
new file mode 100644
index 0000000..78dcc1f
--- /dev/null
+++ b/pkg/sanitize/sql.go
@@ -0,0 +1,293 @@
+package sanitize
+
+import (
+ "strings"
+ "unicode/utf8"
+)
+
+const (
+ sqlDefaultIdentifierMaxLength = 128
+ sqlDefaultLiteralMaxLength = 4096
+ sqlDefaultLikeMaxLength = 4096
+ sqlDefaultLikeEscape = '\\'
+)
+
+// SQLMode describes the SQL sanitization strategy.
+type SQLMode int
+
+const (
+ // SQLModeIdentifier sanitizes SQL identifiers (table/column names).
+ SQLModeIdentifier SQLMode = iota
+ // SQLModeLiteral sanitizes SQL literals for safe embedding in string literals.
+ SQLModeLiteral
+ // SQLModeLikePattern sanitizes SQL LIKE patterns with escaping.
+ SQLModeLikePattern
+)
+
+// SQLOption configures the SQL sanitizer.
+type SQLOption func(*sqlOptions) error
+
+type sqlOptions struct {
+ mode SQLMode
+ maxLength int
+ allowQualified bool
+ likeEscape rune
+}
+
+// SQLSanitizer sanitizes SQL inputs with safe defaults.
+type SQLSanitizer struct {
+ opts sqlOptions
+}
+
+// NewSQLSanitizer constructs a SQL sanitizer with options.
+func NewSQLSanitizer(opts ...SQLOption) (*SQLSanitizer, error) {
+ cfg := sqlOptions{
+ mode: SQLModeIdentifier,
+ likeEscape: sqlDefaultLikeEscape,
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if cfg.maxLength == 0 {
+ cfg.maxLength = sqlDefaultMaxLength(cfg.mode)
+ }
+
+ err := validateSQLConfig(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &SQLSanitizer{opts: cfg}, nil
+}
+
+// WithSQLMode sets the SQL sanitization mode.
+func WithSQLMode(mode SQLMode) SQLOption {
+ return func(cfg *sqlOptions) error {
+ cfg.mode = mode
+
+ return nil
+ }
+}
+
+// WithSQLMaxLength sets the maximum accepted SQL input length.
+func WithSQLMaxLength(maxLength int) SQLOption {
+ return func(cfg *sqlOptions) error {
+ if maxLength <= 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ cfg.maxLength = maxLength
+
+ return nil
+ }
+}
+
+// WithSQLAllowQualifiedIdentifiers allows dotted identifiers (schema.table).
+func WithSQLAllowQualifiedIdentifiers(allow bool) SQLOption {
+ return func(cfg *sqlOptions) error {
+ cfg.allowQualified = allow
+
+ return nil
+ }
+}
+
+// WithSQLLikeEscapeChar sets the escape character for SQL LIKE patterns.
+func WithSQLLikeEscapeChar(ch rune) SQLOption {
+ return func(cfg *sqlOptions) error {
+ cfg.likeEscape = ch
+
+ return nil
+ }
+}
+
+// Sanitize sanitizes SQL input for the configured mode.
+func (s *SQLSanitizer) Sanitize(input string) (string, error) {
+ if len(input) > s.opts.maxLength {
+ return "", ErrSQLInputTooLong
+ }
+
+ switch s.opts.mode {
+ case SQLModeIdentifier:
+ return s.sanitizeIdentifier(input)
+ case SQLModeLiteral:
+ return sanitizeSQLLiteral(input)
+ case SQLModeLikePattern:
+ return s.sanitizeLikePattern(input)
+ default:
+ return "", ErrInvalidSQLConfig
+ }
+}
+
+func validateSQLConfig(cfg sqlOptions) error {
+ if cfg.maxLength <= 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ if cfg.mode != SQLModeIdentifier && cfg.mode != SQLModeLiteral && cfg.mode != SQLModeLikePattern {
+ return ErrInvalidSQLConfig
+ }
+
+ if cfg.allowQualified && cfg.mode != SQLModeIdentifier {
+ return ErrInvalidSQLConfig
+ }
+
+ if cfg.mode == SQLModeLikePattern && !isValidLikeEscape(cfg.likeEscape) {
+ return ErrSQLLikeEscapeInvalid
+ }
+
+ return nil
+}
+
+func (s *SQLSanitizer) sanitizeIdentifier(input string) (string, error) {
+ value := strings.TrimSpace(input)
+ if value == "" {
+ return "", ErrSQLIdentifierInvalid
+ }
+
+ if s.opts.allowQualified {
+ return sanitizeQualifiedIdentifier(value)
+ }
+
+ err := validateIdentifierSegment(value)
+ if err != nil {
+ return "", err
+ }
+
+ return value, nil
+}
+
+func sanitizeQualifiedIdentifier(value string) (string, error) {
+ parts := strings.SplitSeq(value, ".")
+ for part := range parts {
+ err := validateIdentifierSegment(part)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return value, nil
+}
+
+func validateIdentifierSegment(segment string) error {
+ if segment == "" {
+ return ErrSQLIdentifierInvalid
+ }
+
+ for index := range len(segment) {
+ ch := segment[index]
+ if ch >= utf8.RuneSelf {
+ return ErrSQLIdentifierInvalid
+ }
+
+ if index == 0 {
+ if !isIdentifierStart(ch) {
+ return ErrSQLIdentifierInvalid
+ }
+
+ continue
+ }
+
+ if !isIdentifierPart(ch) {
+ return ErrSQLIdentifierInvalid
+ }
+ }
+
+ return nil
+}
+
+func isIdentifierStart(ch byte) bool {
+ if ch == '_' {
+ return true
+ }
+
+ return isASCIIAlpha(ch)
+}
+
+func isIdentifierPart(ch byte) bool {
+ return isIdentifierStart(ch) || isASCIIDigit(ch)
+}
+
+func isASCIIAlpha(ch byte) bool {
+ return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z')
+}
+
+func isASCIIDigit(ch byte) bool {
+ return ch >= '0' && ch <= '9'
+}
+
+func sanitizeSQLLiteral(input string) (string, error) {
+ if strings.ContainsRune(input, 0) {
+ return "", ErrSQLLiteralInvalid
+ }
+
+ return strings.ReplaceAll(input, "'", "''"), nil
+}
+
+func (s *SQLSanitizer) sanitizeLikePattern(input string) (string, error) {
+ if strings.ContainsRune(input, 0) {
+ return "", ErrSQLLiteralInvalid
+ }
+
+ return escapeLikePattern(input, s.opts.likeEscape)
+}
+
+func escapeLikePattern(input string, escape rune) (string, error) {
+ if !isValidLikeEscape(escape) {
+ return "", ErrSQLLikeEscapeInvalid
+ }
+
+ var builder strings.Builder
+ builder.Grow(len(input))
+
+ for _, ch := range input {
+ switch ch {
+ case '\'':
+ builder.WriteString("''")
+
+ continue
+ case '%', '_':
+ builder.WriteRune(escape)
+ default:
+ if ch == escape {
+ builder.WriteRune(escape)
+ }
+ }
+
+ builder.WriteRune(ch)
+ }
+
+ return builder.String(), nil
+}
+
+func isValidLikeEscape(ch rune) bool {
+ if ch == 0 {
+ return false
+ }
+
+ if ch == '\'' || ch == '%' || ch == '_' {
+ return false
+ }
+
+ return ch < utf8.RuneSelf
+}
+
+func sqlDefaultMaxLength(mode SQLMode) int {
+ if mode == SQLModeIdentifier {
+ return sqlDefaultIdentifierMaxLength
+ }
+
+ if mode == SQLModeLikePattern {
+ return sqlDefaultLikeMaxLength
+ }
+
+ return sqlDefaultLiteralMaxLength
+}
diff --git a/pkg/sanitize/sql_detect.go b/pkg/sanitize/sql_detect.go
new file mode 100644
index 0000000..294f14f
--- /dev/null
+++ b/pkg/sanitize/sql_detect.go
@@ -0,0 +1,180 @@
+package sanitize
+
+import (
+ "strings"
+ "unicode"
+)
+
+const (
+ sqlDetectDefaultMaxLength = 4096
+)
+
+// SQLDetectOption configures the SQL injection detector.
+type SQLDetectOption func(*sqlDetectOptions) error
+
+type sqlDetectOptions struct {
+ maxLength int
+ patterns []string
+}
+
+// SQLInjectionDetector checks inputs for SQL injection heuristics.
+type SQLInjectionDetector struct {
+ opts sqlDetectOptions
+}
+
+// NewSQLInjectionDetector constructs a detector with safe defaults.
+func NewSQLInjectionDetector(opts ...SQLDetectOption) (*SQLInjectionDetector, error) {
+ cfg := sqlDetectOptions{
+ maxLength: sqlDetectDefaultMaxLength,
+ patterns: defaultSQLInjectionPatterns(),
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err := validateSQLDetectConfig(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &SQLInjectionDetector{opts: cfg}, nil
+}
+
+// WithSQLDetectMaxLength sets the maximum input length for detection.
+func WithSQLDetectMaxLength(maxLength int) SQLDetectOption {
+ return func(cfg *sqlDetectOptions) error {
+ if maxLength <= 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ cfg.maxLength = maxLength
+
+ return nil
+ }
+}
+
+// WithSQLDetectPatterns replaces the default detection patterns.
+func WithSQLDetectPatterns(patterns ...string) SQLDetectOption {
+ return func(cfg *sqlDetectOptions) error {
+ if len(patterns) == 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ normalized := normalizeDetectPatterns(patterns)
+ if len(normalized) == 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ cfg.patterns = normalized
+
+ return nil
+ }
+}
+
+// Detect returns ErrSQLInjectionDetected when a pattern matches.
+func (d *SQLInjectionDetector) Detect(input string) error {
+ if len(input) > d.opts.maxLength {
+ return ErrSQLInputTooLong
+ }
+
+ normalized := " " + normalizeDetectInput(input)
+ for _, pattern := range d.opts.patterns {
+ if strings.Contains(normalized, pattern) {
+ return ErrSQLInjectionDetected
+ }
+ }
+
+ return nil
+}
+
+func validateSQLDetectConfig(cfg sqlDetectOptions) error {
+ if cfg.maxLength <= 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ if len(cfg.patterns) == 0 {
+ return ErrInvalidSQLConfig
+ }
+
+ for _, pattern := range cfg.patterns {
+ if strings.TrimSpace(pattern) == "" {
+ return ErrInvalidSQLConfig
+ }
+ }
+
+ return nil
+}
+
+func normalizeDetectInput(input string) string {
+ lower := strings.ToLower(input)
+
+ var builder strings.Builder
+ builder.Grow(len(lower))
+
+ spacePending := false
+
+ for _, ch := range lower {
+ if unicode.IsSpace(ch) {
+ spacePending = true
+
+ continue
+ }
+
+ if spacePending {
+ builder.WriteByte(' ')
+
+ spacePending = false
+ }
+
+ builder.WriteRune(ch)
+ }
+
+ return builder.String()
+}
+
+func defaultSQLInjectionPatterns() []string {
+ return []string{
+ "--",
+ "/*",
+ "*/",
+ ";",
+ "union select",
+ "union all select",
+ " or 1=1",
+ " or 1 = 1",
+ " or '1'='1'",
+ " or '1' = '1'",
+ " or \"1\"=\"1\"",
+ " or \"1\" = \"1\"",
+ " or true",
+ " and 1=1",
+ " and 1 = 1",
+ " and true",
+ "sleep(",
+ "pg_sleep(",
+ "benchmark(",
+ "waitfor delay",
+ }
+}
+
+func normalizeDetectPatterns(patterns []string) []string {
+ normalized := make([]string, 0, len(patterns))
+ for _, pattern := range patterns {
+ value := strings.TrimSpace(strings.ToLower(pattern))
+ if value == "" {
+ continue
+ }
+
+ normalized = append(normalized, value)
+ }
+
+ return normalized
+}
diff --git a/pkg/sanitize/sql_test.go b/pkg/sanitize/sql_test.go
new file mode 100644
index 0000000..bd87551
--- /dev/null
+++ b/pkg/sanitize/sql_test.go
@@ -0,0 +1,99 @@
+package sanitize
+
+import "testing"
+
+func TestSQLSanitizeIdentifier(t *testing.T) {
+ sanitizer, err := NewSQLSanitizer()
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("users")
+ if err != nil {
+ t.Fatalf("expected sanitized identifier, got %v", err)
+ }
+
+ if output != "users" {
+ t.Fatalf("expected users, got %q", output)
+ }
+
+ _, err = sanitizer.Sanitize("users;drop")
+ if err != ErrSQLIdentifierInvalid {
+ t.Fatalf("expected ErrSQLIdentifierInvalid, got %v", err)
+ }
+}
+
+func TestSQLSanitizeQualifiedIdentifier(t *testing.T) {
+ sanitizer, err := NewSQLSanitizer(WithSQLAllowQualifiedIdentifiers(true))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("public.users")
+ if err != nil {
+ t.Fatalf("expected sanitized identifier, got %v", err)
+ }
+
+ if output != "public.users" {
+ t.Fatalf("expected public.users, got %q", output)
+ }
+
+ _, err = sanitizer.Sanitize("public..users")
+ if err != ErrSQLIdentifierInvalid {
+ t.Fatalf("expected ErrSQLIdentifierInvalid, got %v", err)
+ }
+}
+
+func TestSQLSanitizeLiteral(t *testing.T) {
+ sanitizer, err := NewSQLSanitizer(WithSQLMode(SQLModeLiteral))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize("O'Reilly")
+ if err != nil {
+ t.Fatalf("expected sanitized literal, got %v", err)
+ }
+
+ if output != "O''Reilly" {
+ t.Fatalf("expected escaped literal, got %q", output)
+ }
+
+ _, err = sanitizer.Sanitize("bad\x00")
+ if err != ErrSQLLiteralInvalid {
+ t.Fatalf("expected ErrSQLLiteralInvalid, got %v", err)
+ }
+}
+
+func TestSQLSanitizeLikePattern(t *testing.T) {
+ sanitizer, err := NewSQLSanitizer(WithSQLMode(SQLModeLikePattern))
+ if err != nil {
+ t.Fatalf("expected sanitizer, got %v", err)
+ }
+
+ output, err := sanitizer.Sanitize(`50%_off\`)
+ if err != nil {
+ t.Fatalf("expected sanitized pattern, got %v", err)
+ }
+
+ if output != `50\%\_off\\` {
+ t.Fatalf("expected escaped pattern, got %q", output)
+ }
+}
+
+func TestSQLInjectionDetector(t *testing.T) {
+ detector, err := NewSQLInjectionDetector()
+ if err != nil {
+ t.Fatalf("expected detector, got %v", err)
+ }
+
+ err = detector.Detect("1 OR 1=1; --")
+ if err != ErrSQLInjectionDetected {
+ t.Fatalf("expected ErrSQLInjectionDetected, got %v", err)
+ }
+
+ err = detector.Detect("username")
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+}
diff --git a/pkg/validate/doc.go b/pkg/validate/doc.go
new file mode 100644
index 0000000..df983cd
--- /dev/null
+++ b/pkg/validate/doc.go
@@ -0,0 +1,2 @@
+// Package validate provides input validation helpers with safe defaults.
+package validate
diff --git a/pkg/validate/email.go b/pkg/validate/email.go
new file mode 100644
index 0000000..24bc230
--- /dev/null
+++ b/pkg/validate/email.go
@@ -0,0 +1,582 @@
+package validate
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/mail"
+ "strings"
+ "unicode/utf8"
+
+ "golang.org/x/net/idna"
+)
+
+const (
+ emailMaxAddressLength = 254
+ emailMaxLocalLength = 64
+ emailMaxDomainLength = 255
+ emailMaxLabelLength = 63
+ emailMinLabelLength = 1
+
+ emailIPLiteralPrefix = "["
+ emailIPLiteralSuffix = "]"
+ emailIPv6LiteralPrefix = "ipv6:"
+
+ emailDot = '.'
+ emailAt = '@'
+)
+
+// DNSResolver abstracts DNS lookups for email validation.
+type DNSResolver interface {
+ LookupMX(ctx context.Context, name string) ([]*net.MX, error)
+ LookupHost(ctx context.Context, name string) ([]string, error)
+}
+
+// EmailOption configures EmailValidator.
+type EmailOption func(*emailOptions) error
+
+type emailOptions struct {
+ allowDisplayName bool
+ allowQuotedLocal bool
+ allowIPLiteral bool
+ allowIDN bool
+ requireTLD bool
+ verifyDomain bool
+ requireMX bool
+ allowARecordFallback bool
+ resolver DNSResolver
+}
+
+// EmailResult contains normalized email details.
+type EmailResult struct {
+ Address string
+ LocalPart string
+ Domain string
+ DomainASCII string
+ DomainVerified bool
+ VerifiedByMX bool
+ VerifiedByA bool
+}
+
+// EmailValidator validates email addresses with optional DNS checks.
+type EmailValidator struct {
+ opts emailOptions
+}
+
+// NewEmailValidator constructs a validator with optional configuration.
+func NewEmailValidator(opts ...EmailOption) (*EmailValidator, error) {
+ cfg := emailOptions{
+ requireTLD: true,
+ allowARecordFallback: true,
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if cfg.verifyDomain && cfg.resolver == nil {
+ cfg.resolver = net.DefaultResolver
+ }
+
+ return &EmailValidator{opts: cfg}, nil
+}
+
+// WithEmailAllowDisplayName permits display names like "Name ".
+func WithEmailAllowDisplayName(allow bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.allowDisplayName = allow
+
+ return nil
+ }
+}
+
+// WithEmailAllowQuotedLocal permits quoted local parts.
+func WithEmailAllowQuotedLocal(allow bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.allowQuotedLocal = allow
+
+ return nil
+ }
+}
+
+// WithEmailAllowIPLiteral permits [ip] literal domains.
+func WithEmailAllowIPLiteral(allow bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.allowIPLiteral = allow
+
+ return nil
+ }
+}
+
+// WithEmailAllowIDN permits IDN domains and normalizes them to ASCII.
+func WithEmailAllowIDN(allow bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.allowIDN = allow
+
+ return nil
+ }
+}
+
+// WithEmailRequireTLD requires a dot in the domain part.
+func WithEmailRequireTLD(require bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.requireTLD = require
+
+ return nil
+ }
+}
+
+// WithEmailVerifyDomain enables MX/A lookups for domain verification.
+func WithEmailVerifyDomain(verify bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.verifyDomain = verify
+
+ return nil
+ }
+}
+
+// WithEmailRequireMX enforces MX records when domain verification is enabled.
+func WithEmailRequireMX(require bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.requireMX = require
+
+ return nil
+ }
+}
+
+// WithEmailAllowARecordFallback enables A/AAAA fallback when MX is missing.
+func WithEmailAllowARecordFallback(allow bool) EmailOption {
+ return func(cfg *emailOptions) error {
+ cfg.allowARecordFallback = allow
+
+ return nil
+ }
+}
+
+// WithEmailDNSResolver sets a custom DNS resolver.
+func WithEmailDNSResolver(resolver DNSResolver) EmailOption {
+ return func(cfg *emailOptions) error {
+ if resolver == nil {
+ return ErrInvalidEmailConfig
+ }
+
+ cfg.resolver = resolver
+
+ return nil
+ }
+}
+
+// Validate validates an email address and optionally verifies its domain.
+func (v *EmailValidator) Validate(ctx context.Context, input string) (EmailResult, error) {
+ trimmed, err := normalizeEmailInput(input)
+ if err != nil {
+ return EmailResult{}, err
+ }
+
+ address, err := v.normalizeAddress(trimmed)
+ if err != nil {
+ return EmailResult{}, err
+ }
+
+ localPart, domain, err := splitEmail(address)
+ if err != nil {
+ return EmailResult{}, err
+ }
+
+ err = v.validateLocalPart(localPart)
+ if err != nil {
+ return EmailResult{}, err
+ }
+
+ domainInfo, err := v.validateDomain(domain)
+ if err != nil {
+ return EmailResult{}, err
+ }
+
+ result := EmailResult{
+ Address: address,
+ LocalPart: localPart,
+ Domain: domainInfo.normalized,
+ DomainASCII: domainInfo.ascii,
+ }
+
+ err = v.applyDomainVerification(ctx, domainInfo, &result)
+ if err != nil {
+ return EmailResult{}, err
+ }
+
+ return result, nil
+}
+
+func normalizeEmailInput(input string) (string, error) {
+ trimmed := strings.TrimSpace(input)
+ if trimmed == "" {
+ return "", ErrEmailEmpty
+ }
+
+ if len(trimmed) > emailMaxAddressLength {
+ return "", ErrEmailAddressTooLong
+ }
+
+ return trimmed, nil
+}
+
+func (v *EmailValidator) normalizeAddress(input string) (string, error) {
+ addr, err := mail.ParseAddress(input)
+ if err != nil {
+ addr = nil
+ }
+
+ if addr == nil {
+ return input, nil
+ }
+
+ address := addr.Address
+ if !v.opts.allowDisplayName && input != address {
+ return "", ErrEmailDisplayName
+ }
+
+ return address, nil
+}
+
+func (v *EmailValidator) validateLocalPart(local string) error {
+ if len(local) > emailMaxLocalLength {
+ return ErrEmailLocalPartTooLong
+ }
+
+ return validateLocalPartSyntax(local, v.opts.allowQuotedLocal)
+}
+
+type emailDomainInfo struct {
+ normalized string
+ ascii string
+ isIPLiteral bool
+}
+
+func (v *EmailValidator) validateDomain(domain string) (emailDomainInfo, error) {
+ domainInfo, err := normalizeDomain(domain, v.opts.allowIDN)
+ if err != nil {
+ return emailDomainInfo{}, err
+ }
+
+ if domainInfo.isIPLiteral {
+ if !v.opts.allowIPLiteral {
+ return emailDomainInfo{}, ErrEmailIPLiteralNotAllowed
+ }
+
+ return domainInfo, nil
+ }
+
+ if len(domainInfo.ascii) > emailMaxDomainLength {
+ return emailDomainInfo{}, ErrEmailDomainTooLong
+ }
+
+ if v.opts.requireTLD && !strings.ContainsRune(domainInfo.ascii, emailDot) {
+ return emailDomainInfo{}, ErrEmailDomainInvalid
+ }
+
+ err = validateDomainLabels(domainInfo.ascii)
+ if err != nil {
+ return emailDomainInfo{}, err
+ }
+
+ return domainInfo, nil
+}
+
+func (v *EmailValidator) applyDomainVerification(ctx context.Context, domainInfo emailDomainInfo, result *EmailResult) error {
+ if !v.opts.verifyDomain || domainInfo.isIPLiteral {
+ return nil
+ }
+
+ if ctx == nil {
+ return ErrEmailInvalid
+ }
+
+ verification, err := v.verifyDomain(ctx, domainInfo.ascii)
+ if err != nil {
+ return err
+ }
+
+ result.DomainVerified = verification.verified
+ result.VerifiedByMX = verification.byMX
+ result.VerifiedByA = verification.byA
+
+ return nil
+}
+
+type domainVerification struct {
+ verified bool
+ byMX bool
+ byA bool
+}
+
+func splitEmail(address string) (local, domain string, err error) {
+ at := strings.LastIndexByte(address, emailAt)
+ if at <= 0 || at >= len(address)-1 {
+ return "", "", ErrEmailInvalid
+ }
+
+ local = address[:at]
+
+ domain = address[at+1:]
+ if local == "" || domain == "" {
+ return "", "", ErrEmailInvalid
+ }
+
+ return local, domain, nil
+}
+
+func validateLocalPartSyntax(local string, allowQuoted bool) error {
+ if isQuoted(local) {
+ if !allowQuoted {
+ return ErrEmailQuotedLocalPart
+ }
+
+ if !isValidQuotedLocal(local) {
+ return ErrEmailLocalPartInvalid
+ }
+
+ return nil
+ }
+
+ if !isDotAtom(local) {
+ return ErrEmailLocalPartInvalid
+ }
+
+ return nil
+}
+
+func isQuoted(local string) bool {
+ return len(local) >= 2 && local[0] == '"' && local[len(local)-1] == '"'
+}
+
+func isValidQuotedLocal(local string) bool {
+ if len(local) < 2 {
+ return false
+ }
+
+ for _, r := range local {
+ if r == '\n' || r == '\r' {
+ return false
+ }
+ }
+
+ return true
+}
+
+func isDotAtom(local string) bool {
+ if len(local) == 0 {
+ return false
+ }
+
+ if local[0] == '.' || local[len(local)-1] == '.' {
+ return false
+ }
+
+ parts := strings.SplitSeq(local, ".")
+ for part := range parts {
+ if part == "" {
+ return false
+ }
+
+ for _, r := range part {
+ if r > utf8.RuneSelf {
+ return false
+ }
+
+ if !isAtext(byte(r)) {
+ return false
+ }
+ }
+ }
+
+ return true
+}
+
+func isAtext(ch byte) bool {
+ if ch >= 'A' && ch <= 'Z' {
+ return true
+ }
+
+ if ch >= 'a' && ch <= 'z' {
+ return true
+ }
+
+ if ch >= '0' && ch <= '9' {
+ return true
+ }
+
+ switch ch {
+ case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '/', '=', '?', '^', '_', '`', '{', '|', '}', '~':
+ return true
+ default:
+ return false
+ }
+}
+
+func normalizeDomain(domain string, allowIDN bool) (emailDomainInfo, error) {
+ normalized := strings.TrimSuffix(domain, string(emailDot))
+ if normalized == "" {
+ return emailDomainInfo{}, ErrEmailDomainInvalid
+ }
+
+ if strings.HasPrefix(normalized, emailIPLiteralPrefix) && strings.HasSuffix(normalized, emailIPLiteralSuffix) {
+ literal := normalized[1 : len(normalized)-1]
+ if ip := parseIPLiteral(literal); ip == nil {
+ return emailDomainInfo{isIPLiteral: true}, ErrEmailDomainInvalid
+ }
+
+ return emailDomainInfo{
+ normalized: normalized,
+ ascii: normalized,
+ isIPLiteral: true,
+ }, nil
+ }
+
+ asciiDomain := normalized
+ if !isASCII(normalized) {
+ if !allowIDN {
+ return emailDomainInfo{}, ErrEmailIDNNotAllowed
+ }
+
+ converted, err := idna.Lookup.ToASCII(normalized)
+ if err != nil {
+ return emailDomainInfo{}, ErrEmailDomainInvalid
+ }
+
+ asciiDomain = converted
+ }
+
+ asciiDomain = strings.ToLower(asciiDomain)
+
+ return emailDomainInfo{
+ normalized: normalized,
+ ascii: asciiDomain,
+ }, nil
+}
+
+func parseIPLiteral(literal string) net.IP {
+ lower := strings.ToLower(literal)
+ if strings.HasPrefix(lower, emailIPv6LiteralPrefix) {
+ value := literal[len(emailIPv6LiteralPrefix):]
+
+ return net.ParseIP(value)
+ }
+
+ return net.ParseIP(literal)
+}
+
+func isASCII(value string) bool {
+ for i := range len(value) {
+ if value[i] > utf8.RuneSelf {
+ return false
+ }
+ }
+
+ return true
+}
+
+func validateDomainLabels(domain string) error {
+ labels := strings.SplitSeq(domain, ".")
+ for label := range labels {
+ err := validateDomainLabel(label)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func validateDomainLabel(label string) error {
+ if len(label) < emailMinLabelLength || len(label) > emailMaxLabelLength {
+ return ErrEmailDomainInvalid
+ }
+
+ if label[0] == '-' || label[len(label)-1] == '-' {
+ return ErrEmailDomainInvalid
+ }
+
+ for i := range len(label) {
+ if !isLabelChar(label[i]) {
+ return ErrEmailDomainInvalid
+ }
+ }
+
+ return nil
+}
+
+func isLabelChar(ch byte) bool {
+ if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '-' {
+ return true
+ }
+
+ return ch >= 'A' && ch <= 'Z'
+}
+
+func (v *EmailValidator) verifyDomain(ctx context.Context, domain string) (domainVerification, error) {
+ mxRecords, err := v.opts.resolver.LookupMX(ctx, domain)
+ if err == nil && hasValidMX(mxRecords) {
+ return domainVerification{verified: true, byMX: true}, nil
+ }
+
+ if v.opts.requireMX {
+ if err != nil {
+ return domainVerification{}, fmt.Errorf("%w: %w", ErrEmailDomainLookupFailed, err)
+ }
+
+ return domainVerification{}, ErrEmailDomainUnverified
+ }
+
+ if v.opts.allowARecordFallback {
+ hosts, hostErr := v.opts.resolver.LookupHost(ctx, domain)
+ if hostErr == nil && len(hosts) > 0 {
+ return domainVerification{verified: true, byA: true}, nil
+ }
+
+ if hostErr != nil && !isNotFound(hostErr) {
+ return domainVerification{}, fmt.Errorf("%w: %w", ErrEmailDomainLookupFailed, hostErr)
+ }
+ }
+
+ if err != nil && !isNotFound(err) {
+ return domainVerification{}, fmt.Errorf("%w: %w", ErrEmailDomainLookupFailed, err)
+ }
+
+ return domainVerification{}, ErrEmailDomainUnverified
+}
+
+func hasValidMX(records []*net.MX) bool {
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+
+ if strings.TrimSpace(record.Host) == "." {
+ continue
+ }
+
+ if strings.TrimSpace(record.Host) == "" {
+ continue
+ }
+
+ return true
+ }
+
+ return false
+}
+
+func isNotFound(err error) bool {
+ dnsErr := &net.DNSError{}
+ ok := errors.As(err, &dnsErr)
+
+ return ok && dnsErr.IsNotFound
+}
diff --git a/pkg/validate/email_test.go b/pkg/validate/email_test.go
new file mode 100644
index 0000000..65036b9
--- /dev/null
+++ b/pkg/validate/email_test.go
@@ -0,0 +1,210 @@
+package validate
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "golang.org/x/net/idna"
+)
+
+type fakeResolver struct {
+ mxRecords map[string][]*net.MX
+ hosts map[string][]string
+ mxErr map[string]error
+ hostErr map[string]error
+}
+
+func (r *fakeResolver) LookupMX(_ context.Context, name string) ([]*net.MX, error) {
+ if err, ok := r.mxErr[name]; ok {
+ return nil, err
+ }
+ return r.mxRecords[name], nil
+}
+
+func (r *fakeResolver) LookupHost(_ context.Context, name string) ([]string, error) {
+ if err, ok := r.hostErr[name]; ok {
+ return nil, err
+ }
+ return r.hosts[name], nil
+}
+
+func TestEmailValidateBasic(t *testing.T) {
+ validator, err := NewEmailValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ result, err := validator.Validate(context.Background(), "user@example.com")
+ if err != nil {
+ t.Fatalf("expected valid email, got %v", err)
+ }
+
+ if result.DomainASCII != "example.com" {
+ t.Fatalf("expected domain ascii, got %s", result.DomainASCII)
+ }
+}
+
+func TestEmailRejectDisplayName(t *testing.T) {
+ validator, err := NewEmailValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "Name ")
+ if err != ErrEmailDisplayName {
+ t.Fatalf("expected ErrEmailDisplayName, got %v", err)
+ }
+}
+
+func TestEmailAllowDisplayName(t *testing.T) {
+ validator, err := NewEmailValidator(WithEmailAllowDisplayName(true))
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ result, err := validator.Validate(context.Background(), "Name ")
+ if err != nil {
+ t.Fatalf("expected valid email, got %v", err)
+ }
+
+ if result.Address != "user@example.com" {
+ t.Fatalf("expected normalized address, got %s", result.Address)
+ }
+}
+
+func TestEmailInvalidLocalPart(t *testing.T) {
+ validator, err := NewEmailValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "user..dot@example.com")
+ if err != ErrEmailLocalPartInvalid {
+ t.Fatalf("expected ErrEmailLocalPartInvalid, got %v", err)
+ }
+}
+
+func TestEmailRequireTLD(t *testing.T) {
+ validator, err := NewEmailValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "user@localhost")
+ if err != ErrEmailDomainInvalid {
+ t.Fatalf("expected ErrEmailDomainInvalid, got %v", err)
+ }
+}
+
+func TestEmailDomainVerificationMX(t *testing.T) {
+ resolver := &fakeResolver{
+ mxRecords: map[string][]*net.MX{
+ "example.com": {{Host: "mx.example.com."}},
+ },
+ }
+
+ validator, err := NewEmailValidator(
+ WithEmailVerifyDomain(true),
+ WithEmailDNSResolver(resolver),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ result, err := validator.Validate(context.Background(), "user@example.com")
+ if err != nil {
+ t.Fatalf("expected valid email, got %v", err)
+ }
+
+ if !result.DomainVerified || !result.VerifiedByMX {
+ t.Fatalf("expected mx verification")
+ }
+}
+
+func TestEmailDomainVerificationFallback(t *testing.T) {
+ resolver := &fakeResolver{
+ hosts: map[string][]string{
+ "example.com": {"203.0.113.10"},
+ },
+ }
+
+ validator, err := NewEmailValidator(
+ WithEmailVerifyDomain(true),
+ WithEmailDNSResolver(resolver),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ result, err := validator.Validate(context.Background(), "user@example.com")
+ if err != nil {
+ t.Fatalf("expected valid email, got %v", err)
+ }
+
+ if !result.DomainVerified || !result.VerifiedByA {
+ t.Fatalf("expected A record verification")
+ }
+}
+
+func TestEmailDomainVerificationUnverified(t *testing.T) {
+ resolver := &fakeResolver{}
+
+ validator, err := NewEmailValidator(
+ WithEmailVerifyDomain(true),
+ WithEmailDNSResolver(resolver),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "user@example.com")
+ if err != ErrEmailDomainUnverified {
+ t.Fatalf("expected ErrEmailDomainUnverified, got %v", err)
+ }
+}
+
+func TestEmailAllowIDN(t *testing.T) {
+ validator, err := NewEmailValidator(WithEmailAllowIDN(true))
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ result, err := validator.Validate(context.Background(), "user@bücher.example")
+ if err != nil {
+ t.Fatalf("expected valid email, got %v", err)
+ }
+
+ ascii, err := idna.Lookup.ToASCII("bücher.example")
+ if err != nil {
+ t.Fatalf("expected idna conversion, got %v", err)
+ }
+
+ if result.DomainASCII != ascii {
+ t.Fatalf("expected ascii domain %s, got %s", ascii, result.DomainASCII)
+ }
+}
+
+func TestEmailIPLiteralDisallowed(t *testing.T) {
+ validator, err := NewEmailValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "user@[127.0.0.1]")
+ if err != ErrEmailIPLiteralNotAllowed {
+ t.Fatalf("expected ErrEmailIPLiteralNotAllowed, got %v", err)
+ }
+}
+
+func TestEmailIPLiteralAllowed(t *testing.T) {
+ validator, err := NewEmailValidator(WithEmailAllowIPLiteral(true))
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "user@[127.0.0.1]")
+ if err != nil {
+ t.Fatalf("expected valid ip-literal, got %v", err)
+ }
+}
diff --git a/pkg/validate/errors.go b/pkg/validate/errors.go
new file mode 100644
index 0000000..a491af4
--- /dev/null
+++ b/pkg/validate/errors.go
@@ -0,0 +1,61 @@
+package validate
+
+import "github.com/hyp3rd/ewrap"
+
+var (
+ // ErrInvalidEmailConfig indicates that the email validation configuration is invalid.
+ ErrInvalidEmailConfig = ewrap.New("invalid email validation config")
+ // ErrInvalidURLConfig indicates that the URL validation configuration is invalid.
+ ErrInvalidURLConfig = ewrap.New("invalid url validation config")
+ // ErrEmailEmpty indicates that the email is empty.
+ ErrEmailEmpty = ewrap.New("email is empty")
+ // ErrEmailInvalid indicates that the email is invalid.
+ ErrEmailInvalid = ewrap.New("email is invalid")
+ // ErrEmailDisplayName indicates that the email display name is not allowed.
+ ErrEmailDisplayName = ewrap.New("email display name is not allowed")
+ // ErrEmailLocalPartInvalid indicates that the email local part is invalid.
+ ErrEmailLocalPartInvalid = ewrap.New("email local part is invalid")
+ // ErrEmailDomainInvalid indicates that the email domain is invalid.
+ ErrEmailDomainInvalid = ewrap.New("email domain is invalid")
+ // ErrEmailDomainTooLong indicates that the email domain is too long.
+ ErrEmailDomainTooLong = ewrap.New("email domain is too long")
+ // ErrEmailLocalPartTooLong indicates that the email local part is too long.
+ ErrEmailLocalPartTooLong = ewrap.New("email local part is too long")
+ // ErrEmailAddressTooLong indicates that the email address is too long.
+ ErrEmailAddressTooLong = ewrap.New("email address is too long")
+ // ErrEmailQuotedLocalPart indicates that the email quoted local part is not allowed.
+ ErrEmailQuotedLocalPart = ewrap.New("email quoted local part is not allowed")
+ // ErrEmailIPLiteralNotAllowed indicates that the email ip-literal domain is not allowed.
+ ErrEmailIPLiteralNotAllowed = ewrap.New("email ip-literal domain is not allowed")
+ // ErrEmailIDNNotAllowed indicates that the email idn domains are not allowed.
+ ErrEmailIDNNotAllowed = ewrap.New("email idn domains are not allowed")
+ // ErrEmailDomainLookupFailed indicates that the email domain lookup failed.
+ ErrEmailDomainLookupFailed = ewrap.New("email domain lookup failed")
+ // ErrEmailDomainUnverified indicates that the email domain is unverified.
+ ErrEmailDomainUnverified = ewrap.New("email domain is unverified")
+
+ // ErrURLInvalid indicates that the URL is invalid.
+ ErrURLInvalid = ewrap.New("url is invalid")
+ // ErrURLTooLong indicates that the URL is too long.
+ ErrURLTooLong = ewrap.New("url is too long")
+ // ErrURLSchemeNotAllowed indicates that the URL scheme is not allowed.
+ ErrURLSchemeNotAllowed = ewrap.New("url scheme is not allowed")
+ // ErrURLHostMissing indicates that the URL host is required.
+ ErrURLHostMissing = ewrap.New("url host is required")
+ // ErrURLUserInfoNotAllowed indicates that the URL userinfo is not allowed.
+ ErrURLUserInfoNotAllowed = ewrap.New("url userinfo is not allowed")
+ // ErrURLHostNotAllowed indicates that the URL host is not allowed.
+ ErrURLHostNotAllowed = ewrap.New("url host is not allowed")
+ // ErrURLPrivateIPNotAllowed indicates that the URL private IP is not allowed.
+ ErrURLPrivateIPNotAllowed = ewrap.New("url private ip is not allowed")
+ // ErrURLRedirectNotAllowed indicates that URL redirects are not allowed.
+ ErrURLRedirectNotAllowed = ewrap.New("url redirect is not allowed")
+ // ErrURLRedirectLoop indicates that a URL redirect loop was detected.
+ ErrURLRedirectLoop = ewrap.New("url redirect loop detected")
+ // ErrURLRedirectLimit indicates that the URL redirect limit was exceeded.
+ ErrURLRedirectLimit = ewrap.New("url redirect limit exceeded")
+ // ErrURLReputationFailed indicates that the URL reputation check failed.
+ ErrURLReputationFailed = ewrap.New("url reputation check failed")
+ // ErrURLReputationBlocked indicates that the URL reputation check blocked the URL.
+ ErrURLReputationBlocked = ewrap.New("url reputation blocked")
+)
diff --git a/pkg/validate/url.go b/pkg/validate/url.go
new file mode 100644
index 0000000..59c9e23
--- /dev/null
+++ b/pkg/validate/url.go
@@ -0,0 +1,688 @@
+package validate
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "golang.org/x/net/idna"
+)
+
+const (
+ urlDefaultMaxLength = 2048
+ urlDefaultMaxRedirects = 10
+ urlDefaultTimeout = 5 * time.Second
+
+ schemeHTTPS = "https"
+
+ httpMethodHead = "HEAD"
+ httpMethodGet = "GET"
+
+ redirectStatusMultipleChoices = 300
+ redirectStatusMovedPermanently = 301
+ redirectStatusFound = 302
+ redirectStatusSeeOther = 303
+ redirectStatusTemporaryRedirect = 307
+ redirectStatusPermanentRedirect = 308
+)
+
+// URLReputationChecker evaluates a URL's reputation.
+type URLReputationChecker interface {
+ Check(ctx context.Context, target *url.URL) (ReputationResult, error)
+}
+
+// URLReputationCheckerFunc adapts a function to URLReputationChecker.
+type URLReputationCheckerFunc func(ctx context.Context, target *url.URL) (ReputationResult, error)
+
+// Check implements URLReputationChecker.
+func (f URLReputationCheckerFunc) Check(ctx context.Context, target *url.URL) (ReputationResult, error) {
+ return f(ctx, target)
+}
+
+// ReputationVerdict indicates reputation outcome.
+type ReputationVerdict int
+
+const (
+ // ReputationUnknown indicates an unknown reputation verdict.
+ ReputationUnknown ReputationVerdict = iota
+ // ReputationAllowed indicates an allowed reputation verdict.
+ ReputationAllowed
+ // ReputationBlocked indicates a blocked reputation verdict.
+ ReputationBlocked
+)
+
+// ReputationResult describes a reputation check result.
+type ReputationResult struct {
+ Verdict ReputationVerdict
+ Reason string
+}
+
+// StaticReputation checks against allow/block lists.
+type StaticReputation struct {
+ allow map[string]struct{}
+ block map[string]struct{}
+}
+
+// NewStaticReputation constructs a static checker.
+func NewStaticReputation(allowHosts, blockHosts []string) *StaticReputation {
+ return &StaticReputation{
+ allow: normalizeHostSet(allowHosts),
+ block: normalizeHostSet(blockHosts),
+ }
+}
+
+// Check implements URLReputationChecker.
+func (s *StaticReputation) Check(_ context.Context, target *url.URL) (ReputationResult, error) {
+ if target == nil {
+ return ReputationResult{Verdict: ReputationUnknown}, nil
+ }
+
+ host := strings.ToLower(target.Hostname())
+ if host == "" {
+ return ReputationResult{Verdict: ReputationUnknown}, nil
+ }
+
+ if _, ok := s.block[host]; ok {
+ return ReputationResult{Verdict: ReputationBlocked, Reason: "blocked"}, nil
+ }
+
+ if len(s.allow) > 0 {
+ if _, ok := s.allow[host]; ok {
+ return ReputationResult{Verdict: ReputationAllowed}, nil
+ }
+
+ return ReputationResult{Verdict: ReputationBlocked, Reason: "not allowed"}, nil
+ }
+
+ return ReputationResult{Verdict: ReputationUnknown}, nil
+}
+
+// URLOption configures URLValidator.
+type URLOption func(*urlOptions) error
+
+type urlOptions struct {
+ allowedSchemes map[string]struct{}
+ allowUserInfo bool
+ allowIDN bool
+ allowIPLiteral bool
+ allowPrivateIP bool
+ allowLocalhost bool
+ maxLength int
+ checkRedirects bool
+ maxRedirects int
+ redirectMethod string
+ httpClient *http.Client
+ reputationChecker URLReputationChecker
+ allowedHosts map[string]struct{}
+ blockedHosts map[string]struct{}
+}
+
+// URLResult describes URL validation output.
+type URLResult struct {
+ NormalizedURL string
+ FinalURL string
+ Redirects []URLRedirect
+ Reputation ReputationResult
+}
+
+// URLRedirect captures a single redirect hop.
+type URLRedirect struct {
+ From string
+ To string
+ StatusCode int
+}
+
+// URLValidator validates URLs with optional redirect and reputation checks.
+type URLValidator struct {
+ opts urlOptions
+}
+
+// NewURLValidator constructs a validator with options.
+func NewURLValidator(opts ...URLOption) (*URLValidator, error) {
+ cfg := urlOptions{
+ allowedSchemes: map[string]struct{}{
+ schemeHTTPS: {},
+ },
+ maxLength: urlDefaultMaxLength,
+ maxRedirects: urlDefaultMaxRedirects,
+ redirectMethod: httpMethodHead,
+ }
+
+ for _, opt := range opts {
+ if opt == nil {
+ continue
+ }
+
+ err := opt(&cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err := validateURLOptions(&cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &URLValidator{opts: cfg}, nil
+}
+
+// WithURLAllowedSchemes sets allowed schemes.
+func WithURLAllowedSchemes(schemes ...string) URLOption {
+ return func(cfg *urlOptions) error {
+ clean := make(map[string]struct{})
+
+ for _, scheme := range schemes {
+ value := strings.ToLower(strings.TrimSpace(scheme))
+ if value == "" {
+ continue
+ }
+
+ if value != schemeHTTPS {
+ return ErrInvalidURLConfig
+ }
+
+ clean[value] = struct{}{}
+ }
+
+ if len(clean) == 0 {
+ return ErrInvalidURLConfig
+ }
+
+ cfg.allowedSchemes = clean
+
+ return nil
+ }
+}
+
+// WithURLAllowUserInfo allows userinfo in URLs.
+func WithURLAllowUserInfo(allow bool) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.allowUserInfo = allow
+
+ return nil
+ }
+}
+
+// WithURLAllowIDN allows IDN hostnames.
+func WithURLAllowIDN(allow bool) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.allowIDN = allow
+
+ return nil
+ }
+}
+
+// WithURLAllowIPLiteral allows IP literal hosts.
+func WithURLAllowIPLiteral(allow bool) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.allowIPLiteral = allow
+
+ return nil
+ }
+}
+
+// WithURLAllowPrivateIP allows private/loopback IPs.
+func WithURLAllowPrivateIP(allow bool) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.allowPrivateIP = allow
+
+ return nil
+ }
+}
+
+// WithURLAllowLocalhost allows localhost hostnames.
+func WithURLAllowLocalhost(allow bool) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.allowLocalhost = allow
+
+ return nil
+ }
+}
+
+// WithURLMaxLength sets the max URL length.
+func WithURLMaxLength(maxLen int) URLOption {
+ return func(cfg *urlOptions) error {
+ if maxLen <= 0 {
+ return ErrInvalidURLConfig
+ }
+
+ cfg.maxLength = maxLen
+
+ return nil
+ }
+}
+
+// WithURLCheckRedirects enables redirect checks with a max hop count.
+func WithURLCheckRedirects(maxRedirects int) URLOption {
+ return func(cfg *urlOptions) error {
+ if maxRedirects <= 0 {
+ return ErrInvalidURLConfig
+ }
+
+ cfg.checkRedirects = true
+ cfg.maxRedirects = maxRedirects
+
+ return nil
+ }
+}
+
+// WithURLRedirectMethod sets the HTTP method for redirect checks.
+func WithURLRedirectMethod(method string) URLOption {
+ return func(cfg *urlOptions) error {
+ value := strings.ToUpper(strings.TrimSpace(method))
+ if value != httpMethodHead && value != httpMethodGet {
+ return ErrInvalidURLConfig
+ }
+
+ cfg.redirectMethod = value
+
+ return nil
+ }
+}
+
+// WithURLHTTPClient sets a custom HTTP client for redirect checks.
+func WithURLHTTPClient(client *http.Client) URLOption {
+ return func(cfg *urlOptions) error {
+ if client == nil {
+ return ErrInvalidURLConfig
+ }
+
+ cfg.httpClient = client
+
+ return nil
+ }
+}
+
+// WithURLReputationChecker sets a reputation checker.
+func WithURLReputationChecker(checker URLReputationChecker) URLOption {
+ return func(cfg *urlOptions) error {
+ if checker == nil {
+ return ErrInvalidURLConfig
+ }
+
+ cfg.reputationChecker = checker
+
+ return nil
+ }
+}
+
+// WithURLAllowedHosts restricts validation to specific hosts.
+func WithURLAllowedHosts(hosts ...string) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.allowedHosts = normalizeHostSet(hosts)
+
+ return nil
+ }
+}
+
+// WithURLBlockedHosts blocks specific hosts.
+func WithURLBlockedHosts(hosts ...string) URLOption {
+ return func(cfg *urlOptions) error {
+ cfg.blockedHosts = normalizeHostSet(hosts)
+
+ return nil
+ }
+}
+
+// Validate validates the URL, optionally checking redirects and reputation.
+func (v *URLValidator) Validate(ctx context.Context, raw string) (URLResult, error) {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return URLResult{}, ErrURLInvalid
+ }
+
+ if len(trimmed) > v.opts.maxLength {
+ return URLResult{}, ErrURLTooLong
+ }
+
+ parsed, err := url.Parse(trimmed)
+ if err != nil {
+ return URLResult{}, ErrURLInvalid
+ }
+
+ err = v.validateParsed(parsed)
+ if err != nil {
+ return URLResult{}, err
+ }
+
+ result := URLResult{
+ NormalizedURL: parsed.String(),
+ FinalURL: parsed.String(),
+ }
+
+ if v.opts.reputationChecker != nil {
+ err := v.checkReputation(ctx, parsed)
+ if err != nil {
+ return URLResult{}, err
+ }
+ }
+
+ if v.opts.checkRedirects {
+ finalURL, redirects, err := v.followRedirects(ctx, parsed)
+ if err != nil {
+ return URLResult{}, err
+ }
+
+ result.FinalURL = finalURL.String()
+ result.Redirects = redirects
+
+ if v.opts.reputationChecker != nil {
+ err := v.checkReputation(ctx, finalURL)
+ if err != nil {
+ return URLResult{}, err
+ }
+ }
+ }
+
+ return result, nil
+}
+
+func validateURLOptions(cfg *urlOptions) error {
+ if len(cfg.allowedSchemes) == 0 {
+ return ErrInvalidURLConfig
+ }
+
+ if len(cfg.allowedSchemes) != 1 {
+ return ErrInvalidURLConfig
+ }
+
+ if _, ok := cfg.allowedSchemes[schemeHTTPS]; !ok {
+ return ErrInvalidURLConfig
+ }
+
+ if cfg.maxLength <= 0 {
+ return ErrInvalidURLConfig
+ }
+
+ if cfg.checkRedirects && cfg.maxRedirects <= 0 {
+ return ErrInvalidURLConfig
+ }
+
+ if cfg.redirectMethod != httpMethodHead && cfg.redirectMethod != httpMethodGet {
+ return ErrInvalidURLConfig
+ }
+
+ return nil
+}
+
+func normalizeHostSet(hosts []string) map[string]struct{} {
+ clean := make(map[string]struct{})
+
+ for _, host := range hosts {
+ value := strings.ToLower(strings.TrimSpace(host))
+ if value == "" {
+ continue
+ }
+
+ clean[value] = struct{}{}
+ }
+
+ return clean
+}
+
+func (v *URLValidator) validateParsed(parsed *url.URL) error {
+ if parsed == nil {
+ return ErrURLInvalid
+ }
+
+ err := v.validateScheme(parsed)
+ if err != nil {
+ return err
+ }
+
+ err = v.validateUserInfo(parsed)
+ if err != nil {
+ return err
+ }
+
+ host, err := v.normalizedHost(parsed)
+ if err != nil {
+ return err
+ }
+
+ err = v.validateHost(host)
+ if err != nil {
+ return err
+ }
+
+ return v.validateIPHost(host)
+}
+
+func (v *URLValidator) validateScheme(parsed *url.URL) error {
+ scheme := strings.ToLower(parsed.Scheme)
+ if scheme == "" {
+ return ErrURLInvalid
+ }
+
+ if _, ok := v.opts.allowedSchemes[scheme]; !ok {
+ return ErrURLSchemeNotAllowed
+ }
+
+ return nil
+}
+
+func (v *URLValidator) validateUserInfo(parsed *url.URL) error {
+ if parsed.User != nil && !v.opts.allowUserInfo {
+ return ErrURLUserInfoNotAllowed
+ }
+
+ return nil
+}
+
+func (v *URLValidator) normalizedHost(parsed *url.URL) (string, error) {
+ host := parsed.Hostname()
+ if host == "" {
+ return "", ErrURLHostMissing
+ }
+
+ return normalizeHost(host, v.opts.allowIDN)
+}
+
+func (v *URLValidator) validateHost(host string) error {
+ if !v.opts.allowLocalhost && isLocalhost(host) {
+ return ErrURLHostNotAllowed
+ }
+
+ return v.checkHostRestrictions(host)
+}
+
+func (v *URLValidator) validateIPHost(host string) error {
+ ip := net.ParseIP(host)
+ if ip == nil {
+ return nil
+ }
+
+ if !v.opts.allowIPLiteral {
+ return ErrURLHostNotAllowed
+ }
+
+ if !v.opts.allowPrivateIP && isPrivateIP(ip) {
+ return ErrURLPrivateIPNotAllowed
+ }
+
+ return nil
+}
+
+func normalizeHost(host string, allowIDN bool) (string, error) {
+ normalized := strings.TrimSuffix(host, ".")
+ if normalized == "" {
+ return "", ErrURLHostMissing
+ }
+
+ if !isASCII(normalized) {
+ if !allowIDN {
+ return "", ErrURLHostNotAllowed
+ }
+
+ converted, err := idna.Lookup.ToASCII(normalized)
+ if err != nil {
+ return "", ErrURLHostNotAllowed
+ }
+
+ normalized = converted
+ }
+
+ return strings.ToLower(normalized), nil
+}
+
+func (v *URLValidator) checkHostRestrictions(host string) error {
+ if _, ok := v.opts.blockedHosts[host]; ok {
+ return ErrURLHostNotAllowed
+ }
+
+ if len(v.opts.allowedHosts) > 0 {
+ if _, ok := v.opts.allowedHosts[host]; !ok {
+ return ErrURLHostNotAllowed
+ }
+ }
+
+ return nil
+}
+
+func isLocalhost(host string) bool {
+ return host == "localhost" || strings.HasSuffix(host, ".localhost")
+}
+
+func isPrivateIP(ip net.IP) bool {
+ if ip == nil {
+ return false
+ }
+
+ if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+
+ if ip.IsPrivate() {
+ return true
+ }
+
+ return ip.IsMulticast()
+}
+
+func (v *URLValidator) followRedirects(ctx context.Context, start *url.URL) (*url.URL, []URLRedirect, error) {
+ if ctx == nil {
+ return nil, nil, ErrURLInvalid
+ }
+
+ client := v.httpClient()
+ current := start
+ visited := make(map[string]struct{})
+ redirects := make([]URLRedirect, 0)
+
+ for range v.opts.maxRedirects {
+ hopKey := current.String()
+ if _, ok := visited[hopKey]; ok {
+ return nil, nil, ErrURLRedirectLoop
+ }
+
+ visited[hopKey] = struct{}{}
+
+ nextURL, redirect, err := v.nextRedirect(ctx, client, current)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if redirect == nil {
+ return current, redirects, nil
+ }
+
+ redirects = append(redirects, *redirect)
+ current = nextURL
+ }
+
+ return nil, nil, ErrURLRedirectLimit
+}
+
+func (v *URLValidator) nextRedirect(ctx context.Context, client *http.Client, current *url.URL) (*url.URL, *URLRedirect, error) {
+ req, err := http.NewRequestWithContext(ctx, v.opts.redirectMethod, current.String(), nil)
+ if err != nil {
+ return nil, nil, ErrURLInvalid
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, nil, ErrURLRedirectNotAllowed
+ }
+
+ //nolint:errcheck
+ _ = resp.Body.Close()
+
+ if !isRedirectStatus(resp.StatusCode) {
+ return current, nil, nil
+ }
+
+ location := resp.Header.Get("Location")
+ if location == "" {
+ return nil, nil, ErrURLRedirectNotAllowed
+ }
+
+ nextURL, err := url.Parse(location)
+ if err != nil {
+ return nil, nil, ErrURLRedirectNotAllowed
+ }
+
+ nextURL = current.ResolveReference(nextURL)
+
+ err = v.validateParsed(nextURL)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ redirect := URLRedirect{
+ From: current.String(),
+ To: nextURL.String(),
+ StatusCode: resp.StatusCode,
+ }
+
+ return nextURL, &redirect, nil
+}
+
+func isRedirectStatus(code int) bool {
+ switch code {
+ case redirectStatusMultipleChoices,
+ redirectStatusMovedPermanently,
+ redirectStatusFound,
+ redirectStatusSeeOther,
+ redirectStatusTemporaryRedirect,
+ redirectStatusPermanentRedirect:
+ return true
+ default:
+ return false
+ }
+}
+
+func (v *URLValidator) httpClient() *http.Client {
+ client := v.opts.httpClient
+ if client == nil {
+ client = &http.Client{Timeout: urlDefaultTimeout}
+ }
+
+ clone := *client
+ clone.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
+ return http.ErrUseLastResponse
+ }
+
+ return &clone
+}
+
+func (v *URLValidator) checkReputation(ctx context.Context, target *url.URL) error {
+ if ctx == nil {
+ return ErrURLInvalid
+ }
+
+ result, err := v.opts.reputationChecker.Check(ctx, target)
+ if err != nil {
+ return fmt.Errorf("%w: %w", ErrURLReputationFailed, err)
+ }
+
+ if result.Verdict == ReputationBlocked {
+ return ErrURLReputationBlocked
+ }
+
+ return nil
+}
diff --git a/pkg/validate/url_test.go b/pkg/validate/url_test.go
new file mode 100644
index 0000000..a515930
--- /dev/null
+++ b/pkg/validate/url_test.go
@@ -0,0 +1,148 @@
+package validate
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+)
+
+type fakeRoundTripper struct {
+ responses map[string]*http.Response
+}
+
+func (f *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ if resp, ok := f.responses[req.URL.String()]; ok {
+ return resp, nil
+ }
+
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("")),
+ Header: make(http.Header),
+ }, nil
+}
+
+func TestURLValidateBasic(t *testing.T) {
+ validator, err := NewURLValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "https://example.com/path")
+ if err != nil {
+ t.Fatalf("expected valid url, got %v", err)
+ }
+}
+
+func TestURLRejectUserInfo(t *testing.T) {
+ validator, err := NewURLValidator()
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "https://user:pass@example.com")
+ if err != ErrURLUserInfoNotAllowed {
+ t.Fatalf("expected ErrURLUserInfoNotAllowed, got %v", err)
+ }
+}
+
+func TestURLRejectPrivateIP(t *testing.T) {
+ validator, err := NewURLValidator(
+ WithURLAllowIPLiteral(true),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "https://127.0.0.1")
+ if err != ErrURLPrivateIPNotAllowed {
+ t.Fatalf("expected ErrURLPrivateIPNotAllowed, got %v", err)
+ }
+}
+
+func TestURLRedirectCheck(t *testing.T) {
+ client := &http.Client{
+ Transport: &fakeRoundTripper{
+ responses: map[string]*http.Response{
+ "https://example.com/start": {
+ StatusCode: http.StatusFound,
+ Header: http.Header{"Location": []string{"/final"}},
+ Body: io.NopCloser(strings.NewReader("")),
+ },
+ "https://example.com/final": {
+ StatusCode: http.StatusOK,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader("")),
+ },
+ },
+ },
+ }
+
+ validator, err := NewURLValidator(
+ WithURLCheckRedirects(3),
+ WithURLHTTPClient(client),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ result, err := validator.Validate(context.Background(), "https://example.com/start")
+ if err != nil {
+ t.Fatalf("expected valid url, got %v", err)
+ }
+
+ if len(result.Redirects) != 1 {
+ t.Fatalf("expected 1 redirect, got %d", len(result.Redirects))
+ }
+}
+
+func TestURLRedirectLoop(t *testing.T) {
+ client := &http.Client{
+ Transport: &fakeRoundTripper{
+ responses: map[string]*http.Response{
+ "https://example.com/loop": {
+ StatusCode: http.StatusFound,
+ Header: http.Header{"Location": []string{"/loop"}},
+ Body: io.NopCloser(strings.NewReader("")),
+ },
+ },
+ },
+ }
+
+ validator, err := NewURLValidator(
+ WithURLCheckRedirects(2),
+ WithURLHTTPClient(client),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "https://example.com/loop")
+ if err != ErrURLRedirectLoop {
+ t.Fatalf("expected ErrURLRedirectLoop, got %v", err)
+ }
+}
+
+func TestURLReputationBlock(t *testing.T) {
+ checker := NewStaticReputation(nil, []string{"example.com"})
+ validator, err := NewURLValidator(
+ WithURLReputationChecker(checker),
+ )
+ if err != nil {
+ t.Fatalf("expected validator, got %v", err)
+ }
+
+ _, err = validator.Validate(context.Background(), "https://example.com")
+ if err != ErrURLReputationBlocked {
+ t.Fatalf("expected ErrURLReputationBlocked, got %v", err)
+ }
+}
+
+func TestURLRejectHTTPWithSchemesOption(t *testing.T) {
+ _, err := NewURLValidator(WithURLAllowedSchemes("http"))
+ if err != ErrInvalidURLConfig {
+ t.Fatalf("expected ErrInvalidURLConfig, got %v", err)
+ }
+}