diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index fcb3d61..bd2d544 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@f808768d1510423e83855289c910610ca9b43176 # v2.17.0 + uses: step-security/harden-runner@8d3c67de8e2fe68ef647c8db1e6a09f647780f40 # v2.19.0 with: egress-policy: audit diff --git a/Dockerfile b/Dockerfile index 090bfb5..8a89c24 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM --platform=$BUILDPLATFORM golang:1.26.2-alpine3.23@sha256:c2a1f7b2095d046ae14b286b18413a05bb82c9bca9b25fe7ff5efef0f0826166 AS build +FROM --platform=$BUILDPLATFORM golang:1.26.2-alpine3.23@sha256:f85330846cde1e57ca9ec309382da3b8e6ae3ab943d2739500e08c86393a21b1 AS build WORKDIR /application COPY . ./ ARG TARGETOS diff --git a/README.md b/README.md index c873274..6c2c317 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # socket-proxy ## Latest image -- `wollomatic/socket-proxy:1.11.4` / `ghcr.io/wollomatic/socket-proxy:1.11.4` +- `wollomatic/socket-proxy:1.12.0` / `ghcr.io/wollomatic/socket-proxy:1.12.0` - `wollomatic/socket-proxy:1` / `ghcr.io/wollomatic/socket-proxy:1` > [!IMPORTANT] @@ -93,10 +93,12 @@ Use Go's regexp syntax to create the patterns for these parameters. To avoid ins Examples (command-line): + `'-allowGET=/v1\..{1,2}/(version|containers/.*|events.*)'` could be used for allowing access to the docker socket for Traefik v2. + `'-allowHEAD=.*'` allows all HEAD requests. ++ `'-allowGET=/version -allowGET=/_ping'` supports using `-allowGET` multiple times Examples (env variables): + `'SP_ALLOW_GET="/v1\..{1,2}/(version|containers/.*|events.*)"'` could be used for allowing access to the docker socket for Traefik v2. + `'SP_ALLOW_HEAD=".*"'` allows all HEAD requests. ++ `'SP_ALLOW_GET="/version" SP_ALLOW_GET_2="/_ping"'` supports multiple `SP_ALLOW_GET` entries For more information, refer to the [Go regexp documentation](https://golang.org/pkg/regexp/syntax/). @@ -104,6 +106,9 @@ An excellent online regexp tester is [regex101.com](https://regex101.com/). To determine which HTTP requests your client application uses, you could switch socket-proxy to debug log level and look at the log output while allowing all requests in a secure environment. +> [!NOTE] +> Starting with version 1.12.0, the socket-proxy supports using multiple -allow* entries in params, environment, or docker labels. + #### Setting up bind mount restrictions By default, socket-proxy does not restrict bind mounts. If you want to add an additional layer of security by restricting which directories can be used as bind mount sources, you can use the `-allowbindmountfrom` parameter or the `SP_ALLOWBINDMOUNTFROM` environment variable. @@ -135,6 +140,8 @@ services: - docker-proxynet # this should be only restricted to traefik and socket-proxy labels: - 'socket-proxy.allow.get=.*' # allow all GET requests to socket-proxy + - 'socket-proxy.allow.head=/version' # HEAD `/version` requests to socket-proxy + - 'socket-proxy.allow.head.1=/exec' # another HEAD `exec` requests to socket-proxy ``` When this is used, it is not necessary to specify the container in `-allowfrom` as the presence of the allowlist labels will grant corresponding access. @@ -235,7 +242,7 @@ socket-proxy can be configured via command-line parameters or via environment va | `-logjson` | `SP_LOGJSON` | (not set/false) | If set, it enables logging in JSON format. If unset, socket-proxy logs in plain text format. | | `-loglevel` | `SP_LOGLEVEL` | `INFO` | Sets the log level. Accepted values are: `DEBUG`, `INFO`, `WARN`, `ERROR`. | | `-proxyport` | `SP_PROXYPORT` | `2375` | Defines the TCP port the proxy listens to. | -| `-shutdowngracetime` | `SP_SHUTDOWNGRACETIME` | `10` | Defines the time in seconds to wait before forcing the shutdown after SIGTERM or SIGINT (socket-proxy first tries to gracefully shut down the TCP server) | | +| `-shutdowngracetime` | `SP_SHUTDOWNGRACETIME` | `10` | Defines the time in seconds to wait before forcing the shutdown after SIGTERM or SIGINT (socket-proxy first tries to gracefully shut down the TCP server) | | `-socketpath` | `SP_SOCKETPATH` | `/var/run/docker.sock` | Specifies the UNIX socket path to connect to. By default, it connects to the Docker daemon socket. | | `-stoponwatchdog` | `SP_STOPONWATCHDOG` | (not set/false) | If set, socket-proxy will be stopped if the watchdog detects that the unix socket is not available. | | `-watchdoginterval` | `SP_WATCHDOGINTERVAL` | `0` | Check for socket availability every x seconds (disable checks, if not set or value is 0) | @@ -269,6 +276,7 @@ socket-proxy can be configured via command-line parameters or via environment va 1.11 - add per-container allowlists specified by Docker container labels (thanks [@amanda-wee](https://github.com/amanda-wee)) +1.12 - support use of allow* multiple times in env, flag and docker labels (thanks [@qianlongzt](https://github.com/qianlongzt)) ## License diff --git a/cmd/socket-proxy/handlehttprequest.go b/cmd/socket-proxy/handlehttprequest.go index 024b2df..f366135 100644 --- a/cmd/socket-proxy/handlehttprequest.go +++ b/cmd/socket-proxy/handlehttprequest.go @@ -5,6 +5,7 @@ import ( "log/slog" "net" "net/http" + "regexp" "github.com/wollomatic/socket-proxy/internal/config" ) @@ -24,7 +25,7 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request) { communicateBlockedRequest(w, r, "method not allowed", http.StatusMethodNotAllowed) return } - if !allowed.MatchString(r.URL.Path) { // path does not match regex -> not allowed + if !matchURL(allowed, r.URL.Path) { // path does not match regex -> not allowed communicateBlockedRequest(w, r, "path not allowed", http.StatusForbidden) return } @@ -40,6 +41,15 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request) { socketProxy.ServeHTTP(w, r) // #nosec G704 - Request target is always the specified socket } +func matchURL(allowedURIs []*regexp.Regexp, requestURI string) bool { + for _, allowedURI := range allowedURIs { + if allowedURI.MatchString(requestURI) { + return true + } + } + return false +} + // return the relevant allowlist func determineAllowList(r *http.Request) (config.AllowList, bool) { if cfg.ProxySocketEndpoint == "" { // do not perform this check if we proxy to a unix socket diff --git a/internal/config/config.go b/internal/config/config.go index 9d91f69..f929bf1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,7 +26,7 @@ import ( const allowedDockerLabelPrefix = "socket-proxy.allow." -var ( +const ( defaultAllowFrom = "127.0.0.1/32" // allowed IPs to connect to the proxy defaultAllowHealthcheck = false // allow health check requests (HEAD http://localhost:55555/health) defaultLogJSON = false // if true, log in JSON format @@ -67,130 +67,142 @@ type AllowListRegistry struct { } type AllowList struct { - ID string // Container ID (empty for the default allowlist) - AllowedRequests map[string]*regexp.Regexp // map of request methods to request path regex patterns (no requests allowed if empty) - AllowedBindMounts []string // list of from portion of allowed bind mounts (all bind mounts allowed if empty) + ID string // Container ID (empty for the default allowlist) + AllowedRequests map[string][]*regexp.Regexp // map of request methods to request path regex patterns (no requests allowed if empty) + AllowedBindMounts []string // list of from portion of allowed bind mounts (all bind mounts allowed if empty) } // used for list of allowed requests type methodRegex struct { - method string - regexStringFromEnv string - regexStringFromParam string + method string + regexStrings arrayParams } -// mr is the allowlist of requests per http method -// default: regexStringFromEnv and regexStringFromParam are empty, so regexCompiled stays nil and the request is blocked -// if regexStringParam is set with a command line parameter, all requests matching the method and path matching the regex are allowed -// else if regexStringEnv from Environment ist checked -var mr = []methodRegex{ - {method: http.MethodGet}, - {method: http.MethodHead}, - {method: http.MethodPost}, - {method: http.MethodPut}, - {method: http.MethodPatch}, - {method: http.MethodDelete}, - {method: http.MethodConnect}, - {method: http.MethodTrace}, - {method: http.MethodOptions}, +var supportedHTTPMethods = []string{ + http.MethodGet, + http.MethodHead, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + http.MethodConnect, + http.MethodTrace, + http.MethodOptions, } +// InitConfig reads configuration from environment variables and command-line +// flags, validates the resulting values, and returns the initialized Config. func InitConfig() (*Config, error) { var ( - cfg Config - allowFromString string - listenIP string - proxyPort uint - logLevel string - endpointFileMode uint - allowBindMountFromString string + cfg Config + allowFromString string + listenIP string + proxyPort uint + logLevel string + endpointFileMode uint + allowBindMountFromString string + defaultAllowFromValue = defaultAllowFrom + defaultAllowHealthcheckValue = defaultAllowHealthcheck + defaultLogJSONValue = defaultLogJSON + defaultListenIPValue = defaultListenIP + defaultLogLevelValue = defaultLogLevel + defaultProxyPortValue = defaultProxyPort + defaultShutdownGraceTimeValue = defaultShutdownGraceTime + defaultSocketPathValue = defaultSocketPath + defaultStopOnWatchdogValue = defaultStopOnWatchdog + defaultWatchdogIntervalValue = defaultWatchdogInterval + defaultProxySocketEndpointValue = defaultProxySocketEndpoint + defaultProxySocketEndpointFileModeValue = defaultProxySocketEndpointFileMode + defaultAllowBindMountFromValue = defaultAllowBindMountFrom + defaultProxyContainerNameValue = defaultProxyContainerName ) if val, ok := os.LookupEnv("SP_ALLOWFROM"); ok && val != "" { - defaultAllowFrom = val + defaultAllowFromValue = val } if val, ok := os.LookupEnv("SP_ALLOWHEALTHCHECK"); ok { if parsedVal, err := strconv.ParseBool(val); err == nil { - defaultAllowHealthcheck = parsedVal + defaultAllowHealthcheckValue = parsedVal } } if val, ok := os.LookupEnv("SP_LOGJSON"); ok { if parsedVal, err := strconv.ParseBool(val); err == nil { - defaultLogJSON = parsedVal + defaultLogJSONValue = parsedVal } } if val, ok := os.LookupEnv("SP_LISTENIP"); ok && val != "" { - defaultListenIP = val + defaultListenIPValue = val } if val, ok := os.LookupEnv("SP_LOGLEVEL"); ok && val != "" { - defaultLogLevel = val + defaultLogLevelValue = val } if val, ok := os.LookupEnv("SP_PROXYPORT"); ok && val != "" { if parsedVal, err := strconv.ParseUint(val, 10, 32); err == nil { - defaultProxyPort = uint(parsedVal) + defaultProxyPortValue = uint(parsedVal) } } if val, ok := os.LookupEnv("SP_SHUTDOWNGRACETIME"); ok && val != "" { if parsedVal, err := strconv.ParseUint(val, 10, 32); err == nil { - defaultShutdownGraceTime = uint(parsedVal) + defaultShutdownGraceTimeValue = uint(parsedVal) } } if val, ok := os.LookupEnv("SP_SOCKETPATH"); ok && val != "" { - defaultSocketPath = val + defaultSocketPathValue = val } if val, ok := os.LookupEnv("SP_STOPONWATCHDOG"); ok { if parsedVal, err := strconv.ParseBool(val); err == nil { - defaultStopOnWatchdog = parsedVal + defaultStopOnWatchdogValue = parsedVal } } if val, ok := os.LookupEnv("SP_WATCHDOGINTERVAL"); ok && val != "" { if parsedVal, err := strconv.ParseUint(val, 10, 32); err == nil { - defaultWatchdogInterval = uint(parsedVal) + defaultWatchdogIntervalValue = uint(parsedVal) } } if val, ok := os.LookupEnv("SP_PROXYSOCKETENDPOINT"); ok && val != "" { - defaultProxySocketEndpoint = val + defaultProxySocketEndpointValue = val } if val, ok := os.LookupEnv("SP_PROXYSOCKETENDPOINTFILEMODE"); ok { if parsedVal, err := strconv.ParseUint(val, 8, 32); err == nil { - defaultProxySocketEndpointFileMode = uint(parsedVal) + defaultProxySocketEndpointFileModeValue = uint(parsedVal) } } if val, ok := os.LookupEnv("SP_ALLOWBINDMOUNTFROM"); ok && val != "" { - defaultAllowBindMountFrom = val + defaultAllowBindMountFromValue = val } if val, ok := os.LookupEnv("SP_PROXYCONTAINERNAME"); ok && val != "" { - defaultProxyContainerName = val + defaultProxyContainerNameValue = val } - for i := range mr { - if val, ok := os.LookupEnv("SP_ALLOW_" + mr[i].method); ok && val != "" { - mr[i].regexStringFromEnv = val + methodAllowLists := newMethodRegexes() + + // multiple values per method + // like SP_ALLOW_GET_0, SP_ALLOW_GET_1, ... + allowFromEnv := getAllowFromEnv(os.Environ()) + for i := range methodAllowLists { + if val, ok := allowFromEnv[methodAllowLists[i].method]; ok && len(val) > 0 { + for _, v := range val { + methodAllowLists[i].regexStrings = append(methodAllowLists[i].regexStrings, param{value: v, from: fromEnv}) + } } } - flag.StringVar(&allowFromString, "allowfrom", defaultAllowFrom, "allowed IPs or hostname to connect to the proxy") - flag.BoolVar(&cfg.AllowHealthcheck, "allowhealthcheck", defaultAllowHealthcheck, "allow health check requests (HEAD http://localhost:55555/health)") - flag.BoolVar(&cfg.LogJSON, "logjson", defaultLogJSON, "log in JSON format (otherwise log in plain text") - flag.StringVar(&listenIP, "listenip", defaultListenIP, "ip address to listen on") - flag.StringVar(&logLevel, "loglevel", defaultLogLevel, "set log level: DEBUG, INFO, WARN, ERROR") - flag.UintVar(&proxyPort, "proxyport", defaultProxyPort, "tcp port to listen on") - flag.UintVar(&cfg.ShutdownGraceTime, "shutdowngracetime", defaultShutdownGraceTime, "maximum time in seconds to wait for the server to shut down gracefully") - if cfg.ShutdownGraceTime > math.MaxInt { - return nil, fmt.Errorf("shutdowngracetime has to be smaller than %d", math.MaxInt) // this maximum value has no practical significance - } - flag.StringVar(&cfg.SocketPath, "socketpath", defaultSocketPath, "unix socket path to connect to") - flag.BoolVar(&cfg.StopOnWatchdog, "stoponwatchdog", defaultStopOnWatchdog, "stop the program when the socket gets unavailable (otherwise log only)") - flag.UintVar(&cfg.WatchdogInterval, "watchdoginterval", defaultWatchdogInterval, "watchdog interval in seconds (0 to disable)") - if cfg.WatchdogInterval > math.MaxInt { - return nil, fmt.Errorf("watchdoginterval has to be smaller than %d", math.MaxInt) // this maximum value has no practical significance - } - flag.StringVar(&cfg.ProxySocketEndpoint, "proxysocketendpoint", defaultProxySocketEndpoint, "unix socket endpoint (if set, used instead of the TCP listener)") - flag.UintVar(&endpointFileMode, "proxysocketendpointfilemode", defaultProxySocketEndpointFileMode, "set the file mode of the unix socket endpoint") - flag.StringVar(&allowBindMountFromString, "allowbindmountfrom", defaultAllowBindMountFrom, "allowed directories for bind mounts (comma-separated)") - flag.StringVar(&cfg.ProxyContainerName, "proxycontainername", defaultProxyContainerName, "socket-proxy Docker container name") - for i := range mr { - flag.StringVar(&mr[i].regexStringFromParam, "allow"+mr[i].method, "", "regex for "+mr[i].method+" requests (not set means method is not allowed)") + flag.StringVar(&allowFromString, "allowfrom", defaultAllowFromValue, "allowed IPs or hostname to connect to the proxy") + flag.BoolVar(&cfg.AllowHealthcheck, "allowhealthcheck", defaultAllowHealthcheckValue, "allow health check requests (HEAD http://localhost:55555/health)") + flag.BoolVar(&cfg.LogJSON, "logjson", defaultLogJSONValue, "log in JSON format (otherwise log in plain text") + flag.StringVar(&listenIP, "listenip", defaultListenIPValue, "ip address to listen on") + flag.StringVar(&logLevel, "loglevel", defaultLogLevelValue, "set log level: DEBUG, INFO, WARN, ERROR") + flag.UintVar(&proxyPort, "proxyport", defaultProxyPortValue, "tcp port to listen on") + flag.UintVar(&cfg.ShutdownGraceTime, "shutdowngracetime", defaultShutdownGraceTimeValue, "maximum time in seconds to wait for the server to shut down gracefully") + flag.StringVar(&cfg.SocketPath, "socketpath", defaultSocketPathValue, "unix socket path to connect to") + flag.BoolVar(&cfg.StopOnWatchdog, "stoponwatchdog", defaultStopOnWatchdogValue, "stop the program when the socket gets unavailable (otherwise log only)") + flag.UintVar(&cfg.WatchdogInterval, "watchdoginterval", defaultWatchdogIntervalValue, "watchdog interval in seconds (0 to disable)") + flag.StringVar(&cfg.ProxySocketEndpoint, "proxysocketendpoint", defaultProxySocketEndpointValue, "unix socket endpoint (if set, used instead of the TCP listener)") + flag.UintVar(&endpointFileMode, "proxysocketendpointfilemode", defaultProxySocketEndpointFileModeValue, "set the file mode of the unix socket endpoint") + flag.StringVar(&allowBindMountFromString, "allowbindmountfrom", defaultAllowBindMountFromValue, "allowed directories for bind mounts (comma-separated)") + flag.StringVar(&cfg.ProxyContainerName, "proxycontainername", defaultProxyContainerNameValue, "socket-proxy Docker container name") + for i := range methodAllowLists { + flag.Var(&methodAllowLists[i].regexStrings, "allow"+methodAllowLists[i].method, "regex for "+methodAllowLists[i].method+" requests (not set means method is not allowed)") } flag.Parse() @@ -213,6 +225,12 @@ func InitConfig() (*Config, error) { if proxyPort < 1 || proxyPort > 65535 { return nil, errors.New("port number has to be between 1 and 65535") } + if cfg.ShutdownGraceTime > math.MaxInt { + return nil, fmt.Errorf("shutdowngracetime has to be smaller than %d", math.MaxInt) // this maximum value has no practical significance + } + if cfg.WatchdogInterval > math.MaxInt { + return nil, fmt.Errorf("watchdoginterval has to be smaller than %d", math.MaxInt) // this maximum value has no practical significance + } ip := net.ParseIP(listenIP) if ip == nil { return nil, fmt.Errorf("invalid IP \"%s\" for listenip", listenIP) @@ -245,20 +263,23 @@ func InitConfig() (*Config, error) { cfg.ProxySocketEndpointFileMode = os.FileMode(uint32(endpointFileMode)) // compile regexes for default allowed requests - cfg.AllowLists.Default.AllowedRequests = make(map[string]*regexp.Regexp) - for _, rx := range mr { - if rx.regexStringFromParam != "" { - r, err := compileRegexp(rx.regexStringFromParam, rx.method, "command line parameter") - if err != nil { - return nil, err - } - cfg.AllowLists.Default.AllowedRequests[rx.method] = r - } else if rx.regexStringFromEnv != "" { - r, err := compileRegexp(rx.regexStringFromEnv, rx.method, "env variable") - if err != nil { - return nil, err + cfg.AllowLists.Default.AllowedRequests = make(map[string][]*regexp.Regexp) + for _, rx := range methodAllowLists { + for _, regexString := range effectiveMethodParams(rx.regexStrings) { + if regexString.value != "" { + location := "" + switch regexString.from { + case fromEnv: + location = "env variable" + case fromParam: + location = "command line parameter" + } + r, err := compileRegexp(regexString.value, rx.method, location) + if err != nil { + return nil, err + } + cfg.AllowLists.Default.AllowedRequests[rx.method] = append(cfg.AllowLists.Default.AllowedRequests[rx.method], r) } - cfg.AllowLists.Default.AllowedRequests[rx.method] = r } } @@ -287,7 +308,12 @@ func (cfg *Config) UpdateAllowLists() { slog.Error("failed to create Docker client", "error", err) return } - defer dockerClient.Close() + defer func(dockerClient *client.Client) { + err := dockerClient.Close() + if err != nil { + slog.Error("failed to close Docker client", "error", err) + } + }(dockerClient) err = cfg.AllowLists.initByIP(ctx, dockerClient) if err != nil { @@ -575,6 +601,25 @@ func compileRegexp(regex, method, configLocation string) (*regexp.Regexp, error) return r, nil } +// newMethodRegexes returns one methodRegex entry for each supported HTTP method. +func newMethodRegexes() []methodRegex { + methods := make([]methodRegex, 0, len(supportedHTTPMethods)) + for _, method := range supportedHTTPMethods { + methods = append(methods, methodRegex{method: method}) + } + return methods +} + +// effectiveMethodParams returns the parameters that should be applied for one +// HTTP method, preferring command-line values over environment values when both +// are present. +func effectiveMethodParams(params arrayParams) []param { + if slices.ContainsFunc(params, func(p param) bool { return p.from == fromParam }) { + return slices.DeleteFunc(slices.Clone(params), func(p param) bool { return p.from == fromEnv }) + } + return params +} + // parse bind mount from string into list of allowed bind mounts func parseAllowedBindMounts(allowBindMountFromString string) ([]string, error) { allowedBindMounts := strings.Split(allowBindMountFromString, ",") @@ -612,7 +657,12 @@ func getSocketProxyContainerSummary(socketPath, proxyContainerName string) (cont if err != nil { return container.Summary{}, err } - defer dockerClient.Close() + defer func(dockerClient *client.Client) { + err := dockerClient.Close() + if err != nil { + slog.Error("failed to close Docker client", "error", err) + } + }(dockerClient) ctx := context.Background() filter := filters.NewArgs() @@ -634,18 +684,19 @@ func getSocketProxyContainerSummary(socketPath, proxyContainerName string) (cont } // extract Docker container allowlist label data from the container summary -func extractLabelData(cntr container.Summary) (map[string]*regexp.Regexp, []string, error) { - allowedRequests := make(map[string]*regexp.Regexp) +func extractLabelData(cntr container.Summary) (map[string][]*regexp.Regexp, []string, error) { + allowedRequests := make(map[string][]*regexp.Regexp) var allowedBindMounts []string for labelName, labelValue := range cntr.Labels { if strings.HasPrefix(labelName, allowedDockerLabelPrefix) && labelValue != "" { allowSpec := strings.ToUpper(strings.TrimPrefix(labelName, allowedDockerLabelPrefix)) - if slices.ContainsFunc(mr, func(rx methodRegex) bool { return rx.method == allowSpec }) { - r, err := compileRegexp(labelValue, allowSpec, "docker container label") + method, _, _ := strings.Cut(allowSpec, ".") + if slices.Contains(supportedHTTPMethods, method) { + r, err := compileRegexp(labelValue, method, "docker container label") if err != nil { return nil, nil, err } - allowedRequests[allowSpec] = r + allowedRequests[method] = append(allowedRequests[method], r) } else if allowSpec == "BINDMOUNTFROM" { var err error allowedBindMounts, err = parseAllowedBindMounts(labelValue) diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..116408a --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,177 @@ +package config + +import ( + "flag" + "math" + "os" + "reflect" + "regexp" + "sort" + "strconv" + "testing" + + "github.com/wollomatic/socket-proxy/internal/docker/api/types/container" +) + +func resetFlagsForTest(t *testing.T, args []string) func() { + t.Helper() + + prevCommandLine := flag.CommandLine + prevArgs := os.Args + + flag.CommandLine = flag.NewFlagSet(args[0], flag.ContinueOnError) + flag.CommandLine.SetOutput(os.Stderr) + os.Args = args + + return func() { + flag.CommandLine = prevCommandLine + os.Args = prevArgs + } +} + +func Test_extractLabelData(t *testing.T) { + tests := []struct { + name string // description of this test case + // Named input parameters for target function. + cntr container.Summary + want map[string][]*regexp.Regexp + want2 []string + wantErr bool + }{ + { + name: "valid labels with multiple methods and regexes", + cntr: container.Summary{ + Labels: map[string]string{ + "socket-proxy.allow.get.0": "regex1", + "socket-proxy.allow.get.1": "regex2", + "socket-proxy.allow.post": "regex3", + }, + }, + want: map[string][]*regexp.Regexp{ + "GET": {regexp.MustCompile("^regex1$"), regexp.MustCompile("^regex2$")}, + "POST": {regexp.MustCompile("^regex3$")}, + }, + want2: nil, + wantErr: false, + }, + { + name: "invalid regex in label value", + cntr: container.Summary{ + Labels: map[string]string{ + "socket-proxy.allow.get": "invalid[regex", + }, + }, + want: nil, + want2: nil, + wantErr: true, + }, + { + name: "non-allow labels are ignored", + cntr: container.Summary{ + Labels: map[string]string{ + "socket-proxy.allow.get": "regex1", + "other.label": "value", + }, + }, + want: map[string][]*regexp.Regexp{ + "GET": {regexp.MustCompile("^regex1$")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got2, gotErr := extractLabelData(tt.cntr) + if gotErr != nil { + if !tt.wantErr { + t.Errorf("extractLabelData() failed: %v", gotErr) + } + return + } + if tt.wantErr { + t.Fatal("extractLabelData() succeeded unexpectedly") + } + if !regexMapsEqual(got, tt.want) { + t.Errorf("extractLabelData() = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(got2, tt.want2) { + t.Errorf("extractLabelData() = %v, want %v", got2, tt.want2) + } + }) + } +} + +func regexMapsEqual(a, b map[string][]*regexp.Regexp) bool { + if len(a) != len(b) { + return false + } + for method, aRegexes := range a { + bRegexes, ok := b[method] + if !ok || len(aRegexes) != len(bRegexes) { + return false + } + aRegexStrings := make([]string, 0, len(aRegexes)) + for _, ar := range aRegexes { + aRegexStrings = append(aRegexStrings, ar.String()) + } + bRegexStrings := make([]string, 0, len(bRegexes)) + for _, br := range bRegexes { + bRegexStrings = append(bRegexStrings, br.String()) + } + sort.Strings(aRegexStrings) + sort.Strings(bRegexStrings) + for i, ar := range aRegexStrings { + if ar != bRegexStrings[i] { + return false + } + } + } + return true +} + +func TestInitConfig_AllowMethodFlagOverridesEnv(t *testing.T) { + t.Setenv("SP_ALLOW_GET", "/from-env") + restore := resetFlagsForTest(t, []string{"socket-proxy", "-allowGET=/from-flag"}) + defer restore() + + cfg, err := InitConfig() + if err != nil { + t.Fatalf("InitConfig() error = %v", err) + } + + regexes := cfg.AllowLists.Default.AllowedRequests["GET"] + if len(regexes) != 1 { + t.Fatalf("expected 1 GET regex, got %d", len(regexes)) + } + if !regexes[0].MatchString("/from-flag") { + t.Fatalf("expected GET regex to match /from-flag, got %q", regexes[0].String()) + } + if regexes[0].MatchString("/from-env") { + t.Fatalf("expected env GET regex to be ignored when flag is present, got %q", regexes[0].String()) + } +} + +func TestInitConfig_ShutdownGraceTimeTooLarge(t *testing.T) { + restore := resetFlagsForTest(t, []string{ + "socket-proxy", + "-shutdowngracetime=" + strconv.FormatUint(uint64(math.MaxInt)+1, 10), + }) + defer restore() + + _, err := InitConfig() + if err == nil { + t.Fatal("InitConfig() unexpectedly succeeded") + } +} + +func TestInitConfig_WatchdogIntervalTooLarge(t *testing.T) { + restore := resetFlagsForTest(t, []string{ + "socket-proxy", + "-watchdoginterval=" + strconv.FormatUint(uint64(math.MaxInt)+1, 10), + }) + defer restore() + + _, err := InitConfig() + if err == nil { + t.Fatal("InitConfig() unexpectedly succeeded") + } +} diff --git a/internal/config/env.go b/internal/config/env.go new file mode 100644 index 0000000..a8727fa --- /dev/null +++ b/internal/config/env.go @@ -0,0 +1,28 @@ +package config + +import ( + "strings" +) + +const spAllowPrefix = "SP_ALLOW_" + +// getAllowFromEnv reads allowlist regex strings from environment variables. +// +// Environment variables should be of the form +// like SP_ALLOW_GET, SP_ALLOW_GET_0, SP_ALLOW_GET_1, SP_ALLOW_POST +// returning a map of method to list of regex strings. +// like: {"GET":[], "POST":[]} +func getAllowFromEnv(env []string) map[string][]string { + result := make(map[string][]string) + for _, v := range env { + if v, ok := strings.CutPrefix(v, spAllowPrefix); ok { + key, value, found := strings.Cut(v, "=") + if found { + // optional number suffix after method + method, _, _ := strings.Cut(key, "_") + result[method] = append(result[method], value) + } + } + } + return result +} diff --git a/internal/config/env_test.go b/internal/config/env_test.go new file mode 100644 index 0000000..aadb948 --- /dev/null +++ b/internal/config/env_test.go @@ -0,0 +1,49 @@ +package config + +import ( + "reflect" + "testing" +) + +func Test_getAllowFromEnv(t *testing.T) { + tests := []struct { + name string // description of this test case + // Named input parameters for target function. + env []string + want map[string][]string + }{ + { + name: "single method", + env: []string{"SP_ALLOW_GET=/allowed/path"}, + want: map[string][]string{"GET": {"/allowed/path"}}, + }, + { + name: "multiple methods", + env: []string{"SP_ALLOW_GET=/get/path", "SP_ALLOW_POST=/post/path"}, + want: map[string][]string{"GET": {"/get/path"}, "POST": {"/post/path"}}, + }, + { + name: "multiple entries for one method", + env: []string{"SP_ALLOW_GET=/path/one", "SP_ALLOW_GET_1=/path/two"}, + want: map[string][]string{"GET": {"/path/one", "/path/two"}}, + }, + { + name: "multiple entries for one method with non-sequential index", + env: []string{"SP_ALLOW_GET=/path/one", "SP_ALLOW_GET_2=/path/two"}, + want: map[string][]string{"GET": {"/path/one", "/path/two"}}, + }, + { + name: "no relevant env vars", + env: []string{"OTHER_ENV=some_value"}, + want: map[string][]string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getAllowFromEnv(tt.env) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getAllowFromEnv() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/config/param.go b/internal/config/param.go new file mode 100644 index 0000000..3408e04 --- /dev/null +++ b/internal/config/param.go @@ -0,0 +1,36 @@ +package config + +import ( + "flag" + "strings" +) + +type from int + +const ( + fromEnv from = 1 + fromParam from = 2 +) + +type param struct { + value string + from from +} + +type arrayParams []param + +// ensure that arrayParams implements the flag.Value interface +var _ flag.Value = (*arrayParams)(nil) + +func (a *arrayParams) String() string { + var values []string + for _, p := range *a { + values = append(values, p.value) + } + return strings.Join(values, ", ") +} + +func (a *arrayParams) Set(value string) error { + *a = append(*a, param{value: value, from: fromParam}) + return nil +} diff --git a/internal/docker/api/types/events/events.go b/internal/docker/api/types/events/events.go index 7f24c66..d0a5950 100644 --- a/internal/docker/api/types/events/events.go +++ b/internal/docker/api/types/events/events.go @@ -10,21 +10,6 @@ import "github.com/wollomatic/socket-proxy/internal/docker/api/types/filters" // Type is used for event-types. type Type string -// List of known event types. -const ( - BuilderEventType Type = "builder" // BuilderEventType is the event type that the builder generates. - ConfigEventType Type = "config" // ConfigEventType is the event type that configs generate. - ContainerEventType Type = "container" // ContainerEventType is the event type that containers generate. - DaemonEventType Type = "daemon" // DaemonEventType is the event type that daemon generate. - ImageEventType Type = "image" // ImageEventType is the event type that images generate. - NetworkEventType Type = "network" // NetworkEventType is the event type that networks generate. - NodeEventType Type = "node" // NodeEventType is the event type that nodes generate. - PluginEventType Type = "plugin" // PluginEventType is the event type that plugins generate. - SecretEventType Type = "secret" // SecretEventType is the event type that secrets generate. - ServiceEventType Type = "service" // ServiceEventType is the event type that services generate. - VolumeEventType Type = "volume" // VolumeEventType is the event type that volumes generate. -) - // Action is used for event-actions. type Action string diff --git a/internal/docker/api/types/filters/parse.go b/internal/docker/api/types/filters/parse.go index cafebde..e40ee9d 100644 --- a/internal/docker/api/types/filters/parse.go +++ b/internal/docker/api/types/filters/parse.go @@ -24,11 +24,6 @@ type KeyValuePair struct { Value string } -// Arg creates a new KeyValuePair for initializing Args -func Arg(key, value string) KeyValuePair { - return KeyValuePair{Key: key, Value: value} -} - // NewArgs returns a new Args populated with the initial args func NewArgs(initialArgs ...KeyValuePair) Args { args := Args{fields: map[string]map[string]bool{}} @@ -64,30 +59,6 @@ func ToJSON(a Args) (string, error) { return string(buf), err } -// FromJSON decodes a JSON encoded string into Args -func FromJSON(p string) (Args, error) { - args := NewArgs() - - if p == "" { - return args, nil - } - - raw := []byte(p) - err := json.Unmarshal(raw, &args) - if err == nil { - return args, nil - } - - // Fallback to parsing arguments in the legacy slice format - deprecated := map[string][]string{} - if legacyErr := json.Unmarshal(raw, &deprecated); legacyErr != nil { - return args, &invalidFilter{} - } - - args.fields = deprecatedArgs(deprecated) - return args, nil -} - // UnmarshalJSON populates the Args from JSON encode bytes func (args Args) UnmarshalJSON(raw []byte) error { return json.Unmarshal(raw, &args.fields) @@ -291,15 +262,3 @@ func (args Args) Clone() (newArgs Args) { } return newArgs } - -func deprecatedArgs(d map[string][]string) map[string]map[string]bool { - m := map[string]map[string]bool{} - for k, v := range d { - values := map[string]bool{} - for _, vv := range v { - values[vv] = true - } - m[k] = values - } - return m -} diff --git a/internal/go-connections/sockets/sockets.go b/internal/go-connections/sockets/sockets.go index 0d889e8..5c4b8e1 100644 --- a/internal/go-connections/sockets/sockets.go +++ b/internal/go-connections/sockets/sockets.go @@ -8,7 +8,6 @@ package sockets import ( "context" - "errors" "fmt" "net" "net/http" @@ -21,9 +20,6 @@ const ( maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) ) -// ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system. -var ErrProtocolNotAvailable = errors.New("protocol not available") - // ConfigureTransport configures the specified [http.Transport] according to the specified proto // and addr. //