From f45e35b14465fe4714f810c327ed548b88d24ab7 Mon Sep 17 00:00:00 2001 From: Borna Kapusta Date: Tue, 27 Jan 2026 23:09:36 +0100 Subject: [PATCH] feat: add rate limiting support with token bucket algorithm - Configure global and per-route rate limiters - Token bucket with total_rps, per_ip_rps, burst settings - Automatic cleanup of stale per-IP entries - Returns HTTP 429 with Retry-After header when exceeded - New metric: gatekeeper_rate_limited_total{route,limiter,reason} Closes #14 --- AGENTS.md | 35 +++- CHANGELOG.md | 3 + cmd/gatekeeperd/main.go | 39 ++++ config/example.yaml | 19 ++ go.mod | 5 +- go.sum | 6 + internal/config/config.go | 72 ++++++- internal/config/config_test.go | 173 +++++++++++++++ internal/metrics/metrics.go | 14 ++ internal/metrics/metrics_test.go | 6 + internal/proxy/handler.go | 48 +++++ internal/proxy/handler_test.go | 326 +++++++++++++++++++++++++++++ internal/ratelimit/limiter.go | 157 ++++++++++++++ internal/ratelimit/limiter_test.go | 298 ++++++++++++++++++++++++++ internal/ratelimit/set.go | 65 ++++++ internal/ratelimit/set_test.go | 142 +++++++++++++ 16 files changed, 1392 insertions(+), 16 deletions(-) create mode 100644 internal/ratelimit/limiter.go create mode 100644 internal/ratelimit/limiter_test.go create mode 100644 internal/ratelimit/set.go create mode 100644 internal/ratelimit/set_test.go diff --git a/AGENTS.md b/AGENTS.md index 0ca7eb3..3c17258 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,12 +47,13 @@ See [docs/CODING_STANDARDS.md](docs/CODING_STANDARDS.md) for: 1. TLS termination (autocert or ingress) 2. Route lookup by hostname and path -3. IP validation against configured allowlist -4. Signature verification using provider-specific algorithm -5. Either: +3. Rate limiting (if configured) - returns 429 with Retry-After header if exceeded +4. IP validation against configured allowlist +5. Signature verification using provider-specific algorithm +6. Either: - Forward to destination (transparent proxy), or - Deliver via relay to waiting relay client -6. Log result with minimal information (IP, path, success/failure) +7. Log result with minimal information (IP, path, success/failure) ### Delivery Modes @@ -109,6 +110,31 @@ In relay mode, the relay client inside the private network initiates an outbound | json_field | Microsoft Graph | Token embedded in JSON body at configurable path | | noop | Testing | Always succeeds | +### Rate Limiting + +Rate limiting protects against abuse using a token bucket algorithm. Configure named limiters and reference them from routes or set a global default. + +```yaml +rate_limiters: + default: + total_rps: 100 # Total requests per second across all IPs + per_ip_rps: 10 # Per client IP (0 = disabled) + burst: 20 # Spike allowance + cleanup_interval: 5m # Stale entry cleanup interval (default: 5m) + idle_timeout: 10m # Remove idle per-IP entries after (default: 10m) + +global: + default_rate_limiter: default # Apply to all routes without explicit limiter + +routes: + - hostname: example.com + path: /webhook + rate_limiter: default # Override or specify per-route + destination: http://backend:8080 +``` + +When rate limited, returns HTTP 429 with `Retry-After: 1` header. Metrics: `gatekeeper_rate_limited_total{route,limiter,reason}` where reason is `total` or `per_ip`. + ### Configuration Loading Configuration can be loaded from file or from environment variables: @@ -155,6 +181,7 @@ These are user-facing interactive wizards, not coding agent instructions. In Cla - Relay client config: internal/relayclient/config.go - Verifier interface: internal/verifier/verifier.go - HTTP handler: internal/proxy/handler.go +- Rate limiter: internal/ratelimit/limiter.go, internal/ratelimit/set.go - Relay manager: internal/relay/manager.go - Redis relay manager: internal/relay/redis_manager.go - Relay handler: internal/relay/handler.go diff --git a/CHANGELOG.md b/CHANGELOG.md index bf167d9..71a48d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Rate limiting support with token bucket algorithm: configure global and per-route rate limiters with total and per-IP limits, burst allowance, and automatic cleanup of stale entries. Returns HTTP 429 with Retry-After header when exceeded. New metric: `gatekeeper_rate_limited_total{route,limiter,reason}` + ## [0.2.6] - 2026-01-27 ### Added - Microsoft Graph subscription validation handling: automatically responds to `validationToken` query parameter on `json_field` verifier routes, enabling webhook setup without backend involvement (similar to Slack URL verification) diff --git a/cmd/gatekeeperd/main.go b/cmd/gatekeeperd/main.go index eba04c2..f4ff28b 100644 --- a/cmd/gatekeeperd/main.go +++ b/cmd/gatekeeperd/main.go @@ -18,6 +18,7 @@ import ( "github.com/tight-line/gatekeeper/internal/ipfilter" "github.com/tight-line/gatekeeper/internal/metrics" "github.com/tight-line/gatekeeper/internal/proxy" + "github.com/tight-line/gatekeeper/internal/ratelimit" "github.com/tight-line/gatekeeper/internal/relay" "github.com/tight-line/gatekeeper/internal/server" ) @@ -95,6 +96,13 @@ func run() error { logger.Info("debug payloads enabled - request/response bodies will be logged") } + // Setup rate limiters if configured + rateLimiters := buildRateLimiters(cfg, logger) + if rateLimiters != nil { + handler.SetRateLimiters(rateLimiters, cfg.Global.DefaultRateLimiter) + defer rateLimiters.Stop() + } + // Setup relay manager if any routes use relay tokens relayHandler, cleanup, err := setupRelayManager(cfg, handler, logger, *redisURI) if err != nil { @@ -377,6 +385,37 @@ func runHTTPServer(ctx context.Context, addr string, handler http.Handler, logge } } +// buildRateLimiters builds the rate limiter set from config +func buildRateLimiters(cfg *config.Config, logger *slog.Logger) *ratelimit.Set { + if len(cfg.RateLimiters) == 0 { + return nil + } + + limiters := ratelimit.NewSet() + for name, rlCfg := range cfg.RateLimiters { + limiter := ratelimit.New(name, ratelimit.Config{ + TotalRPS: rlCfg.TotalRPS, + PerIPRPS: rlCfg.PerIPRPS, + Burst: rlCfg.Burst, + CleanupInterval: rlCfg.CleanupInterval, + IdleTimeout: rlCfg.IdleTimeout, + }) + limiters.Add(name, limiter) + logger.Info("rate limiter loaded", + "name", name, + "total_rps", rlCfg.TotalRPS, + "per_ip_rps", rlCfg.PerIPRPS, + "burst", rlCfg.Burst, + ) + } + + if cfg.Global.DefaultRateLimiter != "" { + logger.Info("global default rate limiter set", "limiter", cfg.Global.DefaultRateLimiter) + } + + return limiters +} + func init() { fmt.Fprintf(os.Stderr, "gatekeeperd %s\n", version) } diff --git a/config/example.yaml b/config/example.yaml index 05fdbec..7ecaff2 100644 --- a/config/example.yaml +++ b/config/example.yaml @@ -6,6 +6,7 @@ global: acme_cache_dir: "/var/cache/gatekeeper/certs" metrics_port: 9090 log_level: info + # default_rate_limiter: default # Optional: apply this limiter to all routes by default # Named IP allowlists (CIDRs) # Each list can be static or dynamically fetched @@ -68,6 +69,23 @@ verifiers: none: type: noop +# Rate limiters - define named rate limiters that can be applied to routes +# Each limiter uses a token bucket algorithm with total and per-IP limits +rate_limiters: + # Default rate limiter with reasonable limits + default: + total_rps: 100 # Total requests per second across all IPs + per_ip_rps: 10 # Requests per second per client IP (0 = disabled) + burst: 20 # Spike allowance (token bucket capacity) + cleanup_interval: 5m # How often to scan for stale per-IP entries (default: 5m) + idle_timeout: 10m # Remove per-IP limiter after idle time (default: 10m) + + # Strict rate limiter for high-risk endpoints + strict: + total_rps: 10 + per_ip_rps: 2 + burst: 5 + # Proxy routes - each hostname is explicitly enumerated # ACME certs are obtained automatically for each hostname when using -tls flag routes: @@ -76,6 +94,7 @@ routes: path: /events ip_allowlist: aws verifier: avvo-slack + rate_limiter: default # Optional: apply rate limiting to this route destination: http://10.1.1.50:8080/webhooks/slack # Avvo Google Calendar diff --git a/go.mod b/go.mod index 9932e1d..9009116 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,18 @@ go 1.25.0 toolchain go1.25.1 require ( + github.com/alicebob/miniredis/v2 v2.36.0 github.com/google/uuid v1.6.0 github.com/itchyny/gojq v0.12.18 github.com/prometheus/client_golang v1.23.2 + github.com/redis/go-redis/v9 v9.17.2 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 golang.org/x/crypto v0.47.0 + golang.org/x/time v0.14.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/alicebob/miniredis/v2 v2.36.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -24,7 +26,6 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect - github.com/redis/go-redis/v9 v9.17.2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/net v0.48.0 // indirect diff --git a/go.sum b/go.sum index bd42f08..6cb86d4 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/alicebob/miniredis/v2 v2.36.0 h1:yKczg+ez0bQYsG/PrgqtMMmCfl820RPu27kV github.com/alicebob/miniredis/v2 v2.36.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -59,6 +63,8 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/config/config.go b/internal/config/config.go index 3e8e930..07722b8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,20 +11,22 @@ import ( // Config is the root configuration structure type Config struct { - Global GlobalConfig `yaml:"global"` - IPAllowlists map[string]IPAllowlist `yaml:"ip_allowlists"` - Verifiers map[string]VerifierConfig `yaml:"verifiers"` - Validators map[string]ValidatorConfig `yaml:"validators"` - Routes []RouteConfig `yaml:"routes"` + Global GlobalConfig `yaml:"global"` + IPAllowlists map[string]IPAllowlist `yaml:"ip_allowlists"` + Verifiers map[string]VerifierConfig `yaml:"verifiers"` + Validators map[string]ValidatorConfig `yaml:"validators"` + RateLimiters map[string]RateLimiterConfig `yaml:"rate_limiters"` + Routes []RouteConfig `yaml:"routes"` } // GlobalConfig contains global settings type GlobalConfig struct { - ACMEEmail string `yaml:"acme_email"` - ACMECacheDir string `yaml:"acme_cache_dir"` - MetricsPort int `yaml:"metrics_port"` - LogLevel string `yaml:"log_level"` - MaxBodySize int64 `yaml:"max_body_size"` // Maximum request body size in bytes (default: 10MB) + ACMEEmail string `yaml:"acme_email"` + ACMECacheDir string `yaml:"acme_cache_dir"` + MetricsPort int `yaml:"metrics_port"` + LogLevel string `yaml:"log_level"` + MaxBodySize int64 `yaml:"max_body_size"` // Maximum request body size in bytes (default: 10MB) + DefaultRateLimiter string `yaml:"default_rate_limiter"` // Optional default rate limiter for all routes } // DefaultMaxBodySize is the default maximum request body size (10MB) @@ -41,6 +43,15 @@ type IPAllowlist struct { RefreshInterval time.Duration `yaml:"refresh_interval,omitempty"` } +// RateLimiterConfig defines a named rate limiter +type RateLimiterConfig struct { + TotalRPS float64 `yaml:"total_rps"` // Total requests per second across all IPs + PerIPRPS float64 `yaml:"per_ip_rps"` // Requests per second per client IP (0 = disabled) + Burst int `yaml:"burst"` // Burst allowance for spike handling + CleanupInterval time.Duration `yaml:"cleanup_interval"` // How often to scan for stale per-IP entries (default: 5m) + IdleTimeout time.Duration `yaml:"idle_timeout"` // Remove per-IP limiter after idle time (default: 10m) +} + // VerifierConfig defines a webhook signature verifier type VerifierConfig struct { Type string `yaml:"type"` // slack, github, shopify, api_key, hmac, json_field, query_param, header_query_param, noop @@ -83,6 +94,7 @@ type RouteConfig struct { IPAllowlist string `yaml:"ip_allowlist"` Verifier string `yaml:"verifier"` Validator string `yaml:"validator,omitempty"` // Optional payload structure validator + RateLimiter string `yaml:"rate_limiter,omitempty"` // Optional rate limiter (falls back to global default) Destination string `yaml:"destination,omitempty"` // Direct delivery URL RelayToken string `yaml:"relay_token,omitempty"` // Relay delivery token (mutually exclusive with destination) PreserveHost bool `yaml:"preserve_host,omitempty"` // Pass original Host header to destination (default: false) @@ -169,6 +181,9 @@ func (c *Config) Validate() error { if err := c.validateValidators(); err != nil { return err } + if err := c.validateRateLimiters(); err != nil { + return err + } return nil } @@ -211,6 +226,11 @@ func (c *Config) validateRoute(i int, route RouteConfig) error { return fmt.Errorf("route %d: validator %q not found", i, route.Validator) } } + if route.RateLimiter != "" { + if _, ok := c.RateLimiters[route.RateLimiter]; !ok { + return fmt.Errorf("route %d: rate_limiter %q not found", i, route.RateLimiter) + } + } return nil } @@ -357,6 +377,38 @@ func validateValidator(name string, v ValidatorConfig) error { return nil } +// validateRateLimiters checks that all rate limiter configs are valid +func (c *Config) validateRateLimiters() error { + // Validate global default rate limiter reference + if c.Global.DefaultRateLimiter != "" { + if _, ok := c.RateLimiters[c.Global.DefaultRateLimiter]; !ok { + return fmt.Errorf("global: default_rate_limiter %q not found", c.Global.DefaultRateLimiter) + } + } + + // Validate each rate limiter config + for name, rl := range c.RateLimiters { + if err := validateRateLimiter(name, rl); err != nil { + return err + } + } + return nil +} + +// validateRateLimiter validates a single rate limiter configuration +func validateRateLimiter(name string, rl RateLimiterConfig) error { + if rl.TotalRPS <= 0 { + return fmt.Errorf("rate_limiter %q: total_rps must be positive", name) + } + if rl.PerIPRPS < 0 { + return fmt.Errorf("rate_limiter %q: per_ip_rps cannot be negative", name) + } + if rl.Burst <= 0 { + return fmt.Errorf("rate_limiter %q: burst must be positive", name) + } + return nil +} + // GetHostnames returns all unique hostnames configured in routes func (c *Config) GetHostnames() []string { seen := make(map[string]bool) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 378ef6c..ce77eab 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -958,3 +958,176 @@ func TestValidate_ValidHeaderQueryParamVerifier(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +func TestValidate_RateLimiter_MissingTotalRPS(t *testing.T) { + cfg := &Config{ + RateLimiters: map[string]RateLimiterConfig{ + "test": { + PerIPRPS: 10, + Burst: 20, + }, + }, + } + + err := cfg.Validate() + if err == nil { + t.Error("expected validation error for rate_limiter without total_rps") + } +} + +func TestValidate_RateLimiter_NegativePerIPRPS(t *testing.T) { + cfg := &Config{ + RateLimiters: map[string]RateLimiterConfig{ + "test": { + TotalRPS: 100, + PerIPRPS: -1, + Burst: 20, + }, + }, + } + + err := cfg.Validate() + if err == nil { + t.Error("expected validation error for rate_limiter with negative per_ip_rps") + } +} + +func TestValidate_RateLimiter_MissingBurst(t *testing.T) { + cfg := &Config{ + RateLimiters: map[string]RateLimiterConfig{ + "test": { + TotalRPS: 100, + PerIPRPS: 10, + }, + }, + } + + err := cfg.Validate() + if err == nil { + t.Error("expected validation error for rate_limiter without burst") + } +} + +func TestValidate_ValidRateLimiter(t *testing.T) { + cfg := &Config{ + RateLimiters: map[string]RateLimiterConfig{ + "test": { + TotalRPS: 100, + PerIPRPS: 10, + Burst: 20, + }, + }, + } + + err := cfg.Validate() + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestValidate_RateLimiter_ZeroPerIPRPS_Allowed(t *testing.T) { + cfg := &Config{ + RateLimiters: map[string]RateLimiterConfig{ + "test": { + TotalRPS: 100, + PerIPRPS: 0, // Zero means per-IP limiting disabled + Burst: 20, + }, + }, + } + + err := cfg.Validate() + if err != nil { + t.Errorf("unexpected error (per_ip_rps=0 should be allowed): %v", err) + } +} + +func TestValidate_RouteReferencesInvalidRateLimiter(t *testing.T) { + cfg := &Config{ + Routes: []RouteConfig{ + { + Hostname: "test.example.com", + Path: "/webhook", + Destination: "http://backend:8080", + RateLimiter: "nonexistent", + }, + }, + } + + err := cfg.Validate() + if err == nil { + t.Error("expected validation error for invalid rate_limiter reference") + } +} + +func TestValidate_RouteWithValidRateLimiter(t *testing.T) { + cfg := &Config{ + RateLimiters: map[string]RateLimiterConfig{ + "default": { + TotalRPS: 100, + PerIPRPS: 10, + Burst: 20, + }, + }, + Routes: []RouteConfig{ + { + Hostname: "test.example.com", + Path: "/webhook", + Destination: "http://backend:8080", + RateLimiter: "default", + }, + }, + } + + err := cfg.Validate() + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestValidate_GlobalDefaultRateLimiter_NotFound(t *testing.T) { + cfg := &Config{ + Global: GlobalConfig{ + DefaultRateLimiter: "nonexistent", + }, + Routes: []RouteConfig{ + { + Hostname: "test.example.com", + Path: "/webhook", + Destination: "http://backend:8080", + }, + }, + } + + err := cfg.Validate() + if err == nil { + t.Error("expected validation error for invalid global default_rate_limiter") + } +} + +func TestValidate_GlobalDefaultRateLimiter_Valid(t *testing.T) { + cfg := &Config{ + Global: GlobalConfig{ + DefaultRateLimiter: "default", + }, + RateLimiters: map[string]RateLimiterConfig{ + "default": { + TotalRPS: 100, + PerIPRPS: 10, + Burst: 20, + }, + }, + Routes: []RouteConfig{ + { + Hostname: "test.example.com", + Path: "/webhook", + Destination: "http://backend:8080", + }, + }, + } + + err := cfg.Validate() + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 479001d..bfd6dfd 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -139,6 +139,15 @@ var ( }, []string{"token"}, ) + + // RateLimitedTotal counts requests denied by rate limiting + RateLimitedTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gatekeeper_rate_limited_total", + Help: "Total number of requests denied by rate limiting", + }, + []string{"route", "limiter", "reason"}, + ) ) // Handler returns the Prometheus metrics HTTP handler @@ -207,3 +216,8 @@ func RecordRelayWebhooksPending(token string, count int) { func RecordRelayClientsConnected(token string, count int) { RelayClientsConnected.WithLabelValues(token).Set(float64(count)) } + +// RecordRateLimited records a request denied by rate limiting +func RecordRateLimited(route, limiter, reason string) { + RateLimitedTotal.WithLabelValues(route, limiter, reason).Inc() +} diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 280fac8..debb134 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -105,3 +105,9 @@ func TestRecordRelayClientsConnected(t *testing.T) { RecordRelayClientsConnected("token1", 2) RecordRelayClientsConnected("token2", 1) } + +func TestRecordRateLimited(t *testing.T) { + t.Helper() + RecordRateLimited("/webhook", "default", "total") + RecordRateLimited("/webhook", "default", "per_ip") +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 48aaf7a..5c6ded8 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -20,6 +20,7 @@ import ( gkhttputil "github.com/tight-line/gatekeeper/internal/httputil" "github.com/tight-line/gatekeeper/internal/ipfilter" "github.com/tight-line/gatekeeper/internal/metrics" + "github.com/tight-line/gatekeeper/internal/ratelimit" "github.com/tight-line/gatekeeper/internal/relay" "github.com/tight-line/gatekeeper/internal/validator" "github.com/tight-line/gatekeeper/internal/verifier" @@ -51,6 +52,8 @@ type Handler struct { verifierTypes map[string]string // verifier name -> type (e.g., "slack", "github") validators map[string]validator.Validator filters *ipfilter.FilterSet + rateLimiters *ratelimit.Set + defaultRateLimiter string relay relay.Manager logger *slog.Logger trustXForwardedFor bool @@ -115,6 +118,12 @@ func (h *Handler) SetRelayManager(rm relay.Manager) { h.relay = rm } +// SetRateLimiters sets the rate limiters and default limiter for handling rate limiting +func (h *Handler) SetRateLimiters(limiters *ratelimit.Set, defaultLimiter string) { + h.rateLimiters = limiters + h.defaultRateLimiter = defaultLimiter +} + // buildVerifier creates a verifier from config func buildVerifier(vc config.VerifierConfig) (verifier.Verifier, error) { switch vc.Type { @@ -178,6 +187,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if !h.checkRateLimit(w, r, ctx) { + return + } + if !h.checkIPAllowlist(w, r, ctx) { return } @@ -230,6 +243,41 @@ func (h *Handler) handleNotFound(w http.ResponseWriter, r *http.Request, ctx *re http.Error(w, "Not Found", http.StatusNotFound) } +func (h *Handler) checkRateLimit(w http.ResponseWriter, r *http.Request, ctx *requestContext) bool { + if h.rateLimiters == nil { + return true + } + + // Determine which limiter to use: route-specific or global default + limiterName := ctx.route.RateLimiter + if limiterName == "" { + limiterName = h.defaultRateLimiter + } + if limiterName == "" { + return true + } + + clientIP := h.getClientIP(r) + allowed, reason := h.rateLimiters.Allow(limiterName, clientIP) + if allowed { + return true + } + + h.logger.Warn("rate limited", + "hostname", ctx.hostname, + "path", r.URL.Path, + "client_ip", clientIP, + "limiter", limiterName, + "reason", reason, + ) + metrics.RecordRateLimited(ctx.route.Path, limiterName, reason) + metrics.RecordRequest(ctx.hostname, ctx.route.Path, "429", time.Since(ctx.start).Seconds()) + + w.Header().Set("Retry-After", "1") + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return false +} + func (h *Handler) checkIPAllowlist(w http.ResponseWriter, r *http.Request, ctx *requestContext) bool { if ctx.route.IPAllowlist == "" { return true diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index 1eb9563..89c118f 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -21,6 +21,7 @@ import ( "github.com/tight-line/gatekeeper/internal/config" "github.com/tight-line/gatekeeper/internal/ipfilter" + "github.com/tight-line/gatekeeper/internal/ratelimit" "github.com/tight-line/gatekeeper/internal/relay" "github.com/tight-line/gatekeeper/internal/verifier" ) @@ -2823,3 +2824,328 @@ func TestHandler_MicrosoftGraphValidation_Relay(t *testing.T) { t.Errorf("expected body 'relay-test-token', got %q", rr.Body.String()) } } + +func TestHandler_RateLimiting_NoLimiter(t *testing.T) { + // Without rate limiters configured, requests should pass through + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + cfg := &config.Config{ + Routes: []config.RouteConfig{ + { + Hostname: "test.com", + Path: "/webhook", + Destination: backend.URL, + }, + }, + } + + filters := ipfilter.NewFilterSet() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler, err := NewHandler(cfg, filters, logger, HandlerOptions{}) + if err != nil { + t.Fatalf("failed to create handler: %v", err) + } + // No SetRateLimiters call - rate limiting not configured + + // Multiple requests should all succeed + for i := 0; i < 10; i++ { + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "127.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", i, rr.Code) + } + } +} + +func TestHandler_RateLimiting_TotalLimit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + cfg := &config.Config{ + Routes: []config.RouteConfig{ + { + Hostname: "test.com", + Path: "/webhook", + Destination: backend.URL, + RateLimiter: "strict", + }, + }, + } + + filters := ipfilter.NewFilterSet() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler, err := NewHandler(cfg, filters, logger, HandlerOptions{}) + if err != nil { + t.Fatalf("failed to create handler: %v", err) + } + + // Create rate limiter set with very strict limits + limiters := ratelimit.NewSet() + defer limiters.Stop() + limiters.Add("strict", ratelimit.New("strict", ratelimit.Config{ + TotalRPS: 1, + PerIPRPS: 0, // Only total limiting + Burst: 1, + })) + handler.SetRateLimiters(limiters, "") + + // First request should succeed + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "127.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("first request: expected 200, got %d", rr.Code) + } + + // Second request should be rate limited + req = httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "192.168.1.1:12345" // Different IP, but total limit applies + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("second request: expected 429, got %d", rr.Code) + } + if rr.Header().Get("Retry-After") != "1" { + t.Errorf("expected Retry-After: 1, got %q", rr.Header().Get("Retry-After")) + } +} + +func TestHandler_RateLimiting_PerIPLimit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + cfg := &config.Config{ + Routes: []config.RouteConfig{ + { + Hostname: "test.com", + Path: "/webhook", + Destination: backend.URL, + RateLimiter: "per-ip", + }, + }, + } + + filters := ipfilter.NewFilterSet() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler, err := NewHandler(cfg, filters, logger, HandlerOptions{}) + if err != nil { + t.Fatalf("failed to create handler: %v", err) + } + + // Burst applies to both total and per-IP equally + // With burst=5 and per_ip_rps=1, each IP gets 5 burst requests before hitting the per-IP limit + // High total RPS ensures total limit refills fast enough to not be a bottleneck + limiters := ratelimit.NewSet() + defer limiters.Stop() + limiters.Add("per-ip", ratelimit.New("per-ip", ratelimit.Config{ + TotalRPS: 10000, // Very high total limit (refills quickly) + PerIPRPS: 1, // Low per-IP limit + Burst: 5, // Allow 5 burst requests per IP + })) + handler.SetRateLimiters(limiters, "") + + // First 5 requests from IP1 should succeed (using burst) + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("IP1 request %d: expected 200, got %d", i+1, rr.Code) + } + } + + // 6th request from IP1 should be rate limited (burst exhausted) + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("IP1 6th request: expected 429, got %d", rr.Code) + } + + // First request from IP2 should succeed (different per-IP limiter with its own burst) + req = httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "192.168.1.2:12345" + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("IP2 first request: expected 200, got %d", rr.Code) + } +} + +func TestHandler_RateLimiting_GlobalDefault(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + cfg := &config.Config{ + Routes: []config.RouteConfig{ + { + Hostname: "test.com", + Path: "/webhook", + Destination: backend.URL, + // No RateLimiter specified - should use global default + }, + }, + } + + filters := ipfilter.NewFilterSet() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler, err := NewHandler(cfg, filters, logger, HandlerOptions{}) + if err != nil { + t.Fatalf("failed to create handler: %v", err) + } + + limiters := ratelimit.NewSet() + defer limiters.Stop() + limiters.Add("default", ratelimit.New("default", ratelimit.Config{ + TotalRPS: 1, + Burst: 1, + })) + handler.SetRateLimiters(limiters, "default") // Set global default + + // First request should succeed + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "127.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("first request: expected 200, got %d", rr.Code) + } + + // Second request should be rate limited + req = httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "127.0.0.1:12345" + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("second request: expected 429, got %d", rr.Code) + } +} + +func TestHandler_RateLimiting_RouteOverridesDefault(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + cfg := &config.Config{ + Routes: []config.RouteConfig{ + { + Hostname: "test.com", + Path: "/webhook", + Destination: backend.URL, + RateLimiter: "lenient", // Route-specific limiter + }, + }, + } + + filters := ipfilter.NewFilterSet() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler, err := NewHandler(cfg, filters, logger, HandlerOptions{}) + if err != nil { + t.Fatalf("failed to create handler: %v", err) + } + + limiters := ratelimit.NewSet() + defer limiters.Stop() + // Strict default that would block after 1 request + limiters.Add("default", ratelimit.New("default", ratelimit.Config{ + TotalRPS: 1, + Burst: 1, + })) + // Lenient route-specific limiter + limiters.Add("lenient", ratelimit.New("lenient", ratelimit.Config{ + TotalRPS: 100, + Burst: 10, + })) + handler.SetRateLimiters(limiters, "default") + + // Multiple requests should succeed (using lenient limiter, not default) + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "127.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", i, rr.Code) + } + } +} + +func TestHandler_RateLimiting_NoDefaultNoRoute(t *testing.T) { + // When no default and no route limiter, rate limiting is skipped + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + cfg := &config.Config{ + Routes: []config.RouteConfig{ + { + Hostname: "test.com", + Path: "/webhook", + Destination: backend.URL, + // No RateLimiter specified + }, + }, + } + + filters := ipfilter.NewFilterSet() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler, err := NewHandler(cfg, filters, logger, HandlerOptions{}) + if err != nil { + t.Fatalf("failed to create handler: %v", err) + } + + limiters := ratelimit.NewSet() + defer limiters.Stop() + limiters.Add("unused", ratelimit.New("unused", ratelimit.Config{ + TotalRPS: 1, + Burst: 1, + })) + handler.SetRateLimiters(limiters, "") // No global default + + // Multiple requests should succeed (no limiter applied) + for i := 0; i < 10; i++ { + req := httptest.NewRequest(http.MethodPost, "https://test.com/webhook", strings.NewReader("test")) + req.Host = "test.com" + req.RemoteAddr = "127.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", i, rr.Code) + } + } +} diff --git a/internal/ratelimit/limiter.go b/internal/ratelimit/limiter.go new file mode 100644 index 0000000..bf9f0be --- /dev/null +++ b/internal/ratelimit/limiter.go @@ -0,0 +1,157 @@ +package ratelimit + +import ( + "sync" + "time" + + "golang.org/x/time/rate" +) + +// DefaultCleanupInterval is the default interval for cleaning up stale per-IP entries +const DefaultCleanupInterval = 5 * time.Minute + +// DefaultIdleTimeout is the default time after which idle per-IP entries are removed +const DefaultIdleTimeout = 10 * time.Minute + +// Config holds rate limiter configuration +type Config struct { + TotalRPS float64 // Total requests per second across all IPs + PerIPRPS float64 // Requests per second per client IP (0 = disabled) + Burst int // Burst allowance for spike handling + CleanupInterval time.Duration // How often to scan for stale per-IP entries + IdleTimeout time.Duration // Remove per-IP limiter after idle time +} + +// ipEntry holds a per-IP rate limiter and its last access time +type ipEntry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// Limiter implements rate limiting with both total and per-IP limits +type Limiter struct { + name string + config Config + + total *rate.Limiter + + mu sync.Mutex + perIP map[string]*ipEntry + stopCh chan struct{} + stopped bool +} + +// New creates a new rate limiter with the given configuration +func New(name string, cfg Config) *Limiter { + // Apply defaults + if cfg.CleanupInterval <= 0 { + cfg.CleanupInterval = DefaultCleanupInterval + } + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = DefaultIdleTimeout + } + + l := &Limiter{ + name: name, + config: cfg, + total: rate.NewLimiter(rate.Limit(cfg.TotalRPS), cfg.Burst), + perIP: make(map[string]*ipEntry), + stopCh: make(chan struct{}), + } + + // Start cleanup goroutine if per-IP limiting is enabled + if cfg.PerIPRPS > 0 { + go l.cleanupLoop() + } + + return l +} + +// Allow checks if a request from the given client IP should be allowed. +// Returns the reason for denial ("total" or "per_ip") if denied, or "" if allowed. +func (l *Limiter) Allow(clientIP string) (allowed bool, reason string) { + // Check total rate limit first + if !l.total.Allow() { + return false, "total" + } + + // Check per-IP rate limit if enabled + if l.config.PerIPRPS > 0 { + limiter := l.getOrCreateIPLimiter(clientIP) + if !limiter.Allow() { + return false, "per_ip" + } + } + + return true, "" +} + +// getOrCreateIPLimiter returns the rate limiter for the given IP, creating one if needed +func (l *Limiter) getOrCreateIPLimiter(clientIP string) *rate.Limiter { + l.mu.Lock() + defer l.mu.Unlock() + + entry, ok := l.perIP[clientIP] + if ok { + entry.lastSeen = time.Now() + return entry.limiter + } + + limiter := rate.NewLimiter(rate.Limit(l.config.PerIPRPS), l.config.Burst) + l.perIP[clientIP] = &ipEntry{ + limiter: limiter, + lastSeen: time.Now(), + } + return limiter +} + +// cleanupLoop periodically removes stale per-IP entries +func (l *Limiter) cleanupLoop() { + ticker := time.NewTicker(l.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + l.cleanup() + case <-l.stopCh: + return + } + } +} + +// cleanup removes per-IP entries that have been idle for too long +func (l *Limiter) cleanup() { + l.mu.Lock() + defer l.mu.Unlock() + + cutoff := time.Now().Add(-l.config.IdleTimeout) + for ip, entry := range l.perIP { + if entry.lastSeen.Before(cutoff) { + delete(l.perIP, ip) + } + } +} + +// Stop stops the cleanup goroutine +func (l *Limiter) Stop() { + l.mu.Lock() + defer l.mu.Unlock() + + if !l.stopped { + l.stopped = true + close(l.stopCh) + } +} + +// Name returns the limiter's name +func (l *Limiter) Name() string { + return l.name +} + +// PerIPCount returns the current number of tracked per-IP entries (for testing/metrics) +func (l *Limiter) PerIPCount() int { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.perIP) +} diff --git a/internal/ratelimit/limiter_test.go b/internal/ratelimit/limiter_test.go new file mode 100644 index 0000000..d478a15 --- /dev/null +++ b/internal/ratelimit/limiter_test.go @@ -0,0 +1,298 @@ +package ratelimit + +import ( + "testing" + "time" +) + +func TestNew_DefaultValues(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 20, + }) + defer l.Stop() + + if l.config.CleanupInterval != DefaultCleanupInterval { + t.Errorf("expected default cleanup interval %v, got %v", DefaultCleanupInterval, l.config.CleanupInterval) + } + if l.config.IdleTimeout != DefaultIdleTimeout { + t.Errorf("expected default idle timeout %v, got %v", DefaultIdleTimeout, l.config.IdleTimeout) + } +} + +func TestNew_CustomValues(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 20, + CleanupInterval: 1 * time.Minute, + IdleTimeout: 2 * time.Minute, + }) + defer l.Stop() + + if l.config.CleanupInterval != 1*time.Minute { + t.Errorf("expected cleanup interval 1m, got %v", l.config.CleanupInterval) + } + if l.config.IdleTimeout != 2*time.Minute { + t.Errorf("expected idle timeout 2m, got %v", l.config.IdleTimeout) + } +} + +func TestLimiter_Name(t *testing.T) { + l := New("my-limiter", Config{ + TotalRPS: 100, + Burst: 10, + }) + defer l.Stop() + + if l.Name() != "my-limiter" { + t.Errorf("expected name 'my-limiter', got %q", l.Name()) + } +} + +func TestLimiter_Allow_TotalLimit(t *testing.T) { + l := New("test", Config{ + TotalRPS: 1, // 1 request per second + Burst: 1, // Only 1 burst + PerIPRPS: 0, // No per-IP limiting + }) + defer l.Stop() + + // First request should be allowed + allowed, reason := l.Allow("192.168.1.1") + if !allowed { + t.Errorf("expected first request to be allowed, got denied with reason: %s", reason) + } + + // Second request should be denied (exceeded burst) + allowed, reason = l.Allow("192.168.1.1") + if allowed { + t.Error("expected second request to be denied due to total limit") + } + if reason != "total" { + t.Errorf("expected reason 'total', got %q", reason) + } +} + +func TestLimiter_Allow_PerIPLimit(t *testing.T) { + l := New("test", Config{ + TotalRPS: 10000, // Very high total limit (refills quickly) + PerIPRPS: 1, // 1 request per second per IP + Burst: 1, // Only 1 burst + }) + defer l.Stop() + + // First request from IP1 should be allowed + allowed, reason := l.Allow("192.168.1.1") + if !allowed { + t.Errorf("expected first request from IP1 to be allowed, got denied with reason: %s", reason) + } + + // Brief pause to let total limiter refill (10000 RPS = 10 tokens/ms) + time.Sleep(1 * time.Millisecond) + + // Second request from IP1 should be denied by per-IP limit + // (total has refilled, but per-IP with 1 RPS hasn't) + allowed, reason = l.Allow("192.168.1.1") + if allowed { + t.Error("expected second request from IP1 to be denied") + } + if reason != "per_ip" { + t.Errorf("expected reason 'per_ip', got %q", reason) + } + + // Brief pause to let total limiter refill again + time.Sleep(1 * time.Millisecond) + + // First request from IP2 should be allowed + // (total has refilled, and IP2 has its own fresh per-IP limiter) + allowed, reason = l.Allow("192.168.1.2") + if !allowed { + t.Errorf("expected first request from IP2 to be allowed, got denied with reason: %s", reason) + } +} + +func TestLimiter_Allow_NoPerIPLimit(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 0, // Per-IP limiting disabled + Burst: 100, + }) + defer l.Stop() + + // Multiple requests from same IP should all be allowed (up to total limit) + for i := 0; i < 50; i++ { + allowed, reason := l.Allow("192.168.1.1") + if !allowed { + t.Errorf("request %d: expected to be allowed, got denied with reason: %s", i, reason) + } + } +} + +func TestLimiter_Allow_BurstHandling(t *testing.T) { + l := New("test", Config{ + TotalRPS: 10000, // Very high total limit (refills quickly) + PerIPRPS: 10, + Burst: 5, // Allow 5 requests burst + }) + defer l.Stop() + + // Should allow burst of 5 requests + for i := 0; i < 5; i++ { + allowed, reason := l.Allow("192.168.1.1") + if !allowed { + t.Errorf("burst request %d: expected to be allowed, got denied with reason: %s", i, reason) + } + } + + // Brief pause to let total limiter refill + time.Sleep(1 * time.Millisecond) + + // 6th request should be denied by per-IP limit + // (total has refilled, but per-IP burst is exhausted) + allowed, reason := l.Allow("192.168.1.1") + if allowed { + t.Error("expected 6th request to be denied (burst exceeded)") + } + if reason != "per_ip" { + t.Errorf("expected reason 'per_ip', got %q", reason) + } +} + +func TestLimiter_PerIPCount(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 10, + CleanupInterval: 1 * time.Hour, // Long interval so no cleanup during test + IdleTimeout: 1 * time.Hour, + }) + defer l.Stop() + + if l.PerIPCount() != 0 { + t.Errorf("expected 0 per-IP entries, got %d", l.PerIPCount()) + } + + l.Allow("192.168.1.1") + if l.PerIPCount() != 1 { + t.Errorf("expected 1 per-IP entry, got %d", l.PerIPCount()) + } + + l.Allow("192.168.1.2") + if l.PerIPCount() != 2 { + t.Errorf("expected 2 per-IP entries, got %d", l.PerIPCount()) + } + + // Same IP should not create new entry + l.Allow("192.168.1.1") + if l.PerIPCount() != 2 { + t.Errorf("expected still 2 per-IP entries, got %d", l.PerIPCount()) + } +} + +func TestLimiter_Cleanup(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 10, + CleanupInterval: 10 * time.Millisecond, + IdleTimeout: 20 * time.Millisecond, + }) + defer l.Stop() + + // Create some per-IP entries + l.Allow("192.168.1.1") + l.Allow("192.168.1.2") + + if l.PerIPCount() != 2 { + t.Errorf("expected 2 per-IP entries, got %d", l.PerIPCount()) + } + + // Wait for idle timeout + cleanup interval + time.Sleep(50 * time.Millisecond) + + // Entries should be cleaned up + if l.PerIPCount() != 0 { + t.Errorf("expected 0 per-IP entries after cleanup, got %d", l.PerIPCount()) + } +} + +func TestLimiter_Cleanup_ActiveEntriesNotRemoved(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 10, + CleanupInterval: 10 * time.Millisecond, + IdleTimeout: 50 * time.Millisecond, + }) + defer l.Stop() + + // Create entries + l.Allow("192.168.1.1") + l.Allow("192.168.1.2") + + // Keep one IP active + for i := 0; i < 5; i++ { + time.Sleep(15 * time.Millisecond) + l.Allow("192.168.1.1") // Refresh IP1 + } + + // IP1 should still exist, IP2 should be cleaned up + count := l.PerIPCount() + if count != 1 { + t.Errorf("expected 1 per-IP entry (active one), got %d", count) + } +} + +func TestLimiter_Stop(t *testing.T) { + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 10, + CleanupInterval: 10 * time.Millisecond, + IdleTimeout: 10 * time.Millisecond, + }) + + // Stop should be idempotent + l.Stop() + l.Stop() // Should not panic +} + +func TestLimiter_NoPerIPCleanupGoroutine(t *testing.T) { + // When PerIPRPS is 0, no cleanup goroutine should be started + l := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 0, // Per-IP limiting disabled + Burst: 10, + }) + defer l.Stop() + + // This shouldn't cause any issues - no cleanup goroutine running + time.Sleep(10 * time.Millisecond) +} + +func TestLimiter_TotalLimitBeforePerIP(t *testing.T) { + // Test that total limit is checked before per-IP limit + l := New("test", Config{ + TotalRPS: 1, // Very low total limit + PerIPRPS: 100, + Burst: 1, + }) + defer l.Stop() + + // First request allowed + allowed, _ := l.Allow("192.168.1.1") + if !allowed { + t.Error("expected first request to be allowed") + } + + // Second request should be denied by TOTAL limit (not per-IP) + allowed, reason := l.Allow("192.168.1.2") // Different IP + if allowed { + t.Error("expected second request to be denied") + } + if reason != "total" { + t.Errorf("expected reason 'total' (total should be checked first), got %q", reason) + } +} diff --git a/internal/ratelimit/set.go b/internal/ratelimit/set.go new file mode 100644 index 0000000..1515458 --- /dev/null +++ b/internal/ratelimit/set.go @@ -0,0 +1,65 @@ +package ratelimit + +import ( + "sync" +) + +// Set manages a collection of named rate limiters +type Set struct { + mu sync.RWMutex + limiters map[string]*Limiter +} + +// NewSet creates a new limiter set +func NewSet() *Set { + return &Set{ + limiters: make(map[string]*Limiter), + } +} + +// Add adds a limiter to the set +func (s *Set) Add(name string, limiter *Limiter) { + s.mu.Lock() + defer s.mu.Unlock() + s.limiters[name] = limiter +} + +// Get returns the limiter with the given name, or nil if not found +func (s *Set) Get(name string) *Limiter { + s.mu.RLock() + defer s.mu.RUnlock() + return s.limiters[name] +} + +// Allow checks if a request is allowed by the named limiter. +// Returns (true, "") if allowed, (false, reason) if denied. +// If the limiter is not found, returns (true, "") (fail open). +func (s *Set) Allow(name, clientIP string) (allowed bool, reason string) { + limiter := s.Get(name) + if limiter == nil { + return true, "" + } + return limiter.Allow(clientIP) +} + +// Stop stops all limiters in the set +func (s *Set) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + + for _, limiter := range s.limiters { + limiter.Stop() + } +} + +// Names returns the names of all limiters in the set +func (s *Set) Names() []string { + s.mu.RLock() + defer s.mu.RUnlock() + + names := make([]string, 0, len(s.limiters)) + for name := range s.limiters { + names = append(names, name) + } + return names +} diff --git a/internal/ratelimit/set_test.go b/internal/ratelimit/set_test.go new file mode 100644 index 0000000..e8c9280 --- /dev/null +++ b/internal/ratelimit/set_test.go @@ -0,0 +1,142 @@ +package ratelimit + +import ( + "sort" + "testing" +) + +func TestNewSet(t *testing.T) { + s := NewSet() + if s == nil { + t.Fatal("expected non-nil Set") + } + if len(s.Names()) != 0 { + t.Errorf("expected empty set, got %d limiters", len(s.Names())) + } +} + +func TestSet_AddAndGet(t *testing.T) { + s := NewSet() + defer s.Stop() + + limiter := New("test", Config{ + TotalRPS: 100, + PerIPRPS: 10, + Burst: 20, + }) + + s.Add("test", limiter) + + got := s.Get("test") + if got == nil { + t.Fatal("expected to get limiter, got nil") + } + if got.Name() != "test" { + t.Errorf("expected limiter name 'test', got %q", got.Name()) + } +} + +func TestSet_Get_NotFound(t *testing.T) { + s := NewSet() + + got := s.Get("nonexistent") + if got != nil { + t.Errorf("expected nil for nonexistent limiter, got %v", got) + } +} + +func TestSet_Allow_LimiterExists(t *testing.T) { + s := NewSet() + defer s.Stop() + + limiter := New("test", Config{ + TotalRPS: 1, + PerIPRPS: 1, + Burst: 1, + }) + s.Add("test", limiter) + + // First request should be allowed + allowed, reason := s.Allow("test", "192.168.1.1") + if !allowed { + t.Errorf("expected first request to be allowed, got denied with reason: %s", reason) + } + + // Second request should be denied + allowed, _ = s.Allow("test", "192.168.1.1") + if allowed { + t.Error("expected second request to be denied") + } +} + +func TestSet_Allow_LimiterNotFound(t *testing.T) { + s := NewSet() + + // Should fail open when limiter not found + allowed, reason := s.Allow("nonexistent", "192.168.1.1") + if !allowed { + t.Errorf("expected request to be allowed (fail open), got denied with reason: %s", reason) + } + if reason != "" { + t.Errorf("expected empty reason, got %q", reason) + } +} + +func TestSet_Names(t *testing.T) { + s := NewSet() + defer s.Stop() + + s.Add("limiter1", New("limiter1", Config{TotalRPS: 100, Burst: 10})) + s.Add("limiter2", New("limiter2", Config{TotalRPS: 100, Burst: 10})) + s.Add("limiter3", New("limiter3", Config{TotalRPS: 100, Burst: 10})) + + names := s.Names() + if len(names) != 3 { + t.Errorf("expected 3 names, got %d", len(names)) + } + + sort.Strings(names) + expected := []string{"limiter1", "limiter2", "limiter3"} + for i, name := range names { + if name != expected[i] { + t.Errorf("expected name %q at index %d, got %q", expected[i], i, name) + } + } +} + +func TestSet_Stop(t *testing.T) { + s := NewSet() + + s.Add("limiter1", New("limiter1", Config{TotalRPS: 100, PerIPRPS: 10, Burst: 10})) + s.Add("limiter2", New("limiter2", Config{TotalRPS: 100, PerIPRPS: 10, Burst: 10})) + + // Stop should not panic and should stop all limiters + s.Stop() +} + +func TestSet_Concurrent(t *testing.T) { + s := NewSet() + defer s.Stop() + + s.Add("test", New("test", Config{ + TotalRPS: 1000, + PerIPRPS: 100, + Burst: 100, + })) + + // Run concurrent Allow calls + done := make(chan struct{}) + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + s.Allow("test", "192.168.1.1") + } + done <- struct{}{} + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } +}