diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 4389c21a..6b2c5161 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -41,7 +41,6 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: 'go.mod' - architecture: 'x64' - name: Setup prerequisites if: ${{ matrix.prereqsCommand != '' }} diff --git a/AGENTS.md b/AGENTS.md index 98683a51..9fb63e05 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,6 +38,9 @@ This codebase implements several custom Kubernetes types and controllers. Implem ### Avoid variable reuse (especially for errors) - If a function invokes multiple error-returning functions, use a different variable name for each error to avoid confusion. +### Use sync.Mutex as pointers +- In this codebase, `sync.Mutex` instances are used as pointers (`*sync.Mutex`). Create them with `&sync.Mutex{}` and pass them around as pointer values. + ## Adhere to Code Placement Rules Place new code in the correct location according to the project's structure: - **API Definitions:** Go in `api/v1/`. @@ -55,6 +58,9 @@ Place new code in the correct location according to the project's structure: - Use `OpenFile()`, or (for temporary files) `OpenTempFile()` functions from github.com/microsoft/dcp/pkg/io package to open files. This function takes care of using appropriate file permissions in a cross-platform way. - Always close files after no longer needed, either by calling `Close()` from the method that opened the file (with `defer` statement), or when the lifetime context.Context of the file owner expires. +## Test patterns +- Avoid usage of time.Sleep in tests to enforce timing. Use test helpers and synchronization primitives to make the timing as deterministic as possible to avoid non-deterministic test failures. + ## Code generation - Run `make generate` after making changes to API definitions (files under `api/v1` folder). - Run `make generate-grpc` after making changes to protobuf definitions (files with `.proto` extension). diff --git a/Makefile b/Makefile index 7becf395..cbb594d7 100644 --- a/Makefile +++ b/Makefile @@ -114,6 +114,7 @@ DELAY_TOOL ?= $(TOOL_BIN)/delay$(exe_suffix) LFWRITER_TOOL ?= $(TOOL_BIN)/lfwriter$(exe_suffix) PARROT_TOOL ?= $(TOOL_BIN)/parrot$(exe_suffix) PARROT_TOOL_CONTAINER_BINARY ?= $(TOOL_BIN)/parrot_c +DEBUGGEE_TOOL ?= $(TOOL_BIN)/debuggee$(exe_suffix) GO_LICENSES ?= $(TOOL_BIN)/go-licenses$(exe_suffix) PROTOC ?= $(TOOL_BIN)/protoc/bin/protoc$(exe_suffix) @@ -371,9 +372,9 @@ endif ##@ Test targets ifeq (4.4,$(firstword $(sort $(MAKE_VERSION) 4.4))) -TEST_PREREQS := generate-grpc .WAIT build-dcp build-dcptun-containerexe delay-tool lfwriter-tool parrot-tool parrot-tool-containerexe +TEST_PREREQS := generate-grpc .WAIT build-dcp build-dcptun-containerexe delay-tool lfwriter-tool parrot-tool parrot-tool-containerexe debuggee-tool cache-delve else -TEST_PREREQS := generate-grpc build-dcp build-dcptun-containerexe delay-tool lfwriter-tool parrot-tool parrot-tool-containerexe +TEST_PREREQS := generate-grpc build-dcp build-dcptun-containerexe delay-tool lfwriter-tool parrot-tool parrot-tool-containerexe debuggee-tool cache-delve endif .PHONY: test-prereqs @@ -396,7 +397,7 @@ test: test-prereqs ## Run all tests in the repository .PHONY: test-ci test-ci: test-ci-prereqs ## Runs tests in a way appropriate for CI pipeline, with linting etc. - $(GO_BIN) test ./... $(TEST_OPTS) + $(GO_BIN) test -tags integration ./... $(TEST_OPTS) ## Development and test support targets @@ -466,6 +467,18 @@ else GOOS=linux $(GO_BIN) build -o $(PARROT_TOOL_CONTAINER_BINARY) github.com/microsoft/dcp/test/parrot endif +# debuggee tool is used for DAP proxy integration testing. +# CLEAR_GOARGS ensures it is built for the native architecture (required for Delve debugging). +.PHONY: debuggee-tool +debuggee-tool: $(DEBUGGEE_TOOL) +$(DEBUGGEE_TOOL): $(wildcard ./test/debuggee/*.go) | $(TOOL_BIN) + $(CLEAR_GOARGS) $(GO_BIN) build -gcflags="all=-N -l" -o $(DEBUGGEE_TOOL) github.com/microsoft/dcp/test/debuggee + +# cache-delve ensures the Delve debugger is downloaded for DAP tests +.PHONY: cache-delve +cache-delve: + @$(CLEAR_GOARGS) $(GOTOOL_BIN) github.com/go-delve/delve/cmd/dlv version + .PHONY: httpcontent-stream-repro httpcontent-stream-repro: dotnet build test/HttpContentStreamRepro.Server/HttpContentStreamRepro.Server.csproj diff --git a/NOTICE b/NOTICE index 4a9daa56..41784196 100644 --- a/NOTICE +++ b/NOTICE @@ -3440,6 +3440,216 @@ https://github.com/google/gnostic-models/blob/v0.7.0/LICENSE ---------------------------------------------------------- +github.com/google/go-dap v0.12.0 - Apache-2.0 +https://github.com/google/go-dap/blob/v0.12.0/LICENSE + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +---------------------------------------------------------- + +---------------------------------------------------------- + github.com/google/pprof/profile v0.0.0-20241029153458-d1b30febd7db - Apache-2.0 https://github.com/google/pprof/blob/d1b30febd7db/LICENSE diff --git a/controllers/container_network_tunnel_proxy_controller.go b/controllers/container_network_tunnel_proxy_controller.go index 22494cfd..6ec18787 100644 --- a/controllers/container_network_tunnel_proxy_controller.go +++ b/controllers/container_network_tunnel_proxy_controller.go @@ -1259,7 +1259,7 @@ func (r *ContainerNetworkTunnelProxyReconciler) startServerProxy( r.onServerProcessExit(tunnelProxy.NamespacedName(), pid, exitCode, err, stdoutFile, stderrFile) }) - pid, startTime, startWaitForExit, startErr := r.config.ProcessExecutor.StartProcess(context.Background(), cmd, exitHandler, process.CreationFlagsNone) + handle, startWaitForExit, startErr := r.config.ProcessExecutor.StartProcess(context.Background(), cmd, exitHandler, process.CreationFlagsNone) if startErr != nil { log.Error(startErr, "Failed to start server proxy process") startFailed = true @@ -1273,7 +1273,7 @@ func (r *ContainerNetworkTunnelProxyReconciler) startServerProxy( tc, tcErr := readServerProxyConfig(ctx, stdoutFile.Name()) if tcErr != nil { log.Error(tcErr, "Failed to read connection information from the server proxy") - stopProcessErr := r.config.ProcessExecutor.StopProcess(pid, startTime) + stopProcessErr := r.config.ProcessExecutor.StopProcess(handle) if stopProcessErr != nil { log.Error(stopProcessErr, "Failed to stop server proxy process after being unable to read its configuration") } @@ -1281,11 +1281,11 @@ func (r *ContainerNetworkTunnelProxyReconciler) startServerProxy( return false } - dcpproc.RunProcessWatcher(r.config.ProcessExecutor, pid, startTime, log) + dcpproc.RunProcessWatcher(r.config.ProcessExecutor, handle, log) - pointers.SetValue(&pd.ServerProxyProcessID, int64(pid)) + pointers.SetValue(&pd.ServerProxyProcessID, int64(handle.Pid)) pd.ServerProxyControlPort = tc.ServerControlPort - pd.ServerProxyStartupTimestamp = metav1.NewMicroTime(startTime) + pd.ServerProxyStartupTimestamp = metav1.NewMicroTime(handle.IdentityTime) pd.ServerProxyStdOutFile = stdoutFile.Name() pd.ServerProxyStdErrFile = stderrFile.Name() @@ -1377,7 +1377,7 @@ func (r *ContainerNetworkTunnelProxyReconciler) cleanupProxyPair( // The process may have already exited because the client container has been stopped. - stopErr := r.config.ProcessExecutor.StopProcess(pid, startTime) + stopErr := r.config.ProcessExecutor.StopProcess(process.NewHandle(pid, startTime)) if stopErr != nil && !errors.Is(stopErr, process.ErrorProcessNotFound) { log.Error(stopErr, "Failed to stop server proxy process") } else { diff --git a/controllers/controller_harvest.go b/controllers/controller_harvest.go index dbb46927..4d93dc29 100644 --- a/controllers/controller_harvest.go +++ b/controllers/controller_harvest.go @@ -277,19 +277,19 @@ func (rh *resourceHarvester) harvestAbandonedNetworks( return removeErr } -func (rh *resourceHarvester) isRunningDCPProcess(pid process.Pid_t, startTime time.Time) bool { - if running, exists := rh.processes[pid]; exists { +func (rh *resourceHarvester) isRunningDCPProcess(handle process.ProcessHandle) bool { + if running, exists := rh.processes[handle.Pid]; exists { return running } // If the process is not in the cache, we need to check if it is running. - _, findErr := process.FindProcess(pid, startTime) + _, findErr := process.FindProcess(handle) if findErr != nil { return false // Process not found, so it's not running. } // We found the process, so cache it as running. - rh.processes[pid] = true + rh.processes[handle.Pid] = true return true } @@ -299,7 +299,7 @@ func (rh *resourceHarvester) creatorStillRunning(labels map[string]string) bool creatorPID, _ := process.StringToPidT(labels[CreatorProcessIdLabel]) creatorStartTime, _ := time.Parse(osutil.RFC3339MiliTimestampFormat, labels[CreatorProcessStartTimeLabel]) - return rh.isRunningDCPProcess(creatorPID, creatorStartTime) + return rh.isRunningDCPProcess(process.NewHandle(creatorPID, creatorStartTime)) } // Checks for the presence of the creator process ID and start time labels. diff --git a/controllers/executable_controller.go b/controllers/executable_controller.go index b9424493..c789be2f 100644 --- a/controllers/executable_controller.go +++ b/controllers/executable_controller.go @@ -289,6 +289,7 @@ func ensureExecutableRunningState( // Ensure the status matches the current state. change |= runInfo.ApplyTo(exe, log) r.enableEndpointsAndHealthProbes(ctx, exe, runInfo, log) + return change } @@ -340,6 +341,7 @@ func ensureExecutableFinalState( change |= runInfo.ApplyTo(exe, log) // Ensure the status matches the current state. r.disableEndpointsAndHealthProbes(ctx, exe, runInfo, log) + return change } @@ -885,12 +887,6 @@ func (r *ExecutableReconciler) validateExistingEndpoints( return existing, nil, nil } -// Environment variables starting with these prefixes will never be applied to Executables. -var suppressVarPrefixes = []string{ - "DEBUG_SESSION", - "DCP_", -} - // Computes the effective set of environment variables for the Executable run and stores it in Status.EffectiveEnv. func (r *ExecutableReconciler) computeEffectiveEnvironment( ctx context.Context, @@ -900,20 +896,12 @@ func (r *ExecutableReconciler) computeEffectiveEnvironment( ) error { // Start with ambient environment. var envMap maps.StringKeyMap[string] - if osutil.IsWindows() { - envMap = maps.NewStringKeyMap[string](maps.StringMapModeCaseInsensitive) - } else { - envMap = maps.NewStringKeyMap[string](maps.StringMapModeCaseSensitive) - } switch exe.Spec.AmbientEnvironment.Behavior { case "", apiv1.EnvironmentBehaviorInherit: - envMap.Apply(maps.SliceToMap(os.Environ(), func(envStr string) (string, string) { - parts := strings.SplitN(envStr, "=", 2) - return parts[0], parts[1] - })) + envMap = osutil.NewFilteredAmbientEnv() case apiv1.EnvironmentBehaviorDoNotInherit: - // Noop + envMap = osutil.NewPlatformStringMap[string]() default: return fmt.Errorf("unknown environment behavior: %s", exe.Spec.AmbientEnvironment.Behavior) } @@ -949,9 +937,7 @@ func (r *ExecutableReconciler) computeEffectiveEnvironment( envMap.Set(key, effectiveValue) } - for _, prefix := range suppressVarPrefixes { - envMap.DeletePrefix(prefix) - } + osutil.SuppressEnvVarPrefixes(envMap) exe.Status.EffectiveEnv = maps.MapToSlice[apiv1.EnvVar](envMap.Data(), func(key string, value string) apiv1.EnvVar { return apiv1.EnvVar{Name: key, Value: value} diff --git a/debug-bridge-aspire-plan.md b/debug-bridge-aspire-plan.md new file mode 100644 index 00000000..0c0292a7 --- /dev/null +++ b/debug-bridge-aspire-plan.md @@ -0,0 +1,628 @@ +# Implement 2026-02-01 Debug Bridge Protocol in dotnet/aspire + +## TL;DR + +DCP now supports a "debug bridge" mode (protocol version `2026-02-01`) where it launches debug adapters and proxies DAP messages through a Unix domain socket. Instead of VS Code launching its own debug adapter process, it connects to DCP's bridge socket, tells DCP which adapter to launch (via a length-prefixed JSON handshake), and then communicates DAP messages through that same socket. This requires changes to the IDE execution spec, the VS Code extension's session endpoint, debug adapter descriptor factory, and protocol capabilities. + +Currently, `protocols_supported` tops out at `"2025-10-01"`. No `2026-02-01` or `debug_bridge` references exist anywhere in the aspire repo. + +--- + +## Architecture + +``` +┌──────────────────────────────────────────────────────────────────────────┐ +│ IDE (VS Code) │ +│ └─ Debug Adapter Client │ +│ └─ Connects to Unix socket provided by DCP in run session response │ +└──────────────────────────────────┬───────────────────────────────────────┘ + │ DAP messages (Unix socket) + │ + Initial handshake (token + session ID + run ID + adapter config) + ▼ +┌──────────────────────────────────────────────────────────────────────────┐ +│ DCP DAP Bridge (BridgeManager + DapBridge) │ +│ ├─ PrivateUnixSocketListener for IDE connections │ +│ ├─ Handshake validation (session ID + token) │ +│ ├─ Sequence number remapping (IDE ↔ Adapter seq isolation) │ +│ ├─ RawMessage forwarding (transparent proxy for unknown DAP messages) │ +│ ├─ Interception layer: │ +│ │ ├─ initialize: ensure supportsRunInTerminalRequest = true │ +│ │ ├─ runInTerminal: handle locally, launch process, capture stdio │ +│ │ └─ output events: capture for logging (unless runInTerminal used) │ +│ ├─ Inline runInTerminal handling (exec.Command via process.Executor) │ +│ └─ Output routing (BridgeConnectionHandler → OutputHandler + writers) │ +└──────────────────────────────────┬───────────────────────────────────────┘ + │ DAP messages (stdio/TCP) + ▼ +┌──────────────────────────────────────────────────────────────────────────┐ +│ Debug Adapter (launched by DCP) │ +│ └─ coreclr, debugpy, etc. │ +└──────────────────────────────────────────────────────────────────────────┘ +``` + +### How It Differs from the Current Flow + +| Aspect | Current (no bridge) | New (bridge mode, 2026-02-01+) | +|--------|---------------------|-------------------------------| +| Who launches the debug adapter | VS Code (via `vscode.debug.startDebugging`) | DCP (via bridge, using config from IDE) | +| DAP transport | VS Code manages directly | Unix socket through DCP bridge | +| `runInTerminal` handling | VS Code handles | DCP handles locally (IDE never sees it) | +| stdout/stderr capture | Adapter tracker sends `serviceLogs` | DCP captures from process pipes or output events | +| IDE role | Full debug orchestrator | DAP client connected through bridge socket | + +--- + +## Step-by-Step Implementation + +### Step 1: Update the IDE Execution Spec + +**File:** `docs/specs/IDE-execution.md` + +Add the `2026-02-01` protocol version under **Protocol Versioning → Well-known protocol versions**: + +> **`2026-02-01`** +> Changes: +> - Adds debug bridge support. When this version (or later) is negotiated, the `PUT /run_session` payload may include `debug_bridge_socket_path` and `debug_session_id` fields. + +Add the two new fields to the **Create Session Request** payload documentation: + +| Property | Description | Type | +|----------|-------------|------| +| `debug_bridge_socket_path` | Unix domain socket path that the IDE should connect to for DAP bridging. Present only when API version ≥ `2026-02-01`. | `string` (optional) | +| `debug_session_id` | A unique session identifier the IDE must include in the debug bridge handshake. | `string` (optional) | + +Add a new section **"Debug Bridge Protocol"** describing the full protocol (see [Appendix A](#appendix-a-debug-bridge-protocol-specification) below for the complete spec text). + +--- + +### Step 2: Update Protocol Capabilities + +**File:** `extension/src/capabilities.ts` (~line 55) + +Add `"2026-02-01"` to the `protocols_supported` array: + +```ts +export function getRunSessionInfo(): RunSessionInfo { + return { + protocols_supported: ["2024-03-03", "2024-04-23", "2025-10-01", "2026-02-01"], + supported_launch_configurations: getSupportedCapabilities() + }; +} +``` + +--- + +### Step 3: Update TypeScript Types + +**File:** `extension/src/dcp/types.ts` + +Add the new fields to the run session payload type, and add new types for the handshake: + +```ts +// Add to existing RunSessionPayload (or equivalent) interface: +debug_bridge_socket_path?: string; +debug_session_id?: string; + +// New types for the bridge protocol: +export interface DebugAdapterConfig { + args: string[]; + mode?: "stdio" | "tcp-callback" | "tcp-connect"; + env?: Array<{ name: string; value: string }>; + connectionTimeout?: string; // Go duration format, e.g. "10s" +} + +export interface DebugBridgeHandshakeRequest { + token: string; + session_id: string; + run_id?: string; + debug_adapter_config: DebugAdapterConfig; +} + +export interface DebugBridgeHandshakeResponse { + success: boolean; + error?: string; +} +``` + +--- + +### Step 4: Create a Debug Bridge Client Module + +**New file:** `extension/src/debugger/debugBridgeClient.ts` + +Implement the IDE side of the bridge connection: + +```ts +export async function connectToDebugBridge( + socketPath: string, + token: string, + sessionId: string, + runId: string, + adapterConfig: DebugAdapterConfig +): Promise +``` + +This function should: + +1. Connect to the Unix domain socket at `socketPath` using `net.connect({ path: socketPath })` +2. Send the handshake request as **length-prefixed JSON**: + - Write a 4-byte big-endian `uint32` containing the JSON payload length + - Write the UTF-8 encoded JSON bytes of `DebugBridgeHandshakeRequest` (including `run_id` for output routing) +3. Read the handshake response: + - Read 4 bytes → big-endian `uint32` length + - Read that many bytes → parse as `DebugBridgeHandshakeResponse` +4. If `success === true`, return the connected socket +5. If `success === false`, throw an error with the `error` message + +**Important constraints:** +- Max handshake message size: **64 KB** (65536 bytes) +- Handshake timeout: **30 seconds** (DCP closes the connection if the handshake isn't received in time) + +--- + +### Step 5: Map Launch Configuration Types to Debug Adapter Configs + +The `debug_adapter_config` in the handshake tells DCP what debug adapter binary to launch. The IDE must determine this from the launch configuration type. + +The mapping information already exists in `extension/src/debugger/debuggerExtensions.ts` and the language-specific files: + +| Launch Config Type | Debug Adapter | Source Extension | +|-------------------|---------------|-----------------| +| `project` | `coreclr` | `ms-dotnettools.csharp` | +| `python` | `debugpy` | `ms-python.python` | + +Add a method to `ResourceDebuggerExtension` (or a standalone utility) that returns a `DebugAdapterConfig`: + +```ts +export interface ResourceDebuggerExtension { + // ... existing fields ... + getDebugAdapterConfig?(launchConfig: LaunchConfiguration): DebugAdapterConfig; +} +``` + +For each resource type: +- **`project` / `coreclr`**: Resolve the path to the C# debug adapter executable from the `ms-dotnettools.csharp` extension. Set `mode: "stdio"`. The `args` array should be the command line to launch the adapter (e.g., `["/path/to/Microsoft.CodeAnalysis.LanguageServer", "--debug"]` or whatever the coreclr adapter binary is). +- **`python` / `debugpy`**: Resolve the path to the debugpy adapter. Set `mode: "stdio"` or `"tcp-connect"` as appropriate. For `"tcp-connect"`, use `{{port}}` as a placeholder in `args` — DCP will replace it with an actual port number. + +This is the **key integration point** — the extension needs to locate the actual debug adapter binary that would normally be launched by VS Code's built-in debug infrastructure and package it as an `args` array for the handshake. + +--- + +### Step 6: Update `PUT /run_session` Handler + +**File:** `extension/src/dcp/AspireDcpServer.ts` (~lines 84-120) + +Modify the `PUT /run_session` handler: + +``` +Parse request body + ↓ +Extract debug_bridge_socket_path and debug_session_id + ↓ +┌─ If BOTH fields are present (bridge mode): +│ 1. Resolve DebugAdapterConfig for the launch configuration type (Step 5) +│ 2. Call connectToDebugBridge() with socket path, bearer token, session ID, adapter config +│ 3. Get back the connected net.Socket +│ 4. Create a DebugBridgeAdapter wrapping the socket (Step 7) +│ 5. Start a VS Code debug session using this adapter +│ 6. Respond 201 Created + Location header +│ +└─ If fields are ABSENT (legacy mode): + Follow existing flow unchanged +``` + +--- + +### Step 7: Create a Bridge Debug Adapter + +**New file:** `extension/src/debugger/debugBridgeAdapter.ts` + +Create a custom `vscode.DebugAdapter` that proxies DAP messages to/from the connected Unix socket: + +```ts +export class DebugBridgeAdapter implements vscode.DebugAdapter { + private sendMessage: vscode.EventEmitter; + onDidSendMessage: vscode.Event; + + constructor(private socket: net.Socket) { ... } + + // Called by VS Code when it wants to send a DAP message to the adapter + handleMessage(message: vscode.DebugProtocolMessage): void { + // Write as DAP-framed message (Content-Length header + JSON) to the socket + } + + // Read DAP-framed messages from the socket and emit via onDidSendMessage + + dispose(): void { + // Close the socket + } +} +``` + +**Why not `DebugAdapterNamedPipeServer`?** The handshake must complete before DAP messages flow. `DebugAdapterNamedPipeServer` would try to send DAP messages immediately on connect, bypassing the handshake. The inline adapter approach gives full control over the connection lifecycle. + +Then update `AspireDebugAdapterDescriptorFactory` to return a `DebugAdapterInlineImplementation` wrapping this adapter for bridge sessions: + +```ts +return new vscode.DebugAdapterInlineImplementation(new DebugBridgeAdapter(connectedSocket)); +``` + +--- + +### Step 8: Update Debug Session Lifecycle + +**File:** `extension/src/debugger/AspireDebugSession.ts` + +For bridge-mode sessions: +- The `launch` request handler should **not** spawn `aspire run --start-debug-session` (DCP already manages the process) +- Track whether this is a bridge session (e.g., via a flag or session metadata) +- On `disconnect`/`terminate`, close the bridge socket connection +- Teardown should notify DCP via the existing WebSocket notification path (`sessionTerminated`) + +--- + +### Step 9: Update Adapter Tracker for Bridge Sessions + +**File:** `extension/src/debugger/adapterTracker.ts` + +For bridge sessions: +- DCP captures stdout/stderr directly from the debug adapter's output events and from `runInTerminal` process pipes — the tracker should **not** send duplicate `serviceLogs` notifications for output that DCP is already capturing +- The tracker should still send `processRestarted` and `sessionTerminated` notifications +- Consider skipping tracker registration entirely for bridge sessions, or adding a bridge-mode flag that suppresses log forwarding + +--- + +### Step 10: Update C# Models (if needed) + +**Files in:** `src/Aspire.Hosting/Dcp/Model/` + +If the app host or dashboard reads the run session payload structure, update any C# deserialization models to include the new optional fields for forward compatibility. Check: +- `RunSessionInfo.cs` +- Any request/response models that mirror the run session payload + +This may not be strictly necessary if the C# side doesn't interact with these fields — DCP adds them server-side. But it's good practice for model completeness. + +--- + +## Error Reporting + +### Current State (Implemented in DCP) + +The DCP bridge now sends meaningful DAP error information to the IDE when errors occur after the handshake. The implementation uses `OutputEvent` (category: `"stderr"`) followed by `TerminatedEvent` to communicate errors through the standard DAP protocol. + +### Error Scenarios and Behavior + +| Scenario | What IDE Sees | +|----------|---------------| +| Handshake failure (bad token, invalid session, missing config) | Handshake error JSON response — handled cleanly | +| Handshake read failure (malformed data, timeout) | Raw connection drop — no DAP-level error possible (pre-handshake) | +| Debug adapter fails to launch (bad command, missing binary) | `OutputEvent` (stderr) with error text + `TerminatedEvent` | +| Adapter connection timeout (TCP modes) | `OutputEvent` (stderr) with error text + `TerminatedEvent` | +| Adapter crashes before sending `TerminatedEvent` | Synthesized `TerminatedEvent` (with optional `OutputEvent` if transport error) | +| Transport read/write failure mid-session | `OutputEvent` (stderr) + synthesized `TerminatedEvent` | + +### DCP Implementation Details + +#### 1. DAP message helpers in `internal/dap/message.go` + +Unexported helper functions synthesize DAP messages for error reporting: + +```go +// newOutputEvent creates an OutputEvent for sending error/info text to the IDE. +func newOutputEvent(seq int, category, output string) *dap.OutputEvent + +// newTerminatedEvent creates a TerminatedEvent to signal session end. +func newTerminatedEvent(seq int) *dap.TerminatedEvent +``` + +Note: `NewErrorResponse` was considered but not implemented — `OutputEvent` + `TerminatedEvent` is sufficient for all error scenarios. + +#### 2. Error delivery via `sendErrorToIDE()` in `bridge.go` + +When errors occur after the IDE transport is established, `sendErrorToIDE()` sends an `OutputEvent` with `category: "stderr"` followed by a `TerminatedEvent`. Sequence numbers for bridge-originated messages use `b.ideSeqCounter` (an atomic counter separate from the IDE's own sequence numbers): + +```go +func (b *DapBridge) sendErrorToIDE(message string) { + outputEvent := newOutputEvent(int(b.ideSeqCounter.Add(1)), "stderr", message+"\n") + b.ideTransport.WriteMessage(outputEvent) + b.sendTerminatedToIDE() +} +``` + +#### 3. Adapter launch failure + +When `launchAdapterWithConfig` fails, `sendErrorToIDE()` is called before returning the error: + +```go +launchErr := b.launchAdapterWithConfig(ctx, adapterConfig) +if launchErr != nil { + b.sendErrorToIDE(fmt.Sprintf("Failed to launch debug adapter: %v", launchErr)) + return fmt.Errorf("failed to launch debug adapter: %w", launchErr) +} +``` + +#### 4. Unexpected adapter exit + +When `<-b.adapter.Done()` fires and the adapter did NOT send a `TerminatedEvent` (tracked via `terminatedEventSeen` flag), the bridge synthesizes one. If the exit was due to a transport error (as opposed to clean EOF/cancellation), an `OutputEvent` with the error text is sent first. + +#### 5. Transport failures + +When read/write errors occur in the message loop, the bridge attempts to send an `OutputEvent` describing the failure before closing the connection. + +### Required Changes — IDE/Aspire Side + +#### 1. Handle handshake failures in `debugBridgeClient.ts` + +When `connectToDebugBridge()` receives `{"success": false, "error": "..."}`, throw an error that includes the error message. The VS Code extension should surface this to the user via: +- A `vscode.window.showErrorMessage()` call with the error text +- A `sessionMessage` notification (level: `error`) sent to DCP via the WebSocket notification stream +- Clean termination of the debug session + +#### 2. Handle DAP error events in `DebugBridgeAdapter` + +The `DebugBridgeAdapter` (Step 7 in the main plan) should watch for `OutputEvent` messages with `category: "stderr"` that arrive before the first `InitializeResponse`. These indicate adapter launch errors from DCP. The adapter should: +- Forward them to VS Code (which will display them in the Debug Console) +- If followed by a `TerminatedEvent`, terminate the session cleanly + +#### 3. Handle unexpected connection drops + +If the Unix socket closes unexpectedly (without a `TerminatedEvent` or `DisconnectResponse`), the `DebugBridgeAdapter` should: +- Fire a `TerminatedEvent` to VS Code so the debug session ends cleanly +- Optionally display an error message indicating the debug bridge connection was lost + +--- + +## Key Decisions + +| Decision | Rationale | +|----------|-----------| +| **Inline adapter over named pipe descriptor** | The handshake must complete before DAP messages flow, so we need a `DebugAdapterInlineImplementation` wrapping a custom adapter that manages the socket lifecycle | +| **Token reuse** | The same bearer token used for HTTP authentication (`DEBUG_SESSION_TOKEN`) is reused as the bridge handshake token — no new credential needed | +| **IDE decides adapter** | DCP does NOT tell the IDE which adapter to use; the IDE determines this from the launch configuration type and sends the adapter binary path + args back in the handshake's `debug_adapter_config` | +| **Backward compatible** | When `debug_bridge_socket_path` is absent from the run session request, the existing non-bridge flow is used unchanged | +| **DAP-level error reporting** | DCP sends `OutputEvent` (category: stderr) + `TerminatedEvent` to the IDE when errors occur after handshake, so the IDE can display meaningful errors instead of a silent connection drop | +| **Single `BridgeManager`** | Session management, socket listening, and bridge lifecycle are combined into one `BridgeManager` type rather than separate `BridgeSessionManager` and `BridgeSocketManager` — simpler lifecycle management with a single mutex | +| **Sequence number remapping** | Bridge-assigned seq numbers prevent collisions between IDE-originated and bridge-originated (e.g., `runInTerminal` response) messages; a `seqMap` restores original seq values on responses | +| **`RawMessage` fallback** | Unknown/proprietary DAP messages that the `go-dap` library can't decode are wrapped in `RawMessage` and forwarded transparently, enabling support for custom debug adapter extensions | +| **`PrivateUnixSocketListener`** | Uses the project's `networking.PrivateUnixSocketListener` instead of a plain Unix domain socket for enhanced security | +| **Environment filtering on adapter launch** | Adapter processes inherit the DCP environment but with `DEBUG_SESSION*` and `DCP_*` variables removed, preventing credential leakage to debug adapters | + +--- + +## Verification + +1. **Unit tests**: Test `connectToDebugBridge()` with a mock Unix socket server that validates the length-prefixed JSON format, token, and session ID +2. **Integration test**: Start a DCP instance with debug bridge enabled, verify the extension: + - Reports `"2026-02-01"` in `protocols_supported` + - Connects to the bridge socket when `debug_bridge_socket_path` is in the run request + - Sends a valid handshake with correct adapter config + - Successfully forwards DAP messages through the bridge +3. **Error scenario tests**: + - Handshake failure (bad token) → extension shows meaningful error, session terminates cleanly + - Adapter launch failure (bad binary path) → extension receives `OutputEvent` with error text and `TerminatedEvent`, session terminates cleanly + - Unexpected connection drop → extension fires synthetic `TerminatedEvent` to VS Code, session ends without hang +4. **Regression**: Ensure the existing (non-bridge) flow still works when DCP negotiates an older API version +5. **Manual test**: Debug a .NET Aspire app with the updated extension and verify breakpoints, stepping, variable inspection all work through the bridge + +--- + +## DCP Implementation Details + +These sections document key aspects of the DCP-side implementation that the IDE extension should be aware of. + +### Sequence Number Remapping + +The bridge remaps DAP sequence numbers to prevent collisions between IDE-originated messages and bridge-originated messages (such as `RunInTerminalResponse` or synthesized error events). This is implemented in `bridge.go` with three components: + +- **`adapterSeqCounter`** (atomic `int64`): Generates monotonically increasing `seq` numbers for all messages sent to the adapter. When forwarding an IDE message, the bridge replaces `seq` with a bridge-assigned value and records the mapping. +- **`ideSeqCounter`** (atomic `int64`): Generates `seq` numbers for bridge-originated messages sent to the IDE (synthesized `OutputEvent`, `TerminatedEvent`, `RunInTerminalResponse`). +- **`seqMap`** (`syncmap.Map[int, int]`): Maps bridge-assigned seq numbers → original IDE seq numbers. When a response comes back from the adapter, the bridge looks up the `request_seq` in this map and restores the original IDE seq value before forwarding. + +This is transparent to the IDE — the IDE sees its own seq numbers on all responses. + +### RawMessage and MessageEnvelope + +`message.go` contains two key types that enable transparent proxying of all DAP messages, including those unknown to the `go-dap` library: + +**`RawMessage`**: Wraps the raw JSON bytes of a DAP message that couldn't be decoded by `go-dap`. This enables the bridge to transparently forward proprietary/custom DAP messages (e.g., custom commands from language-specific debug adapters) without needing to understand their schema. + +**`MessageEnvelope`**: A wrapper that provides uniform access to DAP header fields (`seq`, `type`, `request_seq`, `command`, `event`) across both typed `go-dap` messages and `RawMessage` instances. It supports: +- Lazy extraction of header fields at creation time +- Free modification of `Seq`, `RequestSeq`, etc. +- `Finalize()` to apply changes back — zero-cost for typed messages, single JSON field patch for raw messages + +The bridge uses `ReadMessageWithFallback` / `WriteMessageWithFallback` instead of the standard `go-dap` reader/writer. These functions attempt standard decoding first, falling back to `RawMessage` for unrecognized message types. + +### Output Routing + +When a bridge connection is established, `BridgeManager` invokes a `BridgeConnectionHandler` callback to resolve output routing: + +```go +type BridgeConnectionHandler func(sessionID string, runID string) (OutputHandler, io.Writer, io.Writer) +``` + +This returns: +- An `OutputHandler` interface (`HandleOutput(category, output string)`) for routing DAP `OutputEvent` messages +- `io.Writer` instances for stdout and stderr (used as sinks for `runInTerminal` process pipes) + +The `run_id` field in the handshake request is what connects the bridge session to the correct executable's output files. In `internal/exerunners/`, the `bridgeOutputHandler` implementation routes: +- `"stdout"` and `"console"` category events → stdout writer +- `"stderr"` category events → stderr writer +- Other categories → silently discarded + +Output routing only captures via `OutputHandler` when `runInTerminal` was NOT used (tracked by `runInTerminalUsed` flag). When `runInTerminal` launches a process, DCP captures stdout/stderr directly from the process pipes, avoiding double-capture. + +### BridgeManager Lifecycle + +`BridgeManager` is the single orchestrator for all bridge sessions. It combines session registration, socket management, and bridge lifecycle: + +1. **Creation**: `NewBridgeManager(BridgeManagerConfig{Logger, ConnectionHandler})` — requires a `BridgeConnectionHandler` callback +2. **Start**: `Start(ctx)` creates a `PrivateUnixSocketListener`, signals readiness via `Ready()` channel, then enters an accept loop +3. **Session registration**: `RegisterSession(sessionID, token)` creates a `BridgeSession` in `BridgeSessionStateCreated` state. Session ID is typically `string(exe.UID)`. +4. **Connection handling**: Each accepted connection goes through handshake, validation, `markSessionConnected()`, then `runBridge()`. If anything fails between marking connected and running, `markSessionDisconnected()` rolls back to allow retry. +5. **Bridge construction**: Creates a `DapBridge` via `NewDapBridge(BridgeConfig{...})` where `BridgeConfig` includes `SessionID`, `AdapterConfig`, `Executor`, `Logger`, `OutputHandler`, `StdoutWriter`, `StderrWriter` +6. **Termination**: Session moves to `BridgeSessionStateTerminated` (success) or `BridgeSessionStateError` (failure) when the bridge's `RunWithConnection` returns + +### DapBridge Lifecycle + +The `DapBridge` handles a single debug session's message forwarding: + +1. `RunWithConnection(ctx, ideConn)` creates an IDE transport and calls `launchAdapterWithConfig` +2. On adapter launch failure: `sendErrorToIDE()` → return error +3. On success: enters `runMessageLoop(ctx)` +4. Message loop starts two goroutines (`forwardIDEToAdapter`, `forwardAdapterToIDE`) and watches for adapter process exit via `<-b.adapter.Done()` +5. On adapter exit without `TerminatedEvent`: synthesizes one (optionally preceded by an error `OutputEvent`) +6. Cleanup: closes both transports, waits for goroutines, collects errors + +### Adapter Launch Environment Filtering + +When launching a debug adapter process, `buildFilteredEnv()` in `adapter_launcher.go`: +1. Inherits the DCP process's full environment +2. Removes variables with `DEBUG_SESSION` or `DCP_` prefixes (case-insensitive on Windows) +3. Applies any environment variables specified in the `DebugAdapterConfig.Env` array on top + +Additionally, all adapter modes capture the adapter's stderr via a pipe and log it for diagnostic purposes. + +--- + +## Appendix A: Debug Bridge Protocol Specification + +### Overview + +When API version `2026-02-01` or later is negotiated, DCP may include debug bridge fields in the `PUT /run_session` request. When present, the IDE should connect to the provided Unix domain socket and use DCP as a DAP bridge instead of launching its own debug adapter. + +### Connection Flow + +1. IDE receives `PUT /run_session` with `debug_bridge_socket_path` and `debug_session_id` +2. IDE responds `201 Created` with `Location` header (as normal) +3. IDE connects to the Unix domain socket at `debug_bridge_socket_path` +4. IDE sends a handshake request (length-prefixed JSON) +5. DCP validates and responds with a handshake response +6. On success, standard DAP messages flow over the same socket connection +7. DCP launches the debug adapter specified in the handshake and bridges messages bidirectionally + +### Handshake Wire Format + +All handshake messages use **length-prefixed JSON**: +``` +[4 bytes: big-endian uint32 payload length][JSON payload bytes] +``` + +Maximum message size: **65536 bytes** (64 KB). + +### Handshake Request (IDE → DCP) + +```json +{ + "token": "", + "session_id": "", + "run_id": "", + "debug_adapter_config": { + "args": ["/path/to/debug-adapter", "--arg1", "value1"], + "mode": "stdio", + "env": [ + { "name": "VAR_NAME", "value": "var_value" } + ], + "connectionTimeout": "10s" + } +} +``` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `token` | `string` | Yes | The same bearer token used for HTTP authentication | +| `session_id` | `string` | Yes | The `debug_session_id` from the run session request | +| `run_id` | `string` | No | Correlates the bridge session with the executable's output writers for log routing | +| `debug_adapter_config` | `object` | Yes | Configuration for launching the debug adapter | +| `debug_adapter_config.args` | `string[]` | Yes | Command + arguments to launch the adapter. First element is the executable path. | +| `debug_adapter_config.mode` | `string` | No | `"stdio"` (default), `"tcp-callback"`, or `"tcp-connect"` | +| `debug_adapter_config.env` | `array` | No | Environment variables as `[{"name":"N","value":"V"}]` (uses `apiv1.EnvVar` type on DCP side) | +| `debug_adapter_config.connectionTimeout` | `string` | No | Timeout for TCP connections as a Go duration string, e.g. `"10s"` (default: 10 seconds) | + +### Debug Adapter Modes + +| Mode | Description | +|------|-------------| +| `stdio` (default) | DCP launches the adapter and communicates via stdin/stdout | +| `tcp-callback` | DCP starts a TCP listener, substitutes `{{port}}` in `args` with the listener port, then launches the adapter. The adapter connects back to DCP on that port. | +| `tcp-connect` | DCP allocates a port, replaces `{{port}}` placeholder in `args`, launches the adapter (which listens on that port), then DCP connects to it. | + +### Handshake Response (DCP → IDE) + +Success: +```json +{ + "success": true +} +``` + +Failure: +```json +{ + "success": false, + "error": "error description" +} +``` + +### Handshake Validation + +DCP validates the handshake in this order: +1. Session ID exists → otherwise `"bridge session not found"` (`ErrBridgeSessionNotFound`) +2. Token matches the registered session token → otherwise `"invalid session token"` (`ErrBridgeSessionInvalidToken`) +3. `debug_adapter_config` is present → otherwise `"debug adapter configuration is required"` +4. Session not already connected → otherwise `"session already connected"` (`ErrBridgeSessionAlreadyConnected`) (only one IDE connection per session allowed) + +If connection fails after marking the session as connected (between step 4 and running the bridge), the connected state is rolled back via `markSessionDisconnected()` so the session can be retried. + +### Timeouts + +| Timeout | Duration | Description | +|---------|----------|-------------| +| Handshake | 30 seconds | DCP closes the connection if the handshake request isn't received within this time | +| Adapter connection (TCP modes) | 10 seconds (configurable) | Time to establish TCP connection to/from adapter | + +### DAP Message Flow After Handshake + +After a successful handshake, standard DAP messages flow over the Unix socket using the standard DAP wire format (`Content-Length: N\r\n\r\n{JSON}`). + +DCP intercepts the following messages: +- **`initialize` request** (IDE → Adapter): DCP forces `supportsRunInTerminalRequest = true` in the arguments before forwarding +- **`runInTerminal` reverse request** (Adapter → IDE): DCP handles this locally by launching the process. The IDE will **never** receive `runInTerminal` requests. +- **`output` events** (Adapter → IDE): DCP captures these for logging purposes, then forwards to the IDE + +All other DAP messages are forwarded transparently in both directions. + +### Output Capture + +| Scenario | stdout/stderr source | Output events | +|----------|---------------------|---------------| +| No `runInTerminal` | Captured from DAP `output` events | Logged by DCP + forwarded to IDE | +| With `runInTerminal` | Captured from process pipes by DCP | Forwarded to IDE (not logged from events) | + +--- + +## Appendix B: Relevant DCP Source Files + +These files in the `microsoft/dcp` repo implement the DCP side of the bridge protocol, for reference: + +### `internal/dap/` — Core Bridge Package + +| File | Purpose | +|------|---------| +| `internal/dap/doc.go` | Package-level documentation | +| `internal/dap/bridge.go` | Core `DapBridge` — bidirectional message forwarding with interception, sequence number remapping, inline `runInTerminal` handling (`handleRunInTerminalRequest`), and error reporting via `sendErrorToIDE()` | +| `internal/dap/bridge_handshake.go` | Length-prefixed JSON handshake protocol: `HandshakeRequest`/`HandshakeResponse` types, `HandshakeReader`/`HandshakeWriter`, `performClientHandshake()` convenience function, `maxHandshakeMessageSize` (64 KB) constant | +| `internal/dap/bridge_manager.go` | `BridgeManager` — combined session management, `PrivateUnixSocketListener` socket lifecycle, handshake processing, and bridge lifecycle. Contains `BridgeSession` with states (`Created`, `Connected`, `Terminated`, `Error`), session registration/rollback, and `BridgeConnectionHandler` callback type | +| `internal/dap/adapter_types.go` | `DebugAdapterConfig` struct (args, mode, env as `[]apiv1.EnvVar`, connectionTimeout as `*metav1.Duration`) and `DebugAdapterMode` constants (`stdio`, `tcp-callback`, `tcp-connect`) | +| `internal/dap/adapter_launcher.go` | `LaunchDebugAdapter()` — starts adapter processes in all 3 modes, environment filtering (`buildFilteredEnv()` removes `DEBUG_SESSION*`/`DCP_*` variables), adapter stderr capture via pipe, `LaunchedAdapter` struct with transport + process handle + done channel | +| `internal/dap/transport.go` | `Transport` interface with a single `connTransport` backing implementation shared by three factory functions: `NewTCPTransportWithContext`, `NewStdioTransportWithContext`, `NewUnixTransportWithContext`. Uses `dcpio.NewContextReader` for cancellation-aware reads | +| `internal/dap/message.go` | `RawMessage` (transparent forwarding of unknown/proprietary DAP messages), `MessageEnvelope` (uniform header access with lazy seq patching), `ReadMessageWithFallback`/`WriteMessageWithFallback`, unexported helpers `newOutputEvent`/`newTerminatedEvent` | + +### `internal/exerunners/` — Integration Points + +| File | Purpose | +|------|---------| +| `internal/exerunners/ide_executable_runner.go` | Integration point — creates `BridgeManager` when `SupportsDebugBridge()`, registers bridge sessions using `exe.UID` as session ID, includes `debug_bridge_socket_path` and `debug_session_id` in run session requests | +| `internal/exerunners/ide_requests_responses.go` | Protocol types, API version definitions (`version20260201 = "2026-02-01"`), `ideRunSessionRequestV1` with bridge fields (`DebugBridgeSocketPath`, `DebugSessionID`) | +| `internal/exerunners/ide_connection_info.go` | Version negotiation, `SupportsDebugBridge()` helper (checks `>= version20260201`) | +| `internal/exerunners/bridge_output_handler.go` | `bridgeOutputHandler` implementing `dap.OutputHandler` — routes DAP output events by category (`"stdout"`/`"console"` → stdout writer, `"stderr"` → stderr writer) | diff --git a/go.mod b/go.mod index add9df2b..d64bb92d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-logr/logr v1.4.3 github.com/go-logr/zapr v1.3.0 github.com/google/go-cmp v0.7.0 + github.com/google/go-dap v0.12.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 github.com/joho/godotenv v1.5.1 @@ -48,9 +49,13 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cilium/ebpf v0.11.0 // indirect github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/cosiner/argv v0.1.0 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/derekparker/trie/v3 v3.2.0 // indirect github.com/ebitengine/purego v0.9.0 // indirect github.com/emicklei/go-restful/v3 v3.12.2 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect @@ -58,6 +63,8 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect + github.com/go-delve/delve v1.26.0 // indirect + github.com/go-delve/liner v1.2.3-0.20231231155935-4726ab1d7f62 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-openapi/jsonpointer v0.21.1 // indirect @@ -81,6 +88,7 @@ require ( github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.13 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect @@ -91,6 +99,8 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.63.0 // indirect github.com/prometheus/procfs v0.16.1 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/numcpus v0.10.0 // indirect @@ -105,14 +115,17 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 // indirect go.opentelemetry.io/proto/otlp v1.5.0 // indirect + go.starlark.net v0.0.0-20231101134539-556fd59b42f6 // indirect go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/arch v0.11.0 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.18.0 // indirect + golang.org/x/telemetry v0.0.0-20251008203120-078029d740a8 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/time v0.11.0 // indirect golang.org/x/tools v0.38.0 // indirect @@ -139,6 +152,7 @@ require ( ) tool ( + github.com/go-delve/delve/cmd/dlv github.com/josephspurrier/goversioninfo/cmd/goversioninfo google.golang.org/grpc/cmd/protoc-gen-go-grpc google.golang.org/protobuf/cmd/protoc-gen-go diff --git a/go.sum b/go.sum index 404d24f8..1a91249b 100644 --- a/go.sum +++ b/go.sum @@ -22,17 +22,26 @@ github.com/chromedp/sysutil v1.0.0/go.mod h1:kgWmDdq8fTzXYcKIBqIYvRRTnYb9aNS9moA github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/cilium/ebpf v0.11.0 h1:V8gS/bTCCjX9uUnkUFUpPsksM8n1lXBAvHcpiFk1X2Y= +github.com/cilium/ebpf v0.11.0/go.mod h1:WE7CZAnqOL2RouJ4f1uyNhqr2P4CCvXFIqdRDUgWsVs= github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cosiner/argv v0.1.0 h1:BVDiEL32lwHukgJKP87btEPenzrrHUjajs/8yzaqcXg= +github.com/cosiner/argv v0.1.0/go.mod h1:EusR6TucWKX+zFgtdUsKT2Cvg45K5rtpCcWz4hK06d8= +github.com/cpuguy83/go-md2man/v2 v2.0.6 h1:XJtiaUW6dEEqVuZiMTn1ldk455QWwEIsMIJlo5vtkx0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.20 h1:VIPb/a2s17qNeQgDnkfZC35RScx+blkKF8GV68n80J4= +github.com/creack/pty v1.1.20/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davidwartell/go-onecontext v1.0.2 h1:LfnYCXKsN24jQze/vmfbXrP84AtejOQQxlpUlAenFKs= github.com/davidwartell/go-onecontext v1.0.2/go.mod h1:pIqzkTZw5tV74x9mRCH/u9GtyiufWx2WKzLWArQt06I= +github.com/derekparker/trie/v3 v3.2.0 h1:fET3Qbp9xSB7yc7tz6Y2GKMNl0SycYFo3cmiRI3Gpf0= +github.com/derekparker/trie/v3 v3.2.0/go.mod h1:P94lW0LPgiaMgKAEQD59IDZD2jMK9paKok8Nli/nQbE= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.9.0 h1:mh0zpKBIXDceC63hpvPuGLiJ8ZAa3DfrFTudmfi8A4k= @@ -51,10 +60,16 @@ github.com/felixge/fgprof v0.9.5 h1:8+vR6yu2vvSKn08urWyEuxx75NWPEvybbkBirEpsbVY= github.com/felixge/fgprof v0.9.5/go.mod h1:yKl+ERSa++RYOs32d8K6WEXCB4uXdLls4ZaZPpayhMM= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA= +github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= +github.com/go-delve/delve v1.26.0 h1:YZT1kXD76mxba4/wr+tyUa/tSmy7qzoDsmxutT42PIs= +github.com/go-delve/delve v1.26.0/go.mod h1:8BgFFOXTi1y1M+d/4ax1LdFw0mlqezQiTZQpbpwgBxo= +github.com/go-delve/liner v1.2.3-0.20231231155935-4726ab1d7f62 h1:IGtvsNyIuRjl04XAOFGACozgUD7A82UffYxZt4DWbvA= +github.com/go-delve/liner v1.2.3-0.20231231155935-4726ab1d7f62/go.mod h1:biJCRbqp51wS+I92HMqn5H8/A0PAhxn2vyOT+JqhiGI= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -93,6 +108,8 @@ github.com/google/gnostic-models v0.7.0/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7O github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-dap v0.12.0 h1:rVcjv3SyMIrpaOoTAdFDyHs99CwVOItIJGKLQFQhNeM= +github.com/google/go-dap v0.12.0/go.mod h1:tNjCASCm5cqePi/RVXXWEVqtnNLV1KTWtYOqu6rZNzc= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -145,6 +162,9 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= +github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -177,8 +197,11 @@ github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shirou/gopsutil/v4 v4.25.10 h1:at8lk/5T1OgtuCp+AwrDofFRjnvosn0nkN2OLQ6g8tA= github.com/shirou/gopsutil/v4 v4.25.10/go.mod h1:+kSwyC8DRUD9XXEHCAFjK+0nuArFJM0lva+StQAcskM= @@ -261,6 +284,8 @@ go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJr go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= +go.starlark.net v0.0.0-20231101134539-556fd59b42f6 h1:+eC0F/k4aBLC4szgOcjd7bDTEnpxADJyWJE0yowgM3E= +go.starlark.net v0.0.0-20231101134539-556fd59b42f6/go.mod h1:LcLNIzVOMp4oV+uusnpk+VU+SzXaJakUuBjoCSWH5dM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -271,6 +296,8 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -300,11 +327,14 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20251008203120-078029d740a8 h1:LvzTn0GQhWuvKH/kVRS3R3bVAsdQWI7hvfLHGgh9+lU= +golang.org/x/telemetry v0.0.0-20251008203120-078029d740a8/go.mod h1:Pi4ztBfryZoJEkyFTI5/Ocsu2jXyDr6iSdgJiYE/uwE= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/commands/monitor.go b/internal/commands/monitor.go index 9ca16389..3bc99c8a 100644 --- a/internal/commands/monitor.go +++ b/internal/commands/monitor.go @@ -29,22 +29,21 @@ func AddMonitorFlags(cmd *cobra.Command) { cmd.Flags().Uint8VarP(&monitorInterval, "monitor-interval", "i", 0, "If present, specifies the time in seconds between checks for the monitor PID.") } -// Starts monitoring a process with a given PID and (optional) start time. +// Starts monitoring a process identified by the given handle. // Returns a context that will be cancelled when the monitored process exits, or if the returned cancellation function is called. // The returned context (and the cancellation function) is valid even if an error occurs (e.g. the process cannot be found), // but it will be already cancelled in that case. func MonitorPid( ctx context.Context, - pid process.Pid_t, - expectedProcessStartTime time.Time, + handle process.ProcessHandle, pollInterval uint8, logger logr.Logger, ) (context.Context, context.CancelFunc, error) { monitorCtx, monitorCtxCancel := context.WithCancel(ctx) - monitorProc, monitorProcErr := process.FindWaitableProcess(pid, expectedProcessStartTime) + monitorProc, monitorProcErr := process.FindWaitableProcess(handle) if monitorProcErr != nil { - logger.Info("Error finding process", "PID", pid) + logger.Info("Error finding process", "PID", handle.Pid) monitorCtxCancel() return monitorCtx, monitorCtxCancel, monitorProcErr } @@ -57,12 +56,12 @@ func MonitorPid( defer monitorCtxCancel() if waitErr := monitorProc.Wait(monitorCtx); waitErr != nil { if errors.Is(waitErr, context.Canceled) { - logger.V(1).Info("Monitoring cancelled by context", "PID", pid) + logger.V(1).Info("Monitoring cancelled by context", "PID", handle.Pid) } else { - logger.Error(waitErr, "Error waiting for process", "PID", pid) + logger.Error(waitErr, "Error waiting for process", "PID", handle.Pid) } } else { - logger.Info("Monitor process exited, shutting down", "PID", pid) + logger.Info("Monitor process exited, shutting down", "PID", handle.Pid) } }() @@ -83,6 +82,6 @@ func GetMonitorContextFromFlags(ctx context.Context, logger logr.Logger) (contex } // Ignore errors as they're logged by MonitorPid and we always return a valid context - monitorCtx, monitorCtxCancel, _ := MonitorPid(ctx, monitorPid, monitorProcessStartTime, monitorInterval, logger) + monitorCtx, monitorCtxCancel, _ := MonitorPid(ctx, process.NewHandle(monitorPid, monitorProcessStartTime), monitorInterval, logger) return monitorCtx, monitorCtxCancel } diff --git a/internal/contextdata/contextdata.go b/internal/contextdata/contextdata.go index fd950045..c002eea8 100644 --- a/internal/contextdata/contextdata.go +++ b/internal/contextdata/contextdata.go @@ -9,7 +9,6 @@ import ( "context" "fmt" "os/exec" - "time" "github.com/go-logr/logr" "github.com/microsoft/dcp/pkg/process" @@ -58,16 +57,16 @@ func GetProcessExecutor(ctx context.Context) process.Executor { type dummyProcessExecutor struct{} -func (*dummyProcessExecutor) StartProcess(_ context.Context, _ *exec.Cmd, _ process.ProcessExitHandler, _ process.ProcessCreationFlag) (process.Pid_t, time.Time, func(), error) { - return process.UnknownPID, time.Time{}, nil, fmt.Errorf("there is no process executor configured, no processes can be started") +func (*dummyProcessExecutor) StartProcess(_ context.Context, _ *exec.Cmd, _ process.ProcessExitHandler, _ process.ProcessCreationFlag) (process.ProcessHandle, func(), error) { + return process.ProcessHandle{Pid: process.UnknownPID}, nil, fmt.Errorf("there is no process executor configured, no processes can be started") } -func (*dummyProcessExecutor) StopProcess(_ process.Pid_t, _ time.Time) error { +func (*dummyProcessExecutor) StopProcess(_ process.ProcessHandle) error { return fmt.Errorf("there is no process executor configured, no processes can be stopped") } -func (*dummyProcessExecutor) StartAndForget(_ *exec.Cmd, _ process.ProcessCreationFlag) (process.Pid_t, time.Time, error) { - return process.UnknownPID, time.Time{}, fmt.Errorf("there is no process executor configured, no processes can be started") +func (*dummyProcessExecutor) StartAndForget(_ *exec.Cmd, _ process.ProcessCreationFlag) (process.ProcessHandle, error) { + return process.ProcessHandle{Pid: process.UnknownPID}, fmt.Errorf("there is no process executor configured, no processes can be started") } func (*dummyProcessExecutor) Dispose() { diff --git a/internal/dap/adapter_launcher.go b/internal/dap/adapter_launcher.go new file mode 100644 index 00000000..f8d94a19 --- /dev/null +++ b/internal/dap/adapter_launcher.go @@ -0,0 +1,444 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "errors" + "fmt" + "net" + "os/exec" + "strconv" + "strings" + "sync" + "time" + + apiv1 "github.com/microsoft/dcp/api/v1" + "github.com/microsoft/dcp/internal/dcpproc" + "github.com/microsoft/dcp/internal/networking" + "github.com/microsoft/dcp/pkg/maps" + "github.com/microsoft/dcp/pkg/osutil" + "github.com/microsoft/dcp/pkg/process" + + "github.com/go-logr/logr" +) + +// PortPlaceholder is the placeholder in adapter args that will be replaced with allocated port. +const PortPlaceholder = "{{port}}" + +// ErrInvalidAdapterConfig is returned when the debug adapter configuration is invalid. +var ErrInvalidAdapterConfig = errors.New("invalid debug adapter configuration: Args must have at least one element") + +// ErrAdapterConnectionTimeout is returned when the adapter fails to connect within the timeout. +var ErrAdapterConnectionTimeout = errors.New("debug adapter connection timeout") + +// LaunchedAdapter represents a running debug adapter process with its transport. +type LaunchedAdapter struct { + // Transport provides DAP message I/O with the debug adapter. + Transport Transport + + // handle identifies the debug adapter process. + handle process.ProcessHandle + + // listener is the TCP listener for callback mode (nil for other modes). + listener net.Listener + + // done signals when the process has exited. + done chan struct{} + + // exitCode contains the process exit code (if any). + exitCode int32 + + // exitErr contains the process exit error (if any). + exitErr error + + // mu protects exitCode and exitErr. + mu *sync.Mutex +} + +// Pid returns the process ID of the debug adapter. +func (la *LaunchedAdapter) Pid() process.Pid_t { + return la.handle.Pid +} + +// Done returns a channel that is closed when the debug adapter process exits. +func (la *LaunchedAdapter) Done() <-chan struct{} { + return la.done +} + +// Close cleans up the adapter resources. +// This closes the transport and listener, but does NOT stop the process. +// The process is stopped automatically when the context passed to LaunchDebugAdapter is cancelled. +func (la *LaunchedAdapter) Close() error { + var errs []error + if la.listener != nil { + if err := la.listener.Close(); err != nil { + errs = append(errs, err) + } + } + if la.Transport != nil { + if err := la.Transport.Close(); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +// LaunchDebugAdapter launches a debug adapter process using the provided configuration. +// The process lifetime is tied to the provided context - when the context is cancelled, +// the process will be killed by the executor. +// +// The returned LaunchedAdapter provides: +// - Transport: for DAP message I/O with the adapter +// - Wait(): to block until the process exits +// - Done(): a channel that closes when the process exits +// - Pid(): the process ID +// +// The caller must close the Transport when done. +func LaunchDebugAdapter(ctx context.Context, executor process.Executor, config *DebugAdapterConfig, log logr.Logger) (*LaunchedAdapter, error) { + if config == nil || len(config.Args) == 0 { + return nil, ErrInvalidAdapterConfig + } + + switch config.EffectiveMode() { + case DebugAdapterModeStdio: + return launchStdioAdapter(ctx, executor, config, log) + case DebugAdapterModeTCPCallback: + return launchTCPCallbackAdapter(ctx, executor, config, log) + case DebugAdapterModeTCPConnect: + return launchTCPConnectAdapter(ctx, executor, config, log) + default: + return launchStdioAdapter(ctx, executor, config, log) + } +} + +// launchStdioAdapter launches an adapter in stdio mode. +func launchStdioAdapter(ctx context.Context, executor process.Executor, config *DebugAdapterConfig, log logr.Logger) (*LaunchedAdapter, error) { + cmd := exec.Command(config.Args[0], config.Args[1:]...) + cmd.Env = buildFilteredEnv(config) + + stdin, stdinErr := cmd.StdinPipe() + if stdinErr != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", stdinErr) + } + + stdout, stdoutErr := cmd.StdoutPipe() + if stdoutErr != nil { + stdin.Close() + return nil, fmt.Errorf("failed to create stdout pipe: %w", stdoutErr) + } + + stderr, stderrErr := cmd.StderrPipe() + if stderrErr != nil { + stdin.Close() + stdout.Close() + return nil, fmt.Errorf("failed to create stderr pipe: %w", stderrErr) + } + + adapter := &LaunchedAdapter{ + mu: &sync.Mutex{}, + done: make(chan struct{}), + exitCode: process.UnknownExitCode, + } + + exitHandler := process.ProcessExitHandlerFunc(func(pid process.Pid_t, exitCode int32, err error) { + adapter.mu.Lock() + adapter.exitCode = exitCode + adapter.exitErr = err + adapter.mu.Unlock() + close(adapter.done) + + if err != nil { + log.V(1).Info("Debug adapter process exited with error", + "pid", pid, + "exitCode", exitCode, + "error", err) + } else { + log.V(1).Info("Debug adapter process exited", + "pid", pid, + "exitCode", exitCode) + } + }) + + handle, startWaitForExit, startErr := executor.StartProcess(ctx, cmd, exitHandler, process.CreationFlagEnsureKillOnDispose) + if startErr != nil { + stdin.Close() + stdout.Close() + stderr.Close() + return nil, fmt.Errorf("failed to start debug adapter: %w", startErr) + } + + // Start process monitor to ensure cleanup if DCP crashes + dcpproc.RunProcessWatcher(executor, handle, log) + + // Start waiting for process exit + startWaitForExit() + + go logStderr(stderr, log) + + log.Info("Launched debug adapter process (stdio mode)", + "command", config.Args[0], + "args", config.Args[1:], + "pid", handle.Pid) + + adapter.Transport = NewStdioTransportWithContext(ctx, stdout, stdin) + adapter.handle = handle + + return adapter, nil +} + +// launchTCPCallbackAdapter launches an adapter in TCP callback mode. +// We start a listener and the adapter connects to us. +func launchTCPCallbackAdapter(ctx context.Context, executor process.Executor, config *DebugAdapterConfig, log logr.Logger) (*LaunchedAdapter, error) { + // Start a listener on a free port + listener, listenErr := net.Listen("tcp", networking.AddressAndPort(networking.IPv4LocalhostDefaultAddress, 0)) + if listenErr != nil { + return nil, fmt.Errorf("failed to create listener: %w", listenErr) + } + + listenerAddr := listener.Addr().String() + log.Info("Listening for debug adapter callback", "address", listenerAddr) + + // Substitute {{port}} placeholder with our listening port + _, portStr, _ := net.SplitHostPort(listenerAddr) + args := substitutePort(config.Args, portStr) + + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = buildFilteredEnv(config) + + stderr, stderrErr := cmd.StderrPipe() + if stderrErr != nil { + listener.Close() + return nil, fmt.Errorf("failed to create stderr pipe: %w", stderrErr) + } + + adapter := &LaunchedAdapter{ + mu: &sync.Mutex{}, + listener: listener, + done: make(chan struct{}), + exitCode: process.UnknownExitCode, + } + + exitHandler := process.ProcessExitHandlerFunc(func(pid process.Pid_t, exitCode int32, err error) { + adapter.mu.Lock() + adapter.exitCode = exitCode + adapter.exitErr = err + adapter.mu.Unlock() + close(adapter.done) + + if err != nil { + log.V(1).Info("Debug adapter process exited with error", + "pid", pid, + "exitCode", exitCode, + "error", err) + } else { + log.V(1).Info("Debug adapter process exited", + "pid", pid, + "exitCode", exitCode) + } + }) + + handle, startWaitForExit, startErr := executor.StartProcess(ctx, cmd, exitHandler, process.CreationFlagEnsureKillOnDispose) + if startErr != nil { + listener.Close() + stderr.Close() + return nil, fmt.Errorf("failed to start debug adapter: %w", startErr) + } + + // Start process monitor to ensure cleanup if DCP crashes + dcpproc.RunProcessWatcher(executor, handle, log) + + // Start waiting for process exit + startWaitForExit() + + go logStderr(stderr, log) + + log.Info("Launched debug adapter process (tcp-callback mode)", + "command", args[0], + "args", args[1:], + "pid", handle.Pid, + "listenAddress", listenerAddr) + + adapter.handle = handle + + // Wait for adapter to connect + timeout := config.GetConnectionTimeout() + + connCh := make(chan net.Conn, 1) + errCh := make(chan error, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + errCh <- acceptErr + return + } + connCh <- conn + }() + + var conn net.Conn + select { + case conn = <-connCh: + log.Info("Debug adapter connected", "remoteAddr", conn.RemoteAddr().String()) + case acceptErr := <-errCh: + _ = executor.StopProcess(adapter.handle) + listener.Close() + return nil, fmt.Errorf("failed to accept adapter connection: %w", acceptErr) + case <-time.After(timeout): + _ = executor.StopProcess(adapter.handle) + listener.Close() + return nil, ErrAdapterConnectionTimeout + case <-ctx.Done(): + // Executor will handle stopping the process when context is cancelled + listener.Close() + return nil, ctx.Err() + } + + adapter.Transport = NewTCPTransportWithContext(ctx, conn) + return adapter, nil +} + +// launchTCPConnectAdapter launches an adapter in TCP connect mode. +// The adapter listens on a port and we connect to it. +func launchTCPConnectAdapter(ctx context.Context, executor process.Executor, config *DebugAdapterConfig, log logr.Logger) (*LaunchedAdapter, error) { + // Allocate a free port for the adapter + port, portErr := networking.GetFreePort(apiv1.TCP, networking.IPv4LocalhostDefaultAddress, log) + if portErr != nil { + return nil, fmt.Errorf("failed to allocate port: %w", portErr) + } + + portStr := strconv.Itoa(int(port)) + args := substitutePort(config.Args, portStr) + + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = buildFilteredEnv(config) + + stderr, stderrErr := cmd.StderrPipe() + if stderrErr != nil { + return nil, fmt.Errorf("failed to create stderr pipe: %w", stderrErr) + } + + adapter := &LaunchedAdapter{ + mu: &sync.Mutex{}, + done: make(chan struct{}), + exitCode: process.UnknownExitCode, + } + + exitHandler := process.ProcessExitHandlerFunc(func(pid process.Pid_t, exitCode int32, err error) { + adapter.mu.Lock() + adapter.exitCode = exitCode + adapter.exitErr = err + adapter.mu.Unlock() + close(adapter.done) + + if err != nil { + log.V(1).Info("Debug adapter process exited with error", + "pid", pid, + "exitCode", exitCode, + "error", err) + } else { + log.V(1).Info("Debug adapter process exited", + "pid", pid, + "exitCode", exitCode) + } + }) + + handle, startWaitForExit, startErr := executor.StartProcess(ctx, cmd, exitHandler, process.CreationFlagEnsureKillOnDispose) + if startErr != nil { + stderr.Close() + return nil, fmt.Errorf("failed to start debug adapter: %w", startErr) + } + + // Start process monitor to ensure cleanup if DCP crashes + dcpproc.RunProcessWatcher(executor, handle, log) + + // Start waiting for process exit + startWaitForExit() + + go logStderr(stderr, log) + + log.Info("Launched debug adapter process (tcp-connect mode)", + "command", args[0], + "args", args[1:], + "pid", handle.Pid, + "port", port) + + adapter.handle = handle + + // Connect to the adapter with retry + timeout := config.GetConnectionTimeout() + + addr := fmt.Sprintf("127.0.0.1:%d", port) + var conn net.Conn + var connectErr error + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + // Executor will handle stopping the process when context is cancelled + return nil, ctx.Err() + case <-adapter.done: + // Process exited before we could connect + return nil, fmt.Errorf("debug adapter process exited before connection could be established") + default: + } + + conn, connectErr = net.DialTimeout("tcp", addr, time.Second) + if connectErr == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + if connectErr != nil { + _ = executor.StopProcess(adapter.handle) + return nil, fmt.Errorf("%w: failed to connect to adapter at %s: %v", ErrAdapterConnectionTimeout, addr, connectErr) + } + + log.Info("Connected to debug adapter", "address", addr) + + adapter.Transport = NewTCPTransportWithContext(ctx, conn) + return adapter, nil +} + +// substitutePort replaces {{port}} placeholder in args with the actual port. +func substitutePort(args []string, port string) []string { + result := make([]string, len(args)) + for i, arg := range args { + result[i] = strings.ReplaceAll(arg, PortPlaceholder, port) + } + return result +} + +// buildFilteredEnv builds the environment for the adapter process by inheriting +// the ambient (current process) environment, removing variables with suppressed +// prefixes (DCP_ and DEBUG_SESSION), and then applying the config-specified +// environment variables on top. +func buildFilteredEnv(config *DebugAdapterConfig) []string { + envMap := osutil.NewFilteredAmbientEnv() + + for _, e := range config.Env { + envMap.Override(e.Name, e.Value) + } + + return maps.MapToSlice[string](envMap.Data(), func(key string, value string) string { + return key + "=" + value + }) +} + +// logStderr reads and logs stderr from the adapter. +func logStderr(stderr interface{ Read([]byte) (int, error) }, log logr.Logger) { + buf := make([]byte, 1024) + for { + n, readErr := stderr.Read(buf) + if n > 0 { + log.Info("Debug adapter stderr", "output", string(buf[:n])) + } + if readErr != nil { + return + } + } +} diff --git a/internal/dap/adapter_launcher_test.go b/internal/dap/adapter_launcher_test.go new file mode 100644 index 00000000..517cc3e7 --- /dev/null +++ b/internal/dap/adapter_launcher_test.go @@ -0,0 +1,120 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "os" + "strings" + "testing" + + apiv1 "github.com/microsoft/dcp/api/v1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildFilteredEnv_SuppressesDCPPrefix(t *testing.T) { + t.Setenv("DCP_TEST_VAR", "should-be-removed") + t.Setenv("DCP_ANOTHER", "also-removed") + + config := &DebugAdapterConfig{} + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.NotContains(t, envMap, "DCP_TEST_VAR") + assert.NotContains(t, envMap, "DCP_ANOTHER") +} + +func TestBuildFilteredEnv_SuppressesDebugSessionPrefix(t *testing.T) { + t.Setenv("DEBUG_SESSION_ID", "should-be-removed") + t.Setenv("DEBUG_SESSION_TOKEN", "also-removed") + + config := &DebugAdapterConfig{} + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.NotContains(t, envMap, "DEBUG_SESSION_ID") + assert.NotContains(t, envMap, "DEBUG_SESSION_TOKEN") +} + +func TestBuildFilteredEnv_InheritsNonSuppressedVars(t *testing.T) { + t.Setenv("MY_APP_VAR", "keep-this") + + config := &DebugAdapterConfig{} + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.Equal(t, "keep-this", envMap["MY_APP_VAR"]) +} + +func TestBuildFilteredEnv_ConfigEnvVarsAreApplied(t *testing.T) { + config := &DebugAdapterConfig{ + Env: []apiv1.EnvVar{ + {Name: "CUSTOM_VAR", Value: "custom-value"}, + {Name: "ANOTHER_VAR", Value: "another-value"}, + }, + } + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.Equal(t, "custom-value", envMap["CUSTOM_VAR"]) + assert.Equal(t, "another-value", envMap["ANOTHER_VAR"]) +} + +func TestBuildFilteredEnv_ConfigOverridesAmbient(t *testing.T) { + t.Setenv("OVERRIDE_ME", "original") + + config := &DebugAdapterConfig{ + Env: []apiv1.EnvVar{ + {Name: "OVERRIDE_ME", Value: "overridden"}, + }, + } + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.Equal(t, "overridden", envMap["OVERRIDE_ME"]) +} + +func TestBuildFilteredEnv_ConfigCanSetSuppressedPrefixVars(t *testing.T) { + // Even though DCP_ vars are suppressed from the ambient environment, + // the config should be able to explicitly set them. + t.Setenv("DCP_AMBIENT", "should-be-removed") + + config := &DebugAdapterConfig{ + Env: []apiv1.EnvVar{ + {Name: "DCP_EXPLICIT", Value: "explicitly-set"}, + }, + } + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.NotContains(t, envMap, "DCP_AMBIENT") + assert.Equal(t, "explicitly-set", envMap["DCP_EXPLICIT"]) +} + +func TestBuildFilteredEnv_InheritsPath(t *testing.T) { + // PATH should be inherited since it doesn't match any suppressed prefix. + pathVal := os.Getenv("PATH") + require.NotEmpty(t, pathVal, "PATH should be set in the test environment") + + config := &DebugAdapterConfig{} + env := buildFilteredEnv(config) + + envMap := sliceToEnvMap(env) + assert.Equal(t, pathVal, envMap["PATH"]) +} + +// sliceToEnvMap converts a []string of "KEY=VALUE" entries to a map. +func sliceToEnvMap(envSlice []string) map[string]string { + result := make(map[string]string, len(envSlice)) + for _, entry := range envSlice { + parts := strings.SplitN(entry, "=", 2) + if len(parts) == 2 { + result[parts[0]] = parts[1] + } + } + return result +} diff --git a/internal/dap/adapter_types.go b/internal/dap/adapter_types.go new file mode 100644 index 00000000..1e277290 --- /dev/null +++ b/internal/dap/adapter_types.go @@ -0,0 +1,74 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "time" + + apiv1 "github.com/microsoft/dcp/api/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// DefaultAdapterConnectionTimeout is the default timeout for connecting to the debug adapter. +const DefaultAdapterConnectionTimeout = 10 * time.Second + +// DebugAdapterMode specifies how the debug adapter communicates. +type DebugAdapterMode string + +const ( + // DebugAdapterModeStdio indicates the adapter uses stdin/stdout for DAP communication. + DebugAdapterModeStdio DebugAdapterMode = "stdio" + + // DebugAdapterModeTCPCallback indicates we start a listener and adapter connects to us. + // Pass our address to the adapter via --client-addr or similar. + DebugAdapterModeTCPCallback DebugAdapterMode = "tcp-callback" + + // DebugAdapterModeTCPConnect indicates we specify a port, adapter listens, we connect. + // Use {{port}} placeholder in args which is replaced with allocated port. + DebugAdapterModeTCPConnect DebugAdapterMode = "tcp-connect" +) + +// DebugAdapterConfig holds the configuration for launching a debug adapter. +// It is sent as part of the handshake request from the IDE and used internally +// to launch the adapter process. +type DebugAdapterConfig struct { + // Args contains the command and arguments to launch the debug adapter. + // The first element is the executable path, subsequent elements are arguments. + // May contain "{{port}}" placeholder for TCP modes. + Args []string `json:"args"` + + // Mode specifies how the adapter communicates. + // Valid values: "stdio" (default), "tcp-callback", "tcp-connect". + // An empty string is treated as "stdio". + Mode DebugAdapterMode `json:"mode,omitempty"` + + // Env contains environment variables to set for the adapter process. + Env []apiv1.EnvVar `json:"env,omitempty"` + + // ConnectionTimeout is the timeout for connecting to the adapter in TCP modes. + // If nil or zero, DefaultAdapterConnectionTimeout is used. + ConnectionTimeout *metav1.Duration `json:"connectionTimeout,omitempty"` +} + +// GetConnectionTimeout returns the connection timeout as a time.Duration. +// If ConnectionTimeout is nil or non-positive, DefaultAdapterConnectionTimeout is returned. +func (c *DebugAdapterConfig) GetConnectionTimeout() time.Duration { + if c.ConnectionTimeout != nil && c.ConnectionTimeout.Duration > 0 { + return c.ConnectionTimeout.Duration + } + return DefaultAdapterConnectionTimeout +} + +// EffectiveMode returns the adapter mode, defaulting to DebugAdapterModeStdio +// if Mode is empty or unrecognized. +func (c *DebugAdapterConfig) EffectiveMode() DebugAdapterMode { + switch c.Mode { + case DebugAdapterModeStdio, DebugAdapterModeTCPCallback, DebugAdapterModeTCPConnect: + return c.Mode + default: + return DebugAdapterModeStdio + } +} diff --git a/internal/dap/bridge.go b/internal/dap/bridge.go new file mode 100644 index 00000000..4528d1a6 --- /dev/null +++ b/internal/dap/bridge.go @@ -0,0 +1,571 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "os/exec" + "sync" + "sync/atomic" + + "github.com/go-logr/logr" + "github.com/google/go-dap" + "github.com/microsoft/dcp/pkg/process" +) + +// BridgeConfig contains configuration for creating a DapBridge. +type BridgeConfig struct { + // SessionID is the session identifier for this bridge. + SessionID string + + // AdapterConfig contains the configuration for launching the debug adapter. + AdapterConfig *DebugAdapterConfig + + // Executor is the process executor for managing debug adapter processes. + // If nil, a new OS executor will be created for this purpose. + Executor process.Executor + + // Logger for bridge operations. + Logger logr.Logger + + // OutputHandler is called when output events are received from the debug adapter, + // unless runInTerminal was used (in which case output is captured directly from the debugee + // process). If nil, output events are only forwarded without additional processing. + OutputHandler OutputHandler + + // StdoutWriter is where debugee process stdout (from runInTerminal) will be written. + // If nil, stdout is discarded. + StdoutWriter io.Writer + + // StderrWriter is where debugee process stderr (from runInTerminal) will be written. + // If nil, stderr is discarded. + StderrWriter io.Writer +} + +// OutputHandler is called when output events are received from the debug adapter. +type OutputHandler interface { + // HandleOutput is called for each output event. + // category is "stdout", "stderr", "console", etc. + // output is the actual output text. + HandleOutput(category string, output string) +} + +// DapBridge provides a bridge between an IDE's debug adapter client and a debug adapter host. +// It can either listen on a Unix domain socket for the IDE to connect (via Run), +// or accept an already-connected connection (via RunWithConnection). +type DapBridge struct { + config BridgeConfig + executor process.Executor + log logr.Logger + + // ideTransport is the transport to the IDE + ideTransport Transport + + // adapter is the launched debug adapter + adapter *LaunchedAdapter + + // runInTerminalUsed tracks whether runInTerminal was invoked + runInTerminalUsed atomic.Bool + + // terminatedEventSeen tracks whether the adapter sent a TerminatedEvent + terminatedEventSeen atomic.Bool + + // terminateCh is closed when the bridge terminates + terminateCh chan struct{} + + // terminateOnce ensures terminateCh is closed only once + terminateOnce sync.Once + + // adapterPipe is the FIFO message pipe for messages sent to the debug adapter. + // It assigns monotonically increasing sequence numbers at write time and + // maintains a seqMap of virtualSeq→originalIDESeq for response correlation. + adapterPipe *MessagePipe + + // idePipe is the FIFO message pipe for messages sent to the IDE. + // It assigns monotonically increasing sequence numbers at write time. + idePipe *MessagePipe + + // fallbackIDESeqCounter is used for IDE-bound seq assignment when idePipe + // has not yet been created (e.g., adapter launch failure before message loop). + fallbackIDESeqCounter atomic.Int64 +} + +// NewDapBridge creates a new DAP bridge with the given configuration. +func NewDapBridge(config BridgeConfig) *DapBridge { + log := config.Logger + if log.GetSink() == nil { + log = logr.Discard() + } + + executor := config.Executor + if executor == nil { + executor = process.NewOSExecutor(log) + } + + return &DapBridge{ + config: config, + executor: executor, + log: log, + terminateCh: make(chan struct{}), + } +} + +// RunWithConnection runs the bridge with an already-connected IDE connection. +// This is the main entry point when using BridgeManager. +// The handshake must have already been performed by the caller. +// +// The bridge will: +// 1. Launch the debug adapter using the provided config +// 2. Forward DAP messages bidirectionally +// 3. Terminate when the context is cancelled or errors occur +func (b *DapBridge) RunWithConnection(ctx context.Context, ideConn net.Conn) error { + return b.runWithConnectionAndConfig(ctx, ideConn, b.config.AdapterConfig) +} + +// runWithConnectionAndConfig is the internal implementation that accepts an adapter config. +func (b *DapBridge) runWithConnectionAndConfig(ctx context.Context, ideConn net.Conn, adapterConfig *DebugAdapterConfig) error { + defer b.terminate() + + b.log.Info("Bridge starting with pre-connected IDE", "sessionID", b.config.SessionID) + + // Create transport for IDE connection + b.ideTransport = NewUnixTransportWithContext(ctx, ideConn) + + // Launch debug adapter + b.log.V(1).Info("Launching debug adapter") + launchErr := b.launchAdapterWithConfig(ctx, adapterConfig) + if launchErr != nil { + b.sendErrorToIDE(fmt.Sprintf("Failed to launch debug adapter: %v", launchErr)) + return fmt.Errorf("failed to launch debug adapter: %w", launchErr) + } + defer b.adapter.Close() + + b.log.Info("Debug adapter launched", "pid", b.adapter.Pid()) + + // Start message forwarding + b.log.V(1).Info("Bridge connected, starting message loop") + return b.runMessageLoop(ctx) +} + +// launchAdapterWithConfig launches the debug adapter with the specified config. +func (b *DapBridge) launchAdapterWithConfig(ctx context.Context, config *DebugAdapterConfig) error { + var launchErr error + b.adapter, launchErr = LaunchDebugAdapter(ctx, b.executor, config, b.log) + return launchErr +} + +// runMessageLoop runs the bidirectional message forwarding loop. +func (b *DapBridge) runMessageLoop(ctx context.Context) error { + // Create an independent context for the pipes. This must NOT be derived + // from ctx because the ordered shutdown sequence needs the pipes to + // remain alive after ctx is cancelled so that queued messages (including + // shutdown events) can drain. The normal shutdown path uses CloseInput + // on each pipe for a graceful drain; pipeCancel is a fallback safety net. + pipeCtx, pipeCancel := context.WithCancel(context.Background()) + defer pipeCancel() + + // Create message pipes for both directions. + b.adapterPipe = NewMessagePipe(pipeCtx, b.adapter.Transport, "adapterPipe", b.log) + b.idePipe = NewMessagePipe(pipeCtx, b.ideTransport, "idePipe", b.log) + + // Track each goroutine independently so the shutdown sequence can + // wait for specific goroutines in the correct order. + var ( + adapterPipeResult error + idePipeResult error + adapterReaderResult error + ideReaderResult error + ) + + adapterPipeDone := make(chan struct{}) + idePipeDone := make(chan struct{}) + adapterReaderDone := make(chan struct{}) + ideReaderDone := make(chan struct{}) + + // errCh collects the first error for the initial select trigger. + errCh := make(chan error, 4) + + // Pipe writers + go func() { + adapterPipeResult = b.adapterPipe.Run(pipeCtx) + close(adapterPipeDone) + errCh <- adapterPipeResult + }() + go func() { + idePipeResult = b.idePipe.Run(pipeCtx) + close(idePipeDone) + errCh <- idePipeResult + }() + + // IDE → Adapter reader + go func() { + ideReaderResult = b.forwardIDEToAdapter(ctx) + close(ideReaderDone) + errCh <- ideReaderResult + }() + + // Adapter → IDE reader + go func() { + adapterReaderResult = b.forwardAdapterToIDE(ctx) + close(adapterReaderDone) + errCh <- adapterReaderResult + }() + + // Wait for first error, context cancellation, or adapter exit + var loopErr error + select { + case <-ctx.Done(): + b.log.V(1).Info("Context cancelled, shutting down") + case loopErr = <-errCh: + if loopErr != nil && !isExpectedShutdownErr(loopErr) { + b.log.Error(loopErr, "Message forwarding error") + } + case <-b.adapter.Done(): + b.log.V(1).Info("Debug adapter exited") + } + + // === Ordered graceful shutdown === + // + // The goal is to let the IDE-bound pipe (idePipe) drain any queued + // messages (e.g., a disconnect response, terminated event) before + // tearing down the IDE transport. The shutdown proceeds in dependency order: + // + // 1. adapter transport closed → adapter reader unblocked + // 2. adapter reader done → no more external idePipe.Send() calls + // 3. shutdown messages enqueued into idePipe (via Send) + // 4. idePipe input closed → graceful drain → all messages written + // 5. IDE transport closed → IDE reader unblocked + // 6. IDE reader done → no more adapterPipe.Send() + // 7. adapterPipe input closed → drain → done + + // Step 1: Close adapter transport to unblock the adapter→IDE reader. + b.adapter.Transport.Close() + + // Step 2: Wait for adapter reader to finish. After this, no goroutine + // will call idePipe.Send(). + <-adapterReaderDone + + // Step 3: Enqueue any final shutdown messages (e.g., TerminatedEvent) + // into idePipe so they are written in-order by the pipe's writer goroutine. + terminated := b.terminatedEventSeen.Load() + if !terminated { + if loopErr != nil && !isExpectedShutdownErr(loopErr) { + b.sendErrorToIDE(fmt.Sprintf("Debug session ended unexpectedly: %v", loopErr)) + } else { + b.sendTerminatedToIDE() + } + } + + // Step 4: Close idePipe input. The UnboundedChan drains all buffered + // messages (including shutdown messages just enqueued) to its output + // channel, and Run() writes them to the IDE transport. + b.idePipe.CloseInput() + <-idePipeDone + + // Step 5: Close IDE transport to unblock the IDE→adapter reader. + b.ideTransport.Close() + + // Step 6: Wait for IDE reader to finish. + <-ideReaderDone + + // Step 7: Close adapterPipe input and wait for drain. + b.adapterPipe.CloseInput() + <-adapterPipeDone + + // Collect errors from all goroutines. + var errs []error + for _, goroutineErr := range []error{adapterReaderResult, ideReaderResult, adapterPipeResult, idePipeResult} { + if goroutineErr != nil && !isExpectedShutdownErr(goroutineErr) { + errs = append(errs, goroutineErr) + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +// forwardIDEToAdapter reads messages from the IDE, intercepts as needed, +// and enqueues them to the adapterPipe for ordered writing. +func (b *DapBridge) forwardIDEToAdapter(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + msg, readErr := b.ideTransport.ReadMessage() + if readErr != nil { + return fmt.Errorf("failed to read from IDE: %w", readErr) + } + + env := NewMessageEnvelope(msg) + b.logEnvelopeMessage("IDE -> Adapter: received message from IDE", env) + + // Intercept and potentially modify the message + modifiedMsg, forward := b.interceptUpstreamMessage(msg) + if !forward { + b.logEnvelopeMessage("IDE -> Adapter: message not forwarded (handled locally)", env) + continue + } + + // Re-wrap if intercept returned a different message (e.g., modified typed message). + if modifiedMsg != msg { + env = NewMessageEnvelope(modifiedMsg) + } + + b.logEnvelopeMessage("IDE -> Adapter: enqueueing message for adapter", env) + b.adapterPipe.Send(env) + } +} + +// forwardAdapterToIDE reads messages from the debug adapter, intercepts as needed, +// remaps response seq values, and enqueues them to the idePipe for ordered writing. +func (b *DapBridge) forwardAdapterToIDE(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + msg, readErr := b.adapter.Transport.ReadMessage() + if readErr != nil { + return fmt.Errorf("failed to read from adapter: %w", readErr) + } + + env := NewMessageEnvelope(msg) + b.logEnvelopeMessage("Adapter -> IDE: received message from adapter", env) + + // Intercept and potentially handle the message + modifiedMsg, forward, asyncResponse := b.interceptDownstreamMessage(ctx, msg) + + // If there's an async response (e.g., RunInTerminalResponse), enqueue it + // to the adapter pipe so it gets a proper sequence number. + if asyncResponse != nil { + asyncEnv := NewMessageEnvelope(asyncResponse) + b.logEnvelopeMessage("Adapter -> IDE: enqueueing async response for adapter", asyncEnv) + b.adapterPipe.Send(asyncEnv) + } + + if !forward { + b.logEnvelopeMessage("Adapter -> IDE: message not forwarded (handled locally)", env) + continue + } + + // Re-wrap if intercept returned a different message. + if modifiedMsg != msg { + env = NewMessageEnvelope(modifiedMsg) + } + + // For response messages, restore the original IDE sequence number in + // request_seq so the IDE can correlate the response with its request. + b.adapterPipe.RemapResponseSeq(env) + + b.logEnvelopeMessage("Adapter -> IDE: enqueueing message for IDE", env) + b.idePipe.Send(env) + } +} + +// interceptUpstreamMessage intercepts messages from the IDE to the adapter. +// Returns the (possibly modified) message and whether to forward it. +func (b *DapBridge) interceptUpstreamMessage(msg dap.Message) (dap.Message, bool) { + switch req := msg.(type) { + case *dap.InitializeRequest: + // Ensure supportsRunInTerminalRequest is true + req.Arguments.SupportsRunInTerminalRequest = true + b.log.V(1).Info("Set supportsRunInTerminalRequest=true on InitializeRequest") + return req, true + + default: + return msg, true + } +} + +// interceptDownstreamMessage intercepts messages from the adapter to the IDE. +// Returns the (possibly modified) message, whether to forward it, and an optional async response. +func (b *DapBridge) interceptDownstreamMessage(ctx context.Context, msg dap.Message) (dap.Message, bool, dap.Message) { + switch m := msg.(type) { + case *dap.TerminatedEvent: + b.terminatedEventSeen.Store(true) + return msg, true, nil + + case *dap.OutputEvent: + // Capture output for logging if not using runInTerminal + b.handleOutputEvent(m) + return msg, true, nil + + case *dap.RunInTerminalRequest: + // Handle runInTerminal locally, don't forward to IDE + response := b.handleRunInTerminalRequest(ctx, m) + return nil, false, response + + default: + return msg, true, nil + } +} + +// handleOutputEvent processes output events from the debug adapter. +func (b *DapBridge) handleOutputEvent(event *dap.OutputEvent) { + // Only capture output if runInTerminal wasn't used + // (if runInTerminal was used, we capture directly from the process) + if !b.runInTerminalUsed.Load() && b.config.OutputHandler != nil { + b.config.OutputHandler.HandleOutput(event.Body.Category, event.Body.Output) + } +} + +// handleRunInTerminalRequest handles the runInTerminal reverse request. +// Returns the response to send back to the debug adapter. +// The response's Seq field is set to 0 because the adapterPipe will assign +// the actual sequence number when the message is dequeued for writing. +func (b *DapBridge) handleRunInTerminalRequest(ctx context.Context, req *dap.RunInTerminalRequest) *dap.RunInTerminalResponse { + b.log.Info("Handling RunInTerminal request", + "seq", req.Seq, + "kind", req.Arguments.Kind, + "title", req.Arguments.Title, + "cwd", req.Arguments.Cwd, + "args", req.Arguments.Args, + "envCount", len(req.Arguments.Env)) + + // Mark that runInTerminal was used + b.runInTerminalUsed.Store(true) + + // Build the command + if len(req.Arguments.Args) == 0 { + b.log.Error(fmt.Errorf("runInTerminal request has no arguments"), "RunInTerminal failed", + "requestSeq", req.Seq) + return &dap.RunInTerminalResponse{ + Response: dap.Response{ + ProtocolMessage: dap.ProtocolMessage{ + Type: "response", + }, + RequestSeq: req.Seq, + Command: req.Command, + Message: "runInTerminal requires at least one argument", + }, + } + } + + cmd := exec.Command(req.Arguments.Args[0], req.Arguments.Args[1:]...) + cmd.Dir = req.Arguments.Cwd + cmd.Stdout = b.config.StdoutWriter + cmd.Stderr = b.config.StderrWriter + + // Set environment from the request only (do not inherit current process environment) + if len(req.Arguments.Env) > 0 { + env := make([]string, 0, len(req.Arguments.Env)) + for k, v := range req.Arguments.Env { + if strVal, ok := v.(string); ok { + env = append(env, fmt.Sprintf("%s=%s", k, strVal)) + } + } + cmd.Env = env + } + + handle, startErr := b.executor.StartAndForget(cmd, process.CreationFlagsNone) + + response := &dap.RunInTerminalResponse{ + Response: dap.Response{ + ProtocolMessage: dap.ProtocolMessage{ + Type: "response", + }, + RequestSeq: req.Seq, + Command: req.Command, + Success: startErr == nil, + }, + } + + if startErr == nil { + response.Body.ProcessId = int(handle.Pid) + b.log.Info("RunInTerminal succeeded", + "requestSeq", req.Seq, + "processId", handle.Pid) + } else { + response.Message = startErr.Error() + b.log.Error(startErr, "RunInTerminal failed", + "requestSeq", req.Seq) + } + + return response +} + +// sendErrorToIDE sends an OutputEvent with category "stderr" followed by a TerminatedEvent +// to the IDE. When the idePipe is available, messages are enqueued through it so that +// sequence numbering and write serialization are handled by the pipe's writer goroutine. +// When the idePipe is not yet created (e.g., adapter launch failure before the message loop), +// messages are written directly to the IDE transport with a fallback sequence counter. +func (b *DapBridge) sendErrorToIDE(message string) { + outputEvent := newOutputEvent(0, "stderr", message+"\n") + + if b.idePipe != nil { + b.idePipe.Send(NewMessageEnvelope(outputEvent)) + b.sendTerminatedToIDE() + return + } + + if b.ideTransport == nil { + return + } + + outputEvent.Seq = int(b.fallbackIDESeqCounter.Add(1)) + if writeErr := b.ideTransport.WriteMessage(outputEvent); writeErr != nil { + b.log.V(1).Info("Failed to send error OutputEvent to IDE", "error", writeErr) + return + } + + b.sendTerminatedToIDE() +} + +// sendTerminatedToIDE sends a TerminatedEvent to the IDE so it knows the debug session has ended. +// When the idePipe is available, the event is enqueued through it; otherwise it is written +// directly to the IDE transport. +func (b *DapBridge) sendTerminatedToIDE() { + terminatedEvent := newTerminatedEvent(0) + + if b.idePipe != nil { + b.idePipe.Send(NewMessageEnvelope(terminatedEvent)) + return + } + + if b.ideTransport == nil { + return + } + + terminatedEvent.Seq = int(b.fallbackIDESeqCounter.Add(1)) + if writeErr := b.ideTransport.WriteMessage(terminatedEvent); writeErr != nil { + b.log.V(1).Info("Failed to send TerminatedEvent to IDE", "error", writeErr) + } +} + +// terminate marks the bridge as terminated. +func (b *DapBridge) terminate() { + b.terminateOnce.Do(func() { + close(b.terminateCh) + }) +} + +// logEnvelopeMessage logs a DAP message envelope at V(1) level, including raw JSON. +// Additional key-value pairs can be appended via extraKeysAndValues. +func (b *DapBridge) logEnvelopeMessage(logMsg string, env *MessageEnvelope, extraKeysAndValues ...any) { + if !b.log.V(1).Enabled() { + return + } + keysAndValues := []any{"message", env.Describe()} + if raw, ok := env.Inner.(*RawMessage); ok { + keysAndValues = append(keysAndValues, "rawJSON", string(raw.Data)) + } else if jsonBytes, marshalErr := json.Marshal(env.Inner); marshalErr == nil { + keysAndValues = append(keysAndValues, "rawJSON", string(jsonBytes)) + } + keysAndValues = append(keysAndValues, extraKeysAndValues...) + b.log.V(1).Info(logMsg, keysAndValues...) +} diff --git a/internal/dap/bridge_handshake.go b/internal/dap/bridge_handshake.go new file mode 100644 index 00000000..686196ef --- /dev/null +++ b/internal/dap/bridge_handshake.go @@ -0,0 +1,208 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net" +) + +// NOTE: Handshake validation (token/session verification) is handled directly +// by BridgeManager.validateHandshake, called from BridgeManager.handleConnection. +// No separate validator interface is needed since there is only one validation strategy. + +// HandshakeRequest is sent by the IDE after connecting to the Unix socket. +// It contains authentication credentials, session identification, and debug adapter configuration. +type HandshakeRequest struct { + // Token is the authentication token that must match the IDE session token. + Token string `json:"token"` + + // SessionID identifies the debug session to connect to. + SessionID string `json:"session_id"` + + // RunID is the IDE run session identifier. + // This is used to correlate the debug bridge with the executable's output writers + // so that debug adapter output can be captured to the correct log files. + RunID string `json:"run_id,omitempty"` + + // DebugAdapterConfig contains the configuration for launching the debug adapter. + // This is provided by the IDE during the handshake. + DebugAdapterConfig *DebugAdapterConfig `json:"debug_adapter_config,omitempty"` +} + +// HandshakeResponse is sent by the bridge after validating the handshake request. +type HandshakeResponse struct { + // Success indicates whether the handshake was successful. + Success bool `json:"success"` + + // Error contains the error message if Success is false. + Error string `json:"error,omitempty"` +} + +// ErrHandshakeFailed is returned when the handshake fails. +var ErrHandshakeFailed = errors.New("handshake failed") + +// maxHandshakeMessageSize is the maximum size of a handshake message (64KB). +// This prevents denial-of-service attacks via large messages. +const maxHandshakeMessageSize = 64 * 1024 + +// HandshakeReader reads handshake messages from a connection. +// Messages are length-prefixed: 4-byte big-endian length followed by JSON payload. +type HandshakeReader struct { + conn net.Conn +} + +// NewHandshakeReader creates a new HandshakeReader for the given connection. +func NewHandshakeReader(conn net.Conn) *HandshakeReader { + return &HandshakeReader{conn: conn} +} + +// ReadRequest reads a HandshakeRequest from the connection. +func (r *HandshakeReader) ReadRequest() (*HandshakeRequest, error) { + data, readErr := r.readMessage() + if readErr != nil { + return nil, fmt.Errorf("failed to read handshake request: %w", readErr) + } + + var req HandshakeRequest + if unmarshalErr := json.Unmarshal(data, &req); unmarshalErr != nil { + return nil, fmt.Errorf("failed to unmarshal handshake request: %w", unmarshalErr) + } + + return &req, nil +} + +// readMessage reads a length-prefixed message from the connection. +func (r *HandshakeReader) readMessage() ([]byte, error) { + // Read 4-byte length prefix (big-endian) + var lengthBuf [4]byte + if _, readErr := io.ReadFull(r.conn, lengthBuf[:]); readErr != nil { + return nil, fmt.Errorf("failed to read message length: %w", readErr) + } + + length := binary.BigEndian.Uint32(lengthBuf[:]) + if length == 0 { + return nil, errors.New("message length is zero") + } + if length > maxHandshakeMessageSize { + return nil, fmt.Errorf("message length %d exceeds maximum %d", length, maxHandshakeMessageSize) + } + + // Read the message body + data := make([]byte, length) + if _, readErr := io.ReadFull(r.conn, data); readErr != nil { + return nil, fmt.Errorf("failed to read message body: %w", readErr) + } + + return data, nil +} + +// HandshakeWriter writes handshake messages to a connection. +// Messages are length-prefixed: 4-byte big-endian length followed by JSON payload. +type HandshakeWriter struct { + conn net.Conn +} + +// NewHandshakeWriter creates a new HandshakeWriter for the given connection. +func NewHandshakeWriter(conn net.Conn) *HandshakeWriter { + return &HandshakeWriter{conn: conn} +} + +// WriteResponse writes a HandshakeResponse to the connection. +func (w *HandshakeWriter) WriteResponse(resp *HandshakeResponse) error { + data, marshalErr := json.Marshal(resp) + if marshalErr != nil { + return fmt.Errorf("failed to marshal handshake response: %w", marshalErr) + } + + return w.writeMessage(data) +} + +// WriteRequest writes a HandshakeRequest to the connection. +// This is used by the client side (IDE) to initiate the handshake. +func (w *HandshakeWriter) WriteRequest(req *HandshakeRequest) error { + data, marshalErr := json.Marshal(req) + if marshalErr != nil { + return fmt.Errorf("failed to marshal handshake request: %w", marshalErr) + } + + return w.writeMessage(data) +} + +// writeMessage writes a length-prefixed message to the connection. +func (w *HandshakeWriter) writeMessage(data []byte) error { + if len(data) > maxHandshakeMessageSize { + return fmt.Errorf("message length %d exceeds maximum %d", len(data), maxHandshakeMessageSize) + } + + // Write 4-byte length prefix (big-endian) + var lengthBuf [4]byte + binary.BigEndian.PutUint32(lengthBuf[:], uint32(len(data))) + + if _, writeErr := w.conn.Write(lengthBuf[:]); writeErr != nil { + return fmt.Errorf("failed to write message length: %w", writeErr) + } + + // Write the message body + if _, writeErr := w.conn.Write(data); writeErr != nil { + return fmt.Errorf("failed to write message body: %w", writeErr) + } + + return nil +} + +// ReadResponse reads a HandshakeResponse from the connection. +// This is used by the client side (IDE) to receive the handshake result. +func (r *HandshakeReader) ReadResponse() (*HandshakeResponse, error) { + data, readErr := r.readMessage() + if readErr != nil { + return nil, fmt.Errorf("failed to read handshake response: %w", readErr) + } + + var resp HandshakeResponse + if unmarshalErr := json.Unmarshal(data, &resp); unmarshalErr != nil { + return nil, fmt.Errorf("failed to unmarshal handshake response: %w", unmarshalErr) + } + + return &resp, nil +} + +// performClientHandshake sends a handshake request and waits for the response. +// This is a convenience function for the client side (IDE). +// Returns nil on success, or an error on failure. +func performClientHandshake(conn net.Conn, token, sessionID, runID string) error { + writer := NewHandshakeWriter(conn) + reader := NewHandshakeReader(conn) + + // Send the handshake request + req := &HandshakeRequest{ + Token: token, + SessionID: sessionID, + RunID: runID, + } + if writeErr := writer.WriteRequest(req); writeErr != nil { + return fmt.Errorf("failed to send handshake request: %w", writeErr) + } + + // Read the response + resp, readErr := reader.ReadResponse() + if readErr != nil { + return fmt.Errorf("failed to read handshake response: %w", readErr) + } + + if !resp.Success { + if resp.Error != "" { + return fmt.Errorf("%w: %s", ErrHandshakeFailed, resp.Error) + } + return ErrHandshakeFailed + } + + return nil +} diff --git a/internal/dap/bridge_handshake_test.go b/internal/dap/bridge_handshake_test.go new file mode 100644 index 00000000..8b54ba76 --- /dev/null +++ b/internal/dap/bridge_handshake_test.go @@ -0,0 +1,140 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "net" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupHandshakeConn creates a Unix socket pair for handshake testing. +func setupHandshakeConn(t *testing.T, suffix string) (net.Conn, net.Conn) { + t.Helper() + + socketPath := uniqueSocketPath(t, suffix) + + listener, listenErr := net.Listen("unix", socketPath) + require.NoError(t, listenErr) + defer listener.Close() + + var wg sync.WaitGroup + var serverConn net.Conn + var acceptErr error + + wg.Add(1) + go func() { + defer wg.Done() + serverConn, acceptErr = listener.Accept() + }() + + clientConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + + wg.Wait() + require.NoError(t, acceptErr) + + t.Cleanup(func() { + clientConn.Close() + serverConn.Close() + }) + + return clientConn, serverConn +} + +func TestHandshakeWriteAndReadRequest(t *testing.T) { + t.Parallel() + + clientConn, serverConn := setupHandshakeConn(t, "hs-rq") + + clientWriter := NewHandshakeWriter(clientConn) + serverReader := NewHandshakeReader(serverConn) + + req := &HandshakeRequest{ + Token: "test-token-123", + SessionID: "session-456", + } + + writeErr := clientWriter.WriteRequest(req) + require.NoError(t, writeErr) + + receivedReq, readErr := serverReader.ReadRequest() + require.NoError(t, readErr) + + assert.Equal(t, req.Token, receivedReq.Token) + assert.Equal(t, req.SessionID, receivedReq.SessionID) +} + +func TestHandshakeWriteAndReadResponse(t *testing.T) { + t.Parallel() + + clientConn, serverConn := setupHandshakeConn(t, "hs-rs") + + serverWriter := NewHandshakeWriter(serverConn) + clientReader := NewHandshakeReader(clientConn) + + resp := &HandshakeResponse{ + Success: true, + } + + writeErr := serverWriter.WriteResponse(resp) + require.NoError(t, writeErr) + + receivedResp, readErr := clientReader.ReadResponse() + require.NoError(t, readErr) + + assert.True(t, receivedResp.Success) + assert.Empty(t, receivedResp.Error) +} + +func TestHandshakeWriteAndReadErrorResponse(t *testing.T) { + t.Parallel() + + clientConn, serverConn := setupHandshakeConn(t, "hs-er") + + serverWriter := NewHandshakeWriter(serverConn) + clientReader := NewHandshakeReader(clientConn) + + resp := &HandshakeResponse{ + Success: false, + Error: "authentication failed", + } + + writeErr := serverWriter.WriteResponse(resp) + require.NoError(t, writeErr) + + receivedResp, readErr := clientReader.ReadResponse() + require.NoError(t, readErr) + + assert.False(t, receivedResp.Success) + assert.Equal(t, "authentication failed", receivedResp.Error) +} + +func TestHandshakeRejectsOversizedMessage(t *testing.T) { + t.Parallel() + + clientConn, _ := setupHandshakeConn(t, "hs-sz") + + writer := NewHandshakeWriter(clientConn) + + // Create a request with a very long token + largeToken := make([]byte, maxHandshakeMessageSize+1) + for i := range largeToken { + largeToken[i] = 'a' + } + + req := &HandshakeRequest{ + Token: string(largeToken), + SessionID: "session", + } + + // Writing should fail due to size limit + writeErr := writer.WriteRequest(req) + assert.Error(t, writeErr) +} diff --git a/internal/dap/bridge_integration_test.go b/internal/dap/bridge_integration_test.go new file mode 100644 index 00000000..057addad --- /dev/null +++ b/internal/dap/bridge_integration_test.go @@ -0,0 +1,832 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/go-dap" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/wait" + + apiv1 "github.com/microsoft/dcp/api/v1" + "github.com/microsoft/dcp/internal/testutil" + "github.com/microsoft/dcp/pkg/osutil" + "github.com/microsoft/dcp/pkg/process" + pkgtestutil "github.com/microsoft/dcp/pkg/testutil" +) + +// ===== Integration Tests ===== + +func TestBridge_RunWithConnection(t *testing.T) { + t.Parallel() + + // Test that RunWithConnection works correctly with an already-connected net.Conn + // This simulates the flow where BridgeSocketManager has already performed handshake + + // We'll use a pipe to simulate the connection + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + sessionID := "test-session" + + config := BridgeConfig{ + SessionID: sessionID, + AdapterConfig: &DebugAdapterConfig{ + Args: []string{"echo", "hello"}, // Simple command that exits immediately + Mode: DebugAdapterModeStdio, + }, + Logger: logr.Discard(), + } + + bridge := NewDapBridge(config) + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + // Drain clientConn so the bridge can write error messages to the IDE transport + // without blocking on the pipe. + go func() { + _, _ = io.Copy(io.Discard, clientConn) + }() + + // Run the bridge in a goroutine - it will fail to launch the adapter since we're using a fake command + // but this tests the basic flow + go func() { + _ = bridge.RunWithConnection(ctx, serverConn) + }() + + // Wait for the bridge to terminate (it will fail to launch the fake adapter and exit) + select { + case <-bridge.terminateCh: + // Expected - bridge terminated after failing to launch adapter + case <-time.After(5 * time.Second): + cancel() + t.Fatal("bridge did not terminate in time") + } +} + +func TestBridgeManager_HandshakeValidation(t *testing.T) { + t.Parallel() + + // Test that BridgeManager correctly validates handshakes + + socketDir := shortTempDir(t) + manager := NewBridgeManager(BridgeManagerConfig{ + SocketDir: socketDir, + HandshakeTimeout: 2 * time.Second, + }, logr.Discard()) + + // Register a session with a token + session, regErr := manager.RegisterSession("valid-session", "test-token") + require.NoError(t, regErr) + require.NotNil(t, session) + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + // Start bridge manager in background + go func() { + _ = manager.Start(ctx) + }() + + // Wait for it to be ready + select { + case <-manager.Ready(): + // Good + case <-time.After(2 * time.Second): + t.Fatal("bridge manager failed to become ready") + } + + socketPath, socketPathErr := manager.SocketPath(ctx) + require.NoError(t, socketPathErr) + + // Connect with wrong token - should fail + ideConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + defer ideConn.Close() + + handshakeErr := performClientHandshake(ideConn, "wrong-token", "valid-session", "") + require.Error(t, handshakeErr, "handshake should fail with wrong token") + assert.ErrorIs(t, handshakeErr, ErrHandshakeFailed) + + cancel() +} + +func TestBridgeManager_SessionNotFound(t *testing.T) { + t.Parallel() + + // Test handshake failure when session doesn't exist + + socketDir := shortTempDir(t) + manager := NewBridgeManager(BridgeManagerConfig{ + SocketDir: socketDir, + HandshakeTimeout: 2 * time.Second, + }, logr.Discard()) + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + // Start bridge manager in background + go func() { + _ = manager.Start(ctx) + }() + + // Wait for it to be ready + select { + case <-manager.Ready(): + // Good + case <-time.After(2 * time.Second): + t.Fatal("bridge manager failed to become ready") + } + + socketPath, socketPathErr := manager.SocketPath(ctx) + require.NoError(t, socketPathErr) + + // Connect with non-existent session - should fail + ideConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + defer ideConn.Close() + + handshakeErr := performClientHandshake(ideConn, "any-token", "nonexistent-session", "") + require.Error(t, handshakeErr, "handshake should fail with unknown session") + assert.ErrorIs(t, handshakeErr, ErrHandshakeFailed) + + cancel() +} + +func TestBridgeManager_HandshakeTimeout(t *testing.T) { + t.Parallel() + + socketDir := shortTempDir(t) + manager := NewBridgeManager(BridgeManagerConfig{ + SocketDir: socketDir, + HandshakeTimeout: 200 * time.Millisecond, // Short timeout + }, logr.Discard()) + _, _ = manager.RegisterSession("timeout-session", "test-token") + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + // Start bridge manager in background + go func() { + _ = manager.Start(ctx) + }() + + // Wait for it to be ready + select { + case <-manager.Ready(): + // Good + case <-time.After(2 * time.Second): + t.Fatal("bridge manager failed to become ready") + } + + socketPath, socketPathErr := manager.SocketPath(ctx) + require.NoError(t, socketPathErr) + + // Connect but don't send handshake - should timeout and close connection + ideConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + defer ideConn.Close() + + // Poll until the server closes our connection due to handshake timeout + pollErr := wait.PollUntilContextCancel(ctx, 100*time.Millisecond, true, func(_ context.Context) (bool, error) { + _ = ideConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + buf := make([]byte, 1) + _, readErr := ideConn.Read(buf) + // Connection closed when read returns a non-timeout error (EOF, closed, etc.) + if readErr != nil { + if netErr, ok := readErr.(net.Error); ok && netErr.Timeout() { + return false, nil // Still open, keep polling + } + return true, nil // Non-timeout error means connection was closed + } + return false, nil + }) + require.NoError(t, pollErr, "connection should be closed by server after handshake timeout") + + cancel() +} + +func TestBridge_OutputEventCapture(t *testing.T) { + t.Parallel() + + // This test verifies that output events are captured when runInTerminal is not used. + // We use a simpler approach: directly test the handleOutputEvent function behavior. + + stdoutBuf := &bytes.Buffer{} + stderrBuf := &bytes.Buffer{} + + config := BridgeConfig{ + SessionID: "session", + StdoutWriter: stdoutBuf, + StderrWriter: stderrBuf, + OutputHandler: &testOutputHandler{ + stdout: stdoutBuf, + stderr: stderrBuf, + }, + } + + bridge := NewDapBridge(config) + + // Initially runInTerminal not used + assert.False(t, bridge.runInTerminalUsed.Load()) + + // Simulate handling an output event + outputEvent := &dap.OutputEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "event", + }, + Event: "output", + }, + Body: dap.OutputEventBody{ + Category: "stdout", + Output: "Hello from debug adapter\n", + }, + } + + bridge.handleOutputEvent(outputEvent) + + // Output should have been captured + assert.Contains(t, stdoutBuf.String(), "Hello from debug adapter") +} + +// testOutputHandler captures output for testing. +type testOutputHandler struct { + stdout io.Writer + stderr io.Writer +} + +func (h *testOutputHandler) HandleOutput(category string, output string) { + if category == "stdout" && h.stdout != nil { + _, _ = h.stdout.Write([]byte(output)) + } else if category == "stderr" && h.stderr != nil { + _, _ = h.stderr.Write([]byte(output)) + } +} + +func TestBridge_InitializeInterception(t *testing.T) { + t.Parallel() + + // Test that the bridge intercepts initialize requests to force supportsRunInTerminalRequest=true + + config := BridgeConfig{ + SessionID: "session", + } + + bridge := NewDapBridge(config) + + // Create an initialize request with supportsRunInTerminalRequest=false + initReq := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "request", + }, + Command: "initialize", + }, + Arguments: dap.InitializeRequestArguments{ + ClientID: "test", + SupportsRunInTerminalRequest: false, // IDE says it doesn't support it + }, + } + + // Apply upstream interception + modified, forward := bridge.interceptUpstreamMessage(initReq) + + assert.True(t, forward, "initialize should be forwarded") + modifiedInit, ok := modified.(*dap.InitializeRequest) + require.True(t, ok) + assert.True(t, modifiedInit.Arguments.SupportsRunInTerminalRequest, + "supportsRunInTerminalRequest should be forced to true") +} + +func TestBridge_RunInTerminalInterception(t *testing.T) { + t.Parallel() + + // Test that runInTerminal requests are intercepted and not forwarded + + config := BridgeConfig{ + SessionID: "session", + } + + bridge := NewDapBridge(config) + + // Create a runInTerminal request + ritReq := &dap.RunInTerminalRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "request", + }, + Command: "runInTerminal", + }, + Arguments: dap.RunInTerminalRequestArguments{ + Kind: "integrated", + Title: "Debug", + Cwd: "/tmp", + Args: []string{"echo", "hello"}, + }, + } + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + // Apply downstream interception + _, forward, asyncResponse := bridge.interceptDownstreamMessage(ctx, ritReq) + + assert.False(t, forward, "runInTerminal should NOT be forwarded to IDE") + assert.NotNil(t, asyncResponse, "should return an async response") + + // The response should be a RunInTerminalResponse + ritResp, ok := asyncResponse.(*dap.RunInTerminalResponse) + require.True(t, ok, "async response should be RunInTerminalResponse") + assert.Equal(t, "runInTerminal", ritResp.Command) + assert.Equal(t, 1, ritResp.RequestSeq) + + // runInTerminalUsed should now be true + assert.True(t, bridge.runInTerminalUsed.Load()) +} + +func TestBridge_MessageForwarding(t *testing.T) { + t.Parallel() + + // Test that non-intercepted messages are forwarded unchanged + + config := BridgeConfig{ + SessionID: "session", + } + + bridge := NewDapBridge(config) + + // Test upstream message (setBreakpoints - should pass through) + setBreakpointsReq := &dap.SetBreakpointsRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "request", + }, + Command: "setBreakpoints", + }, + } + + modified, forward := bridge.interceptUpstreamMessage(setBreakpointsReq) + assert.True(t, forward, "setBreakpoints should be forwarded") + assert.Equal(t, setBreakpointsReq, modified, "message should not be modified") + + // Test downstream message (stopped event - should pass through) + stoppedEvent := &dap.StoppedEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 2, + Type: "event", + }, + Event: "stopped", + }, + Body: dap.StoppedEventBody{ + Reason: "breakpoint", + ThreadId: 1, + }, + } + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + modifiedDown, forwardDown, asyncResp := bridge.interceptDownstreamMessage(ctx, stoppedEvent) + assert.True(t, forwardDown, "stopped event should be forwarded") + assert.Equal(t, stoppedEvent, modifiedDown, "message should not be modified") + assert.Nil(t, asyncResp, "no async response expected") +} + +func TestBridge_OutputEventForwarding(t *testing.T) { + t.Parallel() + + // Test that output events are forwarded even when captured + + stdoutBuf := &bytes.Buffer{} + + config := BridgeConfig{ + SessionID: "session", + OutputHandler: &testOutputHandler{ + stdout: stdoutBuf, + }, + } + + bridge := NewDapBridge(config) + + outputEvent := &dap.OutputEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "event", + }, + Event: "output", + }, + Body: dap.OutputEventBody{ + Category: "stdout", + Output: "test output", + }, + } + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + modified, forward, asyncResp := bridge.interceptDownstreamMessage(ctx, outputEvent) + + // Output event should still be forwarded to IDE + assert.True(t, forward, "output event should be forwarded") + assert.Equal(t, outputEvent, modified) + assert.Nil(t, asyncResp) + + // And should have been captured + assert.Contains(t, stdoutBuf.String(), "test output") +} + +func TestBridge_OutputEventNotCapturedWhenRunInTerminalUsed(t *testing.T) { + t.Parallel() + + // Test that output events are NOT captured when runInTerminal was used + + stdoutBuf := &bytes.Buffer{} + + config := BridgeConfig{ + SessionID: "session", + OutputHandler: &testOutputHandler{ + stdout: stdoutBuf, + }, + } + + bridge := NewDapBridge(config) + + // Simulate runInTerminal being used + bridge.runInTerminalUsed.Store(true) + + outputEvent := &dap.OutputEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "event", + }, + Event: "output", + }, + Body: dap.OutputEventBody{ + Category: "stdout", + Output: "should not be captured", + }, + } + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + _, forward, _ := bridge.interceptDownstreamMessage(ctx, outputEvent) + + // Output event should still be forwarded + assert.True(t, forward) + + // But should NOT have been captured (buffer should be empty) + assert.Empty(t, stdoutBuf.String(), "output should not be captured when runInTerminal was used") +} + +func TestBridge_TerminatedEventTracking(t *testing.T) { + t.Parallel() + + // Test that interceptDownstreamMessage tracks TerminatedEvent + + config := BridgeConfig{ + SessionID: "session", + } + + bridge := NewDapBridge(config) + + // Initially terminatedEventSeen should be false + assert.False(t, bridge.terminatedEventSeen.Load()) + + terminatedEvent := &dap.TerminatedEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: 1, + Type: "event", + }, + Event: "terminated", + }, + } + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + modified, forward, asyncResp := bridge.interceptDownstreamMessage(ctx, terminatedEvent) + + assert.True(t, forward, "terminated event should be forwarded to IDE") + assert.Equal(t, terminatedEvent, modified) + assert.Nil(t, asyncResp) + + // terminatedEventSeen should now be true + assert.True(t, bridge.terminatedEventSeen.Load(), "bridge should track that TerminatedEvent was seen") +} + +func TestBridge_SendErrorToIDE(t *testing.T) { + t.Parallel() + + // Test that sendErrorToIDE sends an OutputEvent followed by a TerminatedEvent + // through the IDE transport + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + config := BridgeConfig{ + SessionID: "session", + Logger: logr.Discard(), + } + + bridge := NewDapBridge(config) + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + bridge.ideTransport = NewUnixTransportWithContext(ctx, serverConn) + + // Read messages from the client side in a goroutine + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + msgCh := make(chan dap.Message, 2) + go func() { + for i := 0; i < 2; i++ { + msg, readErr := clientTransport.ReadMessage() + if readErr != nil { + return + } + msgCh <- msg + } + }() + + bridge.sendErrorToIDE("adapter crashed unexpectedly") + + // Should receive OutputEvent first + msg1 := <-msgCh + outputEvent, ok := msg1.(*dap.OutputEvent) + require.True(t, ok, "first message should be OutputEvent, got %T", msg1) + assert.Equal(t, "stderr", outputEvent.Body.Category) + assert.Contains(t, outputEvent.Body.Output, "adapter crashed unexpectedly") + + // Then TerminatedEvent + msg2 := <-msgCh + _, ok = msg2.(*dap.TerminatedEvent) + require.True(t, ok, "second message should be TerminatedEvent, got %T", msg2) +} + +func TestBridge_SendTerminatedToIDE(t *testing.T) { + t.Parallel() + + // Test that sendTerminatedToIDE sends only a TerminatedEvent + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + config := BridgeConfig{ + SessionID: "session", + Logger: logr.Discard(), + } + + bridge := NewDapBridge(config) + + ctx, cancel := pkgtestutil.GetTestContext(t, 5*time.Second) + defer cancel() + + bridge.ideTransport = NewUnixTransportWithContext(ctx, serverConn) + + // Read messages from the client side + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + msgCh := make(chan dap.Message, 1) + go func() { + msg, readErr := clientTransport.ReadMessage() + if readErr != nil { + return + } + msgCh <- msg + }() + + bridge.sendTerminatedToIDE() + + msg := <-msgCh + _, ok := msg.(*dap.TerminatedEvent) + require.True(t, ok, "message should be TerminatedEvent, got %T", msg) +} + +func TestBridge_SendErrorToIDE_NilTransport(t *testing.T) { + t.Parallel() + + // Test that sendErrorToIDE is a no-op when ideTransport is nil (no panic) + + config := BridgeConfig{ + SessionID: "session", + Logger: logr.Discard(), + } + + bridge := NewDapBridge(config) + + // Should not panic + bridge.sendErrorToIDE("some error") + bridge.sendTerminatedToIDE() +} + +// performHandshakeWithAdapterConfig sends a full handshake request including +// debug adapter configuration, and reads the response. +// This is needed because performClientHandshake does not include adapter config, +// making it insufficient for end-to-end tests through BridgeSocketManager. +func performHandshakeWithAdapterConfig( + conn net.Conn, + token, sessionID, runID string, + adapterConfig *DebugAdapterConfig, +) error { + writer := NewHandshakeWriter(conn) + reader := NewHandshakeReader(conn) + + req := &HandshakeRequest{ + Token: token, + SessionID: sessionID, + RunID: runID, + DebugAdapterConfig: adapterConfig, + } + if writeErr := writer.WriteRequest(req); writeErr != nil { + return fmt.Errorf("failed to send handshake request: %w", writeErr) + } + + resp, readErr := reader.ReadResponse() + if readErr != nil { + return fmt.Errorf("failed to read handshake response: %w", readErr) + } + + if !resp.Success { + if resp.Error != "" { + return fmt.Errorf("%w: %s", ErrHandshakeFailed, resp.Error) + } + return ErrHandshakeFailed + } + + return nil +} + +// resolveDebuggeeSourcePath returns the absolute path to test/debuggee/debuggee.go. +func resolveDebuggeeSourcePath(t *testing.T) string { + t.Helper() + rootDir, findErr := osutil.FindRootFor(osutil.FileTarget, "test", "debuggee", "debuggee.go") + require.NoError(t, findErr, "could not find repo root containing test/debuggee/debuggee.go") + return filepath.Join(rootDir, "test", "debuggee", "debuggee.go") +} + +func TestBridge_DelveEndToEnd(t *testing.T) { + t.Parallel() + + // Locate the debuggee binary (built by 'make test-prereqs' with debug symbols). + toolDir, toolDirErr := testutil.GetTestToolDir("debuggee") + if toolDirErr != nil { + t.Skip("debuggee binary not found (run 'make test-prereqs' first):", toolDirErr) + } + debuggeeName := "debuggee" + if runtime.GOOS == "windows" { + debuggeeName += ".exe" + } + debuggeeBinary := filepath.Join(toolDir, debuggeeName) + + // Resolve the source file path for setting breakpoints. + debuggeeSource := resolveDebuggeeSourcePath(t) + breakpointLine := 18 // result := compute(10) + + ctx, cancel := pkgtestutil.GetTestContext(t, 30*time.Second) + defer cancel() + + log := logr.Discard() + executor := process.NewOSExecutor(log) + defer executor.Dispose() + + // Set up bridge manager and register a session. + socketDir := shortTempDir(t) + manager := NewBridgeManager(BridgeManagerConfig{ + SocketDir: socketDir, + Executor: executor, + HandshakeTimeout: 5 * time.Second, + }, log) + + token := "test-delve-token" + sessionID := "delve-e2e-session" + session, regErr := manager.RegisterSession(sessionID, token) + require.NoError(t, regErr) + require.NotNil(t, session) + + // Start bridge manager in background. + go func() { + _ = manager.Start(ctx) + }() + + select { + case <-manager.Ready(): + case <-time.After(5 * time.Second): + t.Fatal("bridge manager failed to become ready") + } + + socketPath, socketPathErr := manager.SocketPath(ctx) + require.NoError(t, socketPathErr) + require.NotEmpty(t, socketPath) + + // Connect to the Unix socket as the IDE. + ideConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + defer ideConn.Close() + + // Perform handshake with dlv dap adapter config (tcp-callback: bridge listens, dlv connects). + // The adapter process does not inherit the current process environment, so we must + // explicitly pass environment variables needed by the Go toolchain. + adapterEnv := envVarsFromOS("PATH", "HOME", "GOPATH", "GOROOT", "GOMODCACHE") + handshakeErr := performHandshakeWithAdapterConfig(ideConn, token, sessionID, "delve-run-id", &DebugAdapterConfig{ + Args: []string{ + "go", "tool", "github.com/go-delve/delve/cmd/dlv", + "dap", "--client-addr=127.0.0.1:{{port}}", + }, + Mode: DebugAdapterModeTCPCallback, + Env: adapterEnv, + }) + require.NoError(t, handshakeErr, "handshake with adapter config should succeed") + + // Create the DAP test client over the connected Unix socket. + ideTransport := NewUnixTransportWithContext(ctx, ideConn) + client := NewTestClient(ctx, ideTransport) + defer client.Close() + + // === DAP Protocol Sequence === + // dlv sends the 'initialized' event after receiving the 'launch' request, + // so the sequence is: initialize → launch → initialized → setBreakpoints → configurationDone. + + // 1. Initialize + initResp, initErr := client.Initialize(ctx) + require.NoError(t, initErr, "initialize should succeed") + require.NotNil(t, initResp) + assert.True(t, initResp.Body.SupportsConfigurationDoneRequest, + "dlv should support configurationDone") + + // 2. Launch the debuggee binary (exec mode — dlv runs the pre-built binary directly). + launchErr := client.Launch(ctx, debuggeeBinary, false) + require.NoError(t, launchErr, "launch should succeed") + + // 3. Wait for the 'initialized' event from dlv (sent after launch). + _, initializedErr := client.WaitForEvent("initialized", 10*time.Second) + require.NoError(t, initializedErr, "should receive initialized event from dlv") + + // 4. Set breakpoints on the debuggee source. + bpResp, bpErr := client.SetBreakpoints(ctx, debuggeeSource, []int{breakpointLine}) + require.NoError(t, bpErr, "setBreakpoints should succeed") + require.Len(t, bpResp.Body.Breakpoints, 1) + assert.True(t, bpResp.Body.Breakpoints[0].Verified, + "breakpoint at line %d should be verified", breakpointLine) + + // 5. Signal configuration is complete — program begins executing. + configDoneErr := client.ConfigurationDone(ctx) + require.NoError(t, configDoneErr, "configurationDone should succeed") + + // 6. Wait for the program to hit the breakpoint. + stoppedEvent, stoppedErr := client.WaitForStoppedEvent(10 * time.Second) + require.NoError(t, stoppedErr, "should receive stopped event at breakpoint") + assert.Equal(t, "breakpoint", stoppedEvent.Body.Reason) + assert.Greater(t, stoppedEvent.Body.ThreadId, 0, "thread ID should be positive") + + // 7. Continue execution — program runs to completion. + continueErr := client.Continue(ctx, stoppedEvent.Body.ThreadId) + require.NoError(t, continueErr, "continue should succeed") + + // 8. Wait for the program to terminate. + terminatedErr := client.WaitForTerminatedEvent(10 * time.Second) + require.NoError(t, terminatedErr, "should receive terminated event") + + // 9. Disconnect from the debug adapter. + disconnectErr := client.Disconnect(ctx, true) + require.NoError(t, disconnectErr, "disconnect should succeed") +} + +// envVarsFromOS returns apiv1.EnvVar entries for the given environment variable names, +// including only those that are set in the current process environment. +func envVarsFromOS(names ...string) []apiv1.EnvVar { + var envVars []apiv1.EnvVar + for _, name := range names { + if val, ok := os.LookupEnv(name); ok { + envVars = append(envVars, apiv1.EnvVar{Name: name, Value: val}) + } + } + return envVars +} diff --git a/internal/dap/bridge_manager.go b/internal/dap/bridge_manager.go new file mode 100644 index 00000000..26c1ebe4 --- /dev/null +++ b/internal/dap/bridge_manager.go @@ -0,0 +1,530 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/go-logr/logr" + "github.com/microsoft/dcp/internal/networking" + "github.com/microsoft/dcp/pkg/process" +) + +const ( + // DefaultSocketNamePrefix is the default prefix for the DAP bridge socket name. + // A random suffix is appended to support multiple DCP instances. + DefaultSocketNamePrefix = "dcp-dap-" + + // DefaultHandshakeTimeout is the default timeout for reading the handshake. + DefaultHandshakeTimeout = 30 * time.Second +) + +// BridgeSessionState represents the current state of a bridge session. +type BridgeSessionState int + +const ( + // BridgeSessionStateCreated indicates the session has been registered but bridge not started. + BridgeSessionStateCreated BridgeSessionState = iota + + // BridgeSessionStateConnected indicates the IDE is connected and debugging is active. + BridgeSessionStateConnected + + // BridgeSessionStateTerminated indicates the session has ended. + BridgeSessionStateTerminated + + // BridgeSessionStateError indicates the session encountered an error. + BridgeSessionStateError +) + +// String returns a string representation of the session state. +func (s BridgeSessionState) String() string { + switch s { + case BridgeSessionStateCreated: + return "created" + case BridgeSessionStateConnected: + return "connected" + case BridgeSessionStateTerminated: + return "terminated" + case BridgeSessionStateError: + return "error" + default: + return "unknown" + } +} + +// BridgeSession holds the state for a debug bridge session. +type BridgeSession struct { + // ID is the unique identifier for this session. + ID string + + // Token is the authentication token for this session. + // This is the same token used for IDE authentication (reused, not generated). + Token string + + // State is the current session state. + State BridgeSessionState + + // Connected indicates whether an IDE has connected to this session. + // Only one connection is allowed per session. + Connected bool + + // CreatedAt is when the session was created. + CreatedAt time.Time + + // Error holds any error message if State is BridgeSessionStateError. + Error string +} + +// Error constants for session management. +var ( + ErrBridgeSessionNotFound = errors.New("bridge session not found") + ErrBridgeSessionAlreadyExists = errors.New("bridge session already exists") + ErrBridgeSessionInvalidToken = errors.New("invalid session token") + ErrBridgeSessionAlreadyConnected = errors.New("session already connected") + ErrBridgeSocketNotReady = errors.New("bridge socket is not ready") +) + +// BridgeConnectionHandler is called when a new bridge connection is established, +// after the handshake has been validated. It returns the OutputHandler and stdout/stderr +// writers to use for the bridge session. This allows the caller to wire debug adapter +// output into the appropriate log files for the executable resource. +// +// sessionID is the bridge session identifier (typically the Executable UID). +// runID is the IDE run session identifier provided during the handshake. +// +// If the handler returns a nil OutputHandler, output events from the debug adapter will +// not be captured (they are still forwarded to the IDE). If stdout/stderr writers are nil, +// runInTerminal process output will be discarded. +type BridgeConnectionHandler func(sessionID string, runID string) (OutputHandler, io.Writer, io.Writer) + +// BridgeManagerConfig contains configuration for the BridgeManager. +type BridgeManagerConfig struct { + // SocketDir is the root directory where the secure socket directory will be created. + // If empty, os.UserCacheDir() is used. + SocketDir string + + // SocketNamePrefix is the prefix for the socket file name. + // A random suffix is appended to support multiple DCP instances. + // If empty, DefaultSocketNamePrefix is used. + SocketNamePrefix string + + // Executor is the process executor for debug adapter processes. + // If nil, a new executor will be created. + Executor process.Executor + + // HandshakeTimeout is the timeout for reading the handshake from a connection. + // If zero, defaults to DefaultHandshakeTimeout. + HandshakeTimeout time.Duration + + // ConnectionHandler is called when a bridge connection is established to resolve + // the OutputHandler and stdout/stderr writers for the session. If nil, output + // from debug sessions will not be captured to executable log files. + ConnectionHandler BridgeConnectionHandler +} + +// BridgeManager manages DAP bridge sessions and a shared Unix socket for IDE connections. +// It accepts incoming connections, performs handshake validation, and dispatches +// connections to the appropriate bridge sessions. +type BridgeManager struct { + config BridgeManagerConfig + listener *networking.PrivateUnixSocketListener + log logr.Logger + executor process.Executor + + // Socket configuration + socketDir string + socketPrefix string + readyCh chan struct{} + readyOnce *sync.Once + + // listenerCh is closed by Start() after the listener field has been set + // (whether successfully or not). SocketPath() blocks on this channel so + // that it never observes the listener before Start() has initialised it. + listenerCh chan struct{} + listenerOnce *sync.Once + + // mu protects sessions and activeBridges. + mu *sync.Mutex + sessions map[string]*BridgeSession + activeBridges map[string]*DapBridge +} + +// NewBridgeManager creates a new BridgeManager with the given configuration. +func NewBridgeManager(config BridgeManagerConfig, log logr.Logger) *BridgeManager { + executor := config.Executor + if executor == nil { + executor = process.NewOSExecutor(log) + } + + socketDir := config.SocketDir + socketPrefix := config.SocketNamePrefix + if socketPrefix == "" { + socketPrefix = DefaultSocketNamePrefix + } + + return &BridgeManager{ + config: config, + log: log, + executor: executor, + socketDir: socketDir, + socketPrefix: socketPrefix, + readyCh: make(chan struct{}), + readyOnce: &sync.Once{}, + listenerCh: make(chan struct{}), + listenerOnce: &sync.Once{}, + mu: &sync.Mutex{}, + sessions: make(map[string]*BridgeSession), + activeBridges: make(map[string]*DapBridge), + } +} + +// RegisterSession creates and registers a new bridge session. +// The token parameter should be the IDE session token (reused for bridge authentication). +// Returns the created session. +func (m *BridgeManager) RegisterSession(sessionID string, token string) (*BridgeSession, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.sessions[sessionID]; exists { + return nil, ErrBridgeSessionAlreadyExists + } + + session := &BridgeSession{ + ID: sessionID, + Token: token, + State: BridgeSessionStateCreated, + CreatedAt: time.Now(), + } + + m.sessions[sessionID] = session + m.log.Info("Registered bridge session", "sessionID", sessionID) + return session, nil +} + +// SocketPath returns the path to the Unix socket. +// It blocks until Start() has finished initialising the listener or ctx is cancelled. +func (m *BridgeManager) SocketPath(ctx context.Context) (string, error) { + select { + case <-m.listenerCh: + // Start() has set the listener field. + case <-ctx.Done(): + return "", fmt.Errorf("waiting for bridge socket: %w", ctx.Err()) + } + + if m.listener == nil { + return "", ErrBridgeSocketNotReady + } + return m.listener.SocketPath(), nil +} + +// Ready returns a channel that is closed when the socket is ready to accept connections. +func (m *BridgeManager) Ready() <-chan struct{} { + return m.readyCh +} + +// Start begins listening on the Unix socket and accepting connections. +// This method blocks until the context is cancelled. +// Connections are handled in separate goroutines. +func (m *BridgeManager) Start(ctx context.Context) error { + // Create the Unix socket listener + var listenerErr error + m.listener, listenerErr = networking.NewPrivateUnixSocketListener(m.socketDir, m.socketPrefix) + + // Signal that the listener field has been set so that SocketPath() can proceed. + m.listenerOnce.Do(func() { close(m.listenerCh) }) + + if listenerErr != nil { + return fmt.Errorf("failed to create socket listener: %w", listenerErr) + } + defer m.listener.Close() + + m.log.Info("Bridge manager listening", "socketPath", m.listener.SocketPath()) + + // Close the listener when the context is cancelled so that Accept() unblocks. + // PrivateUnixSocketListener.Close() is idempotent, so the deferred Close above + // is still safe. + go func() { + <-ctx.Done() + m.listener.Close() + }() + + // Signal that we're ready to accept connections + m.readyOnce.Do(func() { + close(m.readyCh) + }) + + // Accept connections in a loop + for { + // Accept the next connection + conn, acceptErr := m.listener.Accept() + if acceptErr != nil { + // Check if context was cancelled (listener was closed by the goroutine above) + select { + case <-ctx.Done(): + m.log.V(1).Info("Bridge manager shutting down") + return ctx.Err() + default: + } + m.log.Error(acceptErr, "Failed to accept connection") + continue + } + + // Handle the connection in a separate goroutine + go m.handleConnection(ctx, conn) + } +} + +// validateHandshake validates a handshake request against registered sessions. +// Returns the session if validation succeeds. +func (m *BridgeManager) validateHandshake(sessionID, token string) (*BridgeSession, error) { + m.mu.Lock() + defer m.mu.Unlock() + + session, exists := m.sessions[sessionID] + if !exists { + return nil, ErrBridgeSessionNotFound + } + + if session.Token != token { + return nil, ErrBridgeSessionInvalidToken + } + + return session, nil +} + +// markSessionConnected marks a session as having an active connection. +// Returns an error if the session is not found or already has a connection. +func (m *BridgeManager) markSessionConnected(sessionID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + session, exists := m.sessions[sessionID] + if !exists { + return ErrBridgeSessionNotFound + } + + if session.Connected { + return fmt.Errorf("%w: session %s", ErrBridgeSessionAlreadyConnected, sessionID) + } + + session.Connected = true + m.log.V(1).Info("Marked session as connected", "sessionID", sessionID) + return nil +} + +// markSessionDisconnected resets the connected flag for a session. +// This is used to roll back markSessionConnected if later handshake steps fail. +// It is a no-op if the session does not exist. +func (m *BridgeManager) markSessionDisconnected(sessionID string) { + m.mu.Lock() + defer m.mu.Unlock() + + if session, exists := m.sessions[sessionID]; exists { + session.Connected = false + m.log.V(1).Info("Reset session connected state", "sessionID", sessionID) + } +} + +// IsSessionConnected returns whether the given session has an active connection. +// Returns false if the session does not exist. +func (m *BridgeManager) IsSessionConnected(sessionID string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + session, exists := m.sessions[sessionID] + if !exists { + return false + } + return session.Connected +} + +// updateSessionState updates the state of a session. +func (m *BridgeManager) updateSessionState(sessionID string, state BridgeSessionState, errorMsg string) error { + m.mu.Lock() + defer m.mu.Unlock() + + session, exists := m.sessions[sessionID] + if !exists { + return ErrBridgeSessionNotFound + } + + oldState := session.State + session.State = state + session.Error = errorMsg + + m.log.V(1).Info("Bridge session state changed", + "sessionID", sessionID, + "oldState", oldState.String(), + "newState", state.String()) + + return nil +} + +// handleConnection processes a single incoming connection. +func (m *BridgeManager) handleConnection(ctx context.Context, conn net.Conn) { + defer func() { + if r := recover(); r != nil { + m.log.Error(fmt.Errorf("panic: %v", r), "Panic in connection handler") + conn.Close() + } + }() + + log := m.log.WithValues("remoteAddr", conn.RemoteAddr()) + log.V(1).Info("Accepted connection") + + // Set handshake timeout + timeout := m.config.HandshakeTimeout + if timeout == 0 { + timeout = DefaultHandshakeTimeout + } + if deadlineErr := conn.SetDeadline(time.Now().Add(timeout)); deadlineErr != nil { + log.Error(deadlineErr, "Failed to set handshake deadline") + conn.Close() + return + } + + // Read the handshake request + reader := NewHandshakeReader(conn) + writer := NewHandshakeWriter(conn) + + req, readErr := reader.ReadRequest() + if readErr != nil { + log.Error(readErr, "Failed to read handshake request") + conn.Close() + return + } + + log = log.WithValues("sessionID", req.SessionID) + log.V(1).Info("Received handshake request") + + // Validate token and session + session, validateErr := m.validateHandshake(req.SessionID, req.Token) + if validateErr != nil { + log.Error(validateErr, "Handshake validation failed") + resp := &HandshakeResponse{ + Success: false, + Error: validateErr.Error(), + } + _ = writer.WriteResponse(resp) + conn.Close() + return + } + + // Check if adapter config is provided in handshake + if req.DebugAdapterConfig == nil { + log.Error(nil, "Handshake missing debug adapter configuration") + resp := &HandshakeResponse{ + Success: false, + Error: "debug adapter configuration is required", + } + _ = writer.WriteResponse(resp) + conn.Close() + return + } + + // Try to mark the session as connected (prevents duplicate connections) + markErr := m.markSessionConnected(req.SessionID) + if markErr != nil { + log.Error(markErr, "Failed to mark session as connected") + resp := &HandshakeResponse{ + Success: false, + Error: markErr.Error(), + } + _ = writer.WriteResponse(resp) + conn.Close() + return + } + + // If anything fails between marking connected and handing off to runBridge, + // roll back the connected state so the session can be retried. + handedOff := false + defer func() { + if !handedOff { + m.markSessionDisconnected(req.SessionID) + } + }() + + // Send success response + resp := &HandshakeResponse{Success: true} + if writeErr := writer.WriteResponse(resp); writeErr != nil { + log.Error(writeErr, "Failed to send handshake response") + conn.Close() + return + } + + // Clear the deadline for normal operation + if deadlineErr := conn.SetDeadline(time.Time{}); deadlineErr != nil { + log.Error(deadlineErr, "Failed to clear handshake deadline") + conn.Close() + return + } + + log.Info("Handshake successful, starting bridge") + + // Disarm the rollback—runBridge now owns the session + handedOff = true + + // Create and run the bridge + m.runBridge(ctx, conn, session, req.RunID, req.DebugAdapterConfig, log) +} + +// runBridge creates and runs a DapBridge for the given connection and session. +func (m *BridgeManager) runBridge( + ctx context.Context, + conn net.Conn, + session *BridgeSession, + runID string, + adapterConfig *DebugAdapterConfig, + log logr.Logger, +) { + // Create the bridge configuration + bridgeConfig := BridgeConfig{ + SessionID: session.ID, + AdapterConfig: adapterConfig, + Executor: m.executor, + Logger: log.WithName("DapBridge"), + } + + // Resolve output handlers via the connection callback if configured + if m.config.ConnectionHandler != nil { + outputHandler, stdoutWriter, stderrWriter := m.config.ConnectionHandler(session.ID, runID) + bridgeConfig.OutputHandler = outputHandler + bridgeConfig.StdoutWriter = stdoutWriter + bridgeConfig.StderrWriter = stderrWriter + } + + // Create the bridge + bridge := NewDapBridge(bridgeConfig) + + // Track active bridge + m.mu.Lock() + m.activeBridges[session.ID] = bridge + m.mu.Unlock() + + defer func() { + m.mu.Lock() + delete(m.activeBridges, session.ID) + m.mu.Unlock() + }() + + // Update session state + _ = m.updateSessionState(session.ID, BridgeSessionStateConnected, "") + + // Run the bridge with the already-connected IDE connection + bridgeErr := bridge.RunWithConnection(ctx, conn) + if bridgeErr != nil && !isExpectedShutdownErr(bridgeErr) { + log.Error(bridgeErr, "Bridge terminated with error") + _ = m.updateSessionState(session.ID, BridgeSessionStateError, bridgeErr.Error()) + } else { + _ = m.updateSessionState(session.ID, BridgeSessionStateTerminated, "") + } +} diff --git a/internal/dap/bridge_manager_test.go b/internal/dap/bridge_manager_test.go new file mode 100644 index 00000000..bb25cdad --- /dev/null +++ b/internal/dap/bridge_manager_test.go @@ -0,0 +1,117 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBridgeManager_RegisterSession(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + session, err := manager.RegisterSession("test-session-1", "test-token-123") + require.NoError(t, err) + require.NotNil(t, session) + + assert.Equal(t, "test-session-1", session.ID) + assert.Equal(t, "test-token-123", session.Token) + assert.Equal(t, BridgeSessionStateCreated, session.State) + assert.False(t, session.Connected) +} + +func TestBridgeManager_RegisterSession_DuplicateID(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + _, sessionErr := manager.RegisterSession("dup-session", "token1") + require.NoError(t, sessionErr) + + _, dupErr := manager.RegisterSession("dup-session", "token2") + assert.ErrorIs(t, dupErr, ErrBridgeSessionAlreadyExists) +} + +func TestBridgeManager_ValidateHandshake_InvalidToken(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + _, regErr := manager.RegisterSession("token-session", "correct-token") + require.NoError(t, regErr) + + _, validateErr := manager.validateHandshake("token-session", "wrong-token") + assert.ErrorIs(t, validateErr, ErrBridgeSessionInvalidToken) +} + +func TestBridgeManager_ValidateHandshake_SessionNotFound(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + _, validateErr := manager.validateHandshake("nonexistent", "any-token") + assert.ErrorIs(t, validateErr, ErrBridgeSessionNotFound) +} + +func TestBridgeManager_MarkSessionConnected(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + session, regErr := manager.RegisterSession("connect-session", "test-token") + require.NoError(t, regErr) + assert.False(t, session.Connected) + + // First connection should succeed + connectErr := manager.markSessionConnected("connect-session") + require.NoError(t, connectErr) + assert.True(t, session.Connected) + + // Second connection attempt should fail + connectErr2 := manager.markSessionConnected("connect-session") + assert.ErrorIs(t, connectErr2, ErrBridgeSessionAlreadyConnected) +} + +func TestBridgeManager_MarkSessionConnected_NotFound(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + connectErr := manager.markSessionConnected("nonexistent") + assert.ErrorIs(t, connectErr, ErrBridgeSessionNotFound) +} + +func TestBridgeManager_MarkSessionDisconnected(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + _, regErr := manager.RegisterSession("disconnect-session", "test-token") + require.NoError(t, regErr) + + // Mark connected, then disconnect + connectErr := manager.markSessionConnected("disconnect-session") + require.NoError(t, connectErr) + + manager.markSessionDisconnected("disconnect-session") + + // Should be able to connect again after disconnect + reconnectErr := manager.markSessionConnected("disconnect-session") + assert.NoError(t, reconnectErr) +} + +func TestBridgeManager_MarkSessionDisconnected_NotFound(t *testing.T) { + t.Parallel() + + // Should be a no-op, not panic + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + manager.markSessionDisconnected("nonexistent") +} diff --git a/internal/dap/bridge_test.go b/internal/dap/bridge_test.go new file mode 100644 index 00000000..3b77f038 --- /dev/null +++ b/internal/dap/bridge_test.go @@ -0,0 +1,260 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "net" + "os" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/wait" + + "github.com/microsoft/dcp/pkg/testutil" +) + +// shortTempDir creates a short temporary directory for socket tests. +// macOS has a ~104 character limit for Unix socket paths. +func shortTempDir(t *testing.T) string { + t.Helper() + dir, dirErr := os.MkdirTemp("", "sck") + require.NoError(t, dirErr) + t.Cleanup(func() { os.RemoveAll(dir) }) + return dir +} + +func TestDapBridge_Creation(t *testing.T) { + t.Parallel() + + config := BridgeConfig{ + SessionID: "test-session", + } + + bridge := NewDapBridge(config) + + assert.NotNil(t, bridge) +} + +func TestDapBridge_RunWithConnection(t *testing.T) { + t.Parallel() + + // Test that RunWithConnection starts and can be cancelled + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + config := BridgeConfig{ + SessionID: "session-456", + AdapterConfig: &DebugAdapterConfig{ + Args: []string{"echo", "test"}, // Simple command + Mode: DebugAdapterModeStdio, + }, + } + + bridge := NewDapBridge(config) + + ctx, cancel := testutil.GetTestContext(t, 500*time.Millisecond) + defer cancel() + + // Run bridge with pre-connected connection + // It will fail to properly run the adapter but will start correctly + errCh := make(chan error, 1) + go func() { + errCh <- bridge.RunWithConnection(ctx, serverConn) + }() + + // Cancel to shutdown + cancel() + + // Wait for bridge to finish + select { + case <-errCh: + // Good + case <-time.After(2 * time.Second): + t.Fatal("bridge did not shut down in time") + } +} + +func TestDapBridge_RunInTerminalUsed(t *testing.T) { + t.Parallel() + + config := BridgeConfig{ + SessionID: "session", + } + + bridge := NewDapBridge(config) + + // Initially false + assert.False(t, bridge.runInTerminalUsed.Load()) +} + +func TestDapBridge_Done(t *testing.T) { + t.Parallel() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + config := BridgeConfig{ + SessionID: "session", + AdapterConfig: &DebugAdapterConfig{ + Args: []string{"echo"}, + Mode: DebugAdapterModeStdio, + }, + } + + bridge := NewDapBridge(config) + + // Done channel should not be closed initially + select { + case <-bridge.terminateCh: + t.Fatal("Done channel should not be closed before running") + default: + // Expected + } + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + + // Start bridge + errCh := make(chan error, 1) + go func() { + errCh <- bridge.RunWithConnection(ctx, serverConn) + }() + + // Cancel to cause termination + cancel() + + // Wait for RunWithConnection to return + select { + case <-errCh: + // Expected + case <-time.After(2 * time.Second): + t.Fatal("RunWithConnection did not return after cancel") + } + + // Done channel should be closed after termination + select { + case <-bridge.terminateCh: + // Expected + case <-time.After(2 * time.Second): + t.Fatal("Done channel should be closed after termination") + } +} + +func TestBridgeManager_SocketPath(t *testing.T) { + t.Parallel() + + // Use a cancelled context so SocketPath() returns immediately + // rather than blocking waiting for Start(). + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + // Before Start(), SocketPath() returns an error since no listener exists yet + socketPath, socketErr := manager.SocketPath(cancelledCtx) + assert.Empty(t, socketPath) + assert.Error(t, socketErr) +} + +func TestBridgeManager_DefaultSocketNamePrefix(t *testing.T) { + t.Parallel() + + manager := NewBridgeManager(BridgeManagerConfig{}, logr.Discard()) + + // Should use default prefix + assert.Equal(t, DefaultSocketNamePrefix, manager.socketPrefix) +} + +func TestBridgeManager_StartAndReady(t *testing.T) { + t.Parallel() + + socketDir := shortTempDir(t) + + ctx, cancel := testutil.GetTestContext(t, 2*time.Second) + defer cancel() + + manager := NewBridgeManager(BridgeManagerConfig{ + SocketDir: socketDir, + }, logr.Discard()) + + // Start in background + go func() { + _ = manager.Start(ctx) + }() + + // Wait for ready + select { + case <-manager.Ready(): + // Expected — SocketPath should now be set + socketPath, socketErr := manager.SocketPath(ctx) + require.NoError(t, socketErr) + assert.NotEmpty(t, socketPath) + assert.Contains(t, socketPath, DefaultSocketNamePrefix) + case <-time.After(1 * time.Second): + t.Fatal("manager did not become ready in time") + } + + cancel() +} + +func TestBridgeManager_DuplicateSession(t *testing.T) { + t.Parallel() + + // Test that a second connection for the same session is rejected + + socketDir := shortTempDir(t) + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + manager := NewBridgeManager(BridgeManagerConfig{ + SocketDir: socketDir, + HandshakeTimeout: 2 * time.Second, + }, logr.Discard()) + _, _ = manager.RegisterSession("dup-session", "token") + + go func() { + _ = manager.Start(ctx) + }() + + <-manager.Ready() + + socketPath, socketErr := manager.SocketPath(ctx) + require.NoError(t, socketErr) + + // First connection with a valid adapter config so the handshake completes + // and markSessionConnected is called. The adapter will fail to launch but + // the session will remain marked as connected. + conn1, err1 := net.Dial("unix", socketPath) + require.NoError(t, err1) + defer conn1.Close() + + handshakeErr1 := performHandshakeWithAdapterConfig(conn1, "token", "dup-session", "", &DebugAdapterConfig{ + Args: []string{"echo", "dummy"}, + Mode: DebugAdapterModeStdio, + }) + require.NoError(t, handshakeErr1, "first handshake should succeed") + + // Wait until the first connection is processed and the session is marked connected + pollErr := wait.PollUntilContextCancel(ctx, 50*time.Millisecond, true, func(_ context.Context) (bool, error) { + return manager.IsSessionConnected("dup-session"), nil + }) + require.NoError(t, pollErr, "first connection should mark the session as connected") + + // Second connection for the same session + conn2, err2 := net.Dial("unix", socketPath) + require.NoError(t, err2) + defer conn2.Close() + + // This handshake should fail because session is already connected + handshakeErr := performClientHandshake(conn2, "token", "dup-session", "") + assert.Error(t, handshakeErr, "second connection should be rejected") + + cancel() +} diff --git a/internal/dap/doc.go b/internal/dap/doc.go new file mode 100644 index 00000000..65052114 --- /dev/null +++ b/internal/dap/doc.go @@ -0,0 +1,61 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +/* +Package dap provides Debug Adapter Protocol (DAP) infrastructure for debugging +executables managed by DCP. + +# Architecture Overview + +The package uses a bridge architecture to connect an IDE's debug adapter +client to a debug adapter launched by DCP. Communication occurs over Unix +domain sockets with a length-prefixed JSON handshake protocol. + +# Key Components + + - DapBridge: Main bridge implementation that manages the connection lifecycle + - BridgeManager: Manages active debug sessions, a shared Unix socket, and bridge lifecycle + - DebugAdapterConfig: Configuration for launching debug adapters + +# Connection Flow + + 1. DCP registers a debug session with BridgeManager + 2. BridgeManager listens on a shared Unix socket + 3. Socket path and authentication token are sent to the IDE + 4. IDE connects to the socket and performs handshake + 5. BridgeManager launches the debug adapter via DapBridge + 6. Bridge forwards DAP messages bidirectionally with interception + +The bridge intercepts: + - initialize requests: Forces supportsRunInTerminalRequest=true + - runInTerminal requests: Handles locally instead of forwarding to IDE + - output events: Captures stdout/stderr when runInTerminal is not used + +# Usage + +For debug session implementations, use DapBridge: + + // Create and start the bridge manager + manager := dap.NewBridgeManager(dap.BridgeManagerConfig{}, log) + + // Register a session and start the manager + session, _ := manager.RegisterSession(sessionID, token) + err := manager.Start(ctx) + +# Handshake Protocol + +The IDE must perform a handshake immediately after connecting to the Unix socket. +The handshake uses length-prefixed JSON messages (4-byte big-endian length prefix): + + Request: {"token": "...", "session_id": "..."} + Response: {"success": true} or {"success": false, "error": "..."} + +# Output Capture + +Output is captured differently based on whether runInTerminal is used: + - Without runInTerminal: Bridge captures from DAP output events + - With runInTerminal: Bridge captures from process stdout/stderr pipes +*/ +package dap diff --git a/internal/dap/message.go b/internal/dap/message.go new file mode 100644 index 00000000..84950440 --- /dev/null +++ b/internal/dap/message.go @@ -0,0 +1,306 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + + "github.com/google/go-dap" +) + +// newOutputEvent creates a DAP OutputEvent for sending text to the IDE. +// category should be "stdout", "stderr", or "console". +func newOutputEvent(seq int, category, output string) *dap.OutputEvent { + return &dap.OutputEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: seq, + Type: "event", + }, + Event: "output", + }, + Body: dap.OutputEventBody{ + Category: category, + Output: output, + }, + } +} + +// newTerminatedEvent creates a DAP TerminatedEvent to signal the debug session has ended. +func newTerminatedEvent(seq int) *dap.TerminatedEvent { + return &dap.TerminatedEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{ + Seq: seq, + Type: "event", + }, + Event: "terminated", + }, + } +} + +// RawMessage represents a DAP message that could not be decoded into a known type. +// This is used for custom/proprietary messages that the go-dap library doesn't recognize. +type RawMessage struct { + // Data contains the raw JSON bytes of the message (without Content-Length header). + Data []byte + + // header caches the parsed header to avoid repeated JSON unmarshaling. + // It is invalidated (set to nil) when patchJSONFields modifies the raw data. + header *rawMessageHeader +} + +// rawMessageHeader contains the common fields present in all DAP protocol messages. +// It is used to extract header information from RawMessage instances. +type rawMessageHeader struct { + Seq int `json:"seq"` + Type string `json:"type"` + Command string `json:"command,omitempty"` + Event string `json:"event,omitempty"` + RequestSeq int `json:"request_seq,omitempty"` + Success *bool `json:"success,omitempty"` + Message string `json:"message,omitempty"` +} + +// parseHeader parses the raw JSON into a rawMessageHeader, caching the result. +// Subsequent calls return the cached header without re-parsing. +// The cache is invalidated when patchJSONFields modifies the raw data. +func (r *RawMessage) parseHeader() rawMessageHeader { + if r.header != nil { + return *r.header + } + var h rawMessageHeader + _ = json.Unmarshal(r.Data, &h) + r.header = &h + return h +} + +// GetSeq extracts the sequence number from the raw message, or returns 0 if not parseable. +func (r *RawMessage) GetSeq() int { + return r.parseHeader().Seq +} + +// patchJSONFields patches multiple numeric JSON fields in the raw data in a single +// unmarshal/marshal pass. This invalidates the cached header. +func (r *RawMessage) patchJSONFields(fields map[string]int) error { + if len(fields) == 0 { + return nil + } + var obj map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(r.Data, &obj); unmarshalErr != nil { + return fmt.Errorf("unmarshal raw message for patching: %w", unmarshalErr) + } + for field, value := range fields { + valBytes, marshalErr := json.Marshal(value) + if marshalErr != nil { + return fmt.Errorf("marshal patch value for field %q: %w", field, marshalErr) + } + obj[field] = valBytes + } + patched, patchErr := json.Marshal(obj) + if patchErr != nil { + return fmt.Errorf("marshal patched raw message: %w", patchErr) + } + r.Data = patched + r.header = nil // invalidate cache + return nil +} + +// ReadMessageWithFallback reads a DAP message from the given reader. +// If the message cannot be decoded (e.g., unknown command), it returns a RawMessage +// containing the raw bytes, allowing the message to be forwarded transparently. +func ReadMessageWithFallback(reader *bufio.Reader) (dap.Message, error) { + content, readErr := dap.ReadBaseMessage(reader) + if readErr != nil { + return nil, readErr + } + + msg, decodeErr := dap.DecodeProtocolMessage(content) + if decodeErr != nil { + // Check if this is an "unknown command/event" error from go-dap. + // These errors indicate the message is valid DAP but uses a custom command. + var fieldErr *dap.DecodeProtocolMessageFieldError + if errors.As(decodeErr, &fieldErr) { + // Return the raw message bytes so it can be forwarded transparently. + return &RawMessage{Data: content}, nil + } + // Other decode errors (malformed JSON, etc.) should fail. + return nil, decodeErr + } + + return msg, nil +} + +// WriteMessageWithFallback writes a DAP message to the given writer. +// If the message is a RawMessage, it writes the raw bytes directly. +// Otherwise, it uses the standard dap.WriteProtocolMessage. +func WriteMessageWithFallback(writer io.Writer, msg dap.Message) error { + if raw, ok := msg.(*RawMessage); ok { + return dap.WriteBaseMessage(writer, raw.Data) + } + return dap.WriteProtocolMessage(writer, msg) +} + +// MessageEnvelope wraps a DAP message (typed or raw) and provides uniform access +// to common header fields. Header fields are extracted once at creation time and +// can be freely modified on the envelope. Changes are applied back to the underlying +// message in a single pass when Finalize is called, avoiding repeated +// serialization round trips. +type MessageEnvelope struct { + // Inner is the underlying DAP message (typed or *RawMessage). + Inner dap.Message + + // Seq is the message sequence number. + Seq int + + // Type is the message type: "request", "response", or "event". + Type string + + // Command is the command name (for requests and responses). + Command string + + // Event is the event name (for events). + Event string + + // RequestSeq is the sequence number of the corresponding request (for responses). + RequestSeq int + + // Success indicates whether a response was successful (nil for non-responses). + Success *bool + + // ErrorMessage is the error message for failed responses. + ErrorMessage string + + // isRaw tracks whether Inner is a *RawMessage. + isRaw bool + + // originalSeq and originalRequestSeq track the values at creation time + // so Finalize only patches fields that actually changed. + originalSeq int + originalRequestSeq int +} + +// NewMessageEnvelope creates a MessageEnvelope by extracting header fields from the +// given message. For typed messages this is a zero-cost struct field read. For +// *RawMessage it performs a single JSON unmarshal of the header (which is cached +// on the RawMessage for any subsequent parseHeader calls). +func NewMessageEnvelope(msg dap.Message) *MessageEnvelope { + env := &MessageEnvelope{Inner: msg} + + switch m := msg.(type) { + case *RawMessage: + env.isRaw = true + h := m.parseHeader() + env.Seq = h.Seq + env.Type = h.Type + env.Command = h.Command + env.Event = h.Event + env.RequestSeq = h.RequestSeq + env.Success = h.Success + env.ErrorMessage = h.Message + case dap.RequestMessage: + r := m.GetRequest() + env.Seq = r.Seq + env.Type = "request" + env.Command = r.Command + case dap.ResponseMessage: + r := m.GetResponse() + env.Seq = r.Seq + env.Type = "response" + env.Command = r.Command + env.RequestSeq = r.RequestSeq + env.Success = &r.Success + env.ErrorMessage = r.Message + case dap.EventMessage: + e := m.GetEvent() + env.Seq = e.Seq + env.Type = "event" + env.Event = e.Event + default: + env.Seq = msg.GetSeq() + } + + env.originalSeq = env.Seq + env.originalRequestSeq = env.RequestSeq + return env +} + +// GetSeq implements dap.Message. +func (e *MessageEnvelope) GetSeq() int { + return e.Seq +} + +// IsResponse returns true if the wrapped message is a response (typed or raw). +func (e *MessageEnvelope) IsResponse() bool { + return e.Type == "response" +} + +// Describe returns a human-readable description of the message for logging. +// It uses the pre-extracted header fields, so no additional parsing is required. +func (e *MessageEnvelope) Describe() string { + prefix := "" + if e.isRaw { + prefix = "raw " + } + + switch e.Type { + case "request": + return fmt.Sprintf("%srequest '%s' (seq=%d)", prefix, e.Command, e.Seq) + case "response": + success := e.Success != nil && *e.Success + if success { + return fmt.Sprintf("%sresponse '%s' (seq=%d, request_seq=%d, success=true)", prefix, e.Command, e.Seq, e.RequestSeq) + } + return fmt.Sprintf("%sresponse '%s' (seq=%d, request_seq=%d, success=false, message=%q)", prefix, e.Command, e.Seq, e.RequestSeq, e.ErrorMessage) + case "event": + return fmt.Sprintf("%sevent '%s' (seq=%d)", prefix, e.Event, e.Seq) + default: + if e.isRaw { + return fmt.Sprintf("raw %s (seq=%d)", e.Type, e.Seq) + } + return fmt.Sprintf("unknown(seq=%d, type=%T)", e.Seq, e.Inner) + } +} + +// Finalize applies any modified header fields back to the underlying message and +// returns it, ready for writing to a Transport. For typed messages this is a +// zero-cost struct field write. For *RawMessage, changed fields are applied in +// a single patchJSONFields call (one unmarshal + one marshal). If no fields were +// changed, the raw data is left untouched. +func (e *MessageEnvelope) Finalize() (dap.Message, error) { + if e.isRaw { + raw := e.Inner.(*RawMessage) + patches := make(map[string]int, 2) + if e.Seq != e.originalSeq { + patches["seq"] = e.Seq + } + if e.RequestSeq != e.originalRequestSeq { + patches["request_seq"] = e.RequestSeq + } + if patchErr := raw.patchJSONFields(patches); patchErr != nil { + return nil, fmt.Errorf("finalize raw message: %w", patchErr) + } + return raw, nil + } + + // Typed messages: apply changes via struct field writes. + switch m := e.Inner.(type) { + case dap.RequestMessage: + m.GetRequest().Seq = e.Seq + case dap.ResponseMessage: + r := m.GetResponse() + r.Seq = e.Seq + r.RequestSeq = e.RequestSeq + case dap.EventMessage: + m.GetEvent().Seq = e.Seq + } + + return e.Inner, nil +} diff --git a/internal/dap/message_pipe.go b/internal/dap/message_pipe.go new file mode 100644 index 00000000..d7377100 --- /dev/null +++ b/internal/dap/message_pipe.go @@ -0,0 +1,129 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "fmt" + "sync/atomic" + + "github.com/go-logr/logr" + "github.com/microsoft/dcp/pkg/concurrency" + "github.com/microsoft/dcp/pkg/syncmap" +) + +// MessagePipe provides a FIFO message queue with a dedicated writer goroutine +// that assigns monotonically increasing sequence numbers to messages as they +// are dequeued. This guarantees that sequence numbers on the wire are always +// in-order, even when multiple goroutines enqueue messages concurrently. +// +// Each pipe owns a SeqCounter (atomic, shared with shutdown writers) and a +// seqMap for tracking virtualSeq→originalSeq mappings so that response +// correlation can be performed by the opposite direction's reader. +type MessagePipe struct { + // transport is the write destination for messages. + transport Transport + + // ch is the unbounded channel used as the FIFO queue. + ch *concurrency.UnboundedChan[*MessageEnvelope] + + // SeqCounter generates monotonically increasing sequence numbers. + // It is atomic so that shutdown code can continue assigning seq values + // after the writer goroutine has stopped. + SeqCounter atomic.Int64 + + // seqMap maps bridge-assigned sequence numbers to original sequence numbers. + // For the adapter-bound pipe, this maps virtualSeq→originalIDESeq so that + // the adapter-to-IDE reader can restore request_seq on responses. + seqMap syncmap.Map[int, int] + + // log is the logger for this pipe. + log logr.Logger + + // name identifies this pipe in log messages (e.g., "adapterPipe", "idePipe"). + name string +} + +// NewMessagePipe creates a new MessagePipe that writes to the given transport. +// The pipe's internal goroutine (for the UnboundedChan) is bound to ctx. +func NewMessagePipe(ctx context.Context, transport Transport, name string, log logr.Logger) *MessagePipe { + return &MessagePipe{ + transport: transport, + ch: concurrency.NewUnboundedChan[*MessageEnvelope](ctx), + log: log, + name: name, + } +} + +// Send enqueues a message to be written by the pipe's writer goroutine. +// This method never blocks for an extended period (UnboundedChan buffers +// internally). It is safe for concurrent use by multiple goroutines. +func (p *MessagePipe) Send(env *MessageEnvelope) { + p.ch.In <- env +} + +// CloseInput closes the pipe's input channel, signaling that no more messages +// will be sent. The pipe's Run goroutine will finish writing any buffered +// messages and then exit. The caller must ensure no goroutine calls Send after +// CloseInput returns. +func (p *MessagePipe) CloseInput() { + close(p.ch.In) +} + +// Run runs the writer loop, reading messages from the FIFO queue, assigning +// sequence numbers, and writing them to the transport. It returns when the +// context is cancelled (which closes the UnboundedChan's Out channel) or +// when a transport write error occurs. +func (p *MessagePipe) Run(ctx context.Context) error { + for env := range p.ch.Out { + // Assign the next sequence number. + originalSeq := env.Seq + newSeq := int(p.SeqCounter.Add(1)) + env.Seq = newSeq + + // For request messages, store the mapping so the opposite direction's + // reader can remap request_seq on responses. + if env.Type == "request" { + p.seqMap.Store(newSeq, originalSeq) + } + + p.log.V(1).Info("Writing message", + "pipe", p.name, + "message", env.Describe(), + "originalSeq", originalSeq, + "assignedSeq", newSeq) + + finalizedMsg, finalizeErr := env.Finalize() + if finalizeErr != nil { + return fmt.Errorf("%s: failed to finalize message: %w", p.name, finalizeErr) + } + + writeErr := p.transport.WriteMessage(finalizedMsg) + if writeErr != nil { + return fmt.Errorf("%s: failed to write message: %w", p.name, writeErr) + } + } + + return ctx.Err() +} + +// RemapResponseSeq looks up the original sequence number for a response's +// request_seq field. If found, it updates env.RequestSeq to the original +// value and deletes the mapping. This should be called by the reader of +// the opposite direction before enqueueing a response to its own pipe. +func (p *MessagePipe) RemapResponseSeq(env *MessageEnvelope) { + if !env.IsResponse() { + return + } + if origSeq, found := p.seqMap.LoadAndDelete(env.RequestSeq); found { + p.log.V(1).Info("Remapping response request_seq", + "pipe", p.name, + "command", env.Command, + "virtualRequestSeq", env.RequestSeq, + "originalRequestSeq", origSeq) + env.RequestSeq = origSeq + } +} diff --git a/internal/dap/message_pipe_test.go b/internal/dap/message_pipe_test.go new file mode 100644 index 00000000..4deb519a --- /dev/null +++ b/internal/dap/message_pipe_test.go @@ -0,0 +1,312 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/go-dap" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/microsoft/dcp/pkg/testutil" +) + +func TestMessagePipe_FIFOOrderAndMonotonicSeq(t *testing.T) { + t.Parallel() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + transport := NewUnixTransportWithContext(ctx, serverConn) + pipe := NewMessagePipe(ctx, transport, "test", logr.Discard()) + + // Start the writer goroutine. + errCh := make(chan error, 1) + go func() { + errCh <- pipe.Run(ctx) + }() + + // Enqueue several messages with arbitrary original seq values. + messageCount := 10 + for i := 0; i < messageCount; i++ { + env := NewMessageEnvelope(&dap.SetBreakpointsRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 100 + i, Type: "request"}, + Command: "setBreakpoints", + }, + }) + pipe.Send(env) + } + + // Read messages from the client side and verify ordering. + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + for i := 0; i < messageCount; i++ { + msg, readErr := clientTransport.ReadMessage() + require.NoError(t, readErr) + assert.Equal(t, i+1, msg.GetSeq(), "seq should be monotonically increasing starting at 1") + } + + cancel() + <-errCh +} + +func TestMessagePipe_ConcurrentSendAllWritten(t *testing.T) { + t.Parallel() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + transport := NewUnixTransportWithContext(ctx, serverConn) + pipe := NewMessagePipe(ctx, transport, "test", logr.Discard()) + + errCh := make(chan error, 1) + go func() { + errCh <- pipe.Run(ctx) + }() + + // Send messages from multiple goroutines concurrently. + goroutineCount := 5 + messagesPerGoroutine := 10 + totalMessages := goroutineCount * messagesPerGoroutine + + var wg sync.WaitGroup + wg.Add(goroutineCount) + for g := 0; g < goroutineCount; g++ { + go func() { + defer wg.Done() + for i := 0; i < messagesPerGoroutine; i++ { + env := NewMessageEnvelope(&dap.ContinueRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 999, Type: "request"}, + Command: "continue", + }, + }) + pipe.Send(env) + } + }() + } + wg.Wait() + + // Read all messages and verify we got the right count and monotonic seq. + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + seenSeqs := make([]int, 0, totalMessages) + for i := 0; i < totalMessages; i++ { + msg, readErr := clientTransport.ReadMessage() + require.NoError(t, readErr) + seenSeqs = append(seenSeqs, msg.GetSeq()) + } + + assert.Len(t, seenSeqs, totalMessages) + // Verify monotonically increasing. + for i := 1; i < len(seenSeqs); i++ { + assert.Greater(t, seenSeqs[i], seenSeqs[i-1], + "seq values must be monotonically increasing: seq[%d]=%d, seq[%d]=%d", + i-1, seenSeqs[i-1], i, seenSeqs[i]) + } + + cancel() + <-errCh +} + +func TestMessagePipe_SeqMapPopulatedForRequests(t *testing.T) { + t.Parallel() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + transport := NewUnixTransportWithContext(ctx, serverConn) + pipe := NewMessagePipe(ctx, transport, "test", logr.Discard()) + + errCh := make(chan error, 1) + go func() { + errCh <- pipe.Run(ctx) + }() + + // Send a request with original seq=42. + env := NewMessageEnvelope(&dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 42, Type: "request"}, + Command: "initialize", + }, + }) + pipe.Send(env) + + // Also send an event (should NOT be stored in seqMap). + eventEnv := NewMessageEnvelope(&dap.StoppedEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{Seq: 99, Type: "event"}, + Event: "stopped", + }, + }) + pipe.Send(eventEnv) + + // Drain both messages from the transport. + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + msg1, readErr1 := clientTransport.ReadMessage() + require.NoError(t, readErr1) + msg2, readErr2 := clientTransport.ReadMessage() + require.NoError(t, readErr2) + + // The request should have been assigned seq=1, and the mapping 1→42 stored. + assert.Equal(t, 1, msg1.GetSeq()) + origSeq, found := pipe.seqMap.Load(1) + assert.True(t, found, "seqMap should contain mapping for request") + assert.Equal(t, 42, origSeq) + + // The event (seq=2) should NOT be in the seqMap. + assert.Equal(t, 2, msg2.GetSeq()) + _, eventFound := pipe.seqMap.Load(2) + assert.False(t, eventFound, "seqMap should not contain mapping for events") + + cancel() + <-errCh +} + +func TestMessagePipe_RemapResponseSeq(t *testing.T) { + t.Parallel() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + // We don't need a real transport for this test — just the seqMap. + pipe := NewMessagePipe(ctx, nil, "test", logr.Discard()) + + // Manually populate the seqMap as if a request with virtualSeq=5 was written + // and the original IDE seq was 42. + pipe.seqMap.Store(5, 42) + + // Create a response envelope with request_seq=5 (the virtual seq). + env := NewMessageEnvelope(&dap.InitializeResponse{ + Response: dap.Response{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "response"}, + RequestSeq: 5, + Command: "initialize", + Success: true, + }, + }) + + pipe.RemapResponseSeq(env) + + assert.Equal(t, 42, env.RequestSeq, "request_seq should be remapped to original IDE seq") + + // The mapping should be consumed (deleted). + _, found := pipe.seqMap.Load(5) + assert.False(t, found, "seqMap entry should be deleted after remap") +} + +func TestMessagePipe_RemapResponseSeq_IgnoresNonResponses(t *testing.T) { + t.Parallel() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + pipe := NewMessagePipe(ctx, nil, "test", logr.Discard()) + pipe.seqMap.Store(1, 100) + + // Try to remap a request — should be a no-op. + env := NewMessageEnvelope(&dap.ContinueRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "continue", + }, + }) + + pipe.RemapResponseSeq(env) + + // seqMap entry should still exist (not consumed). + _, found := pipe.seqMap.Load(1) + assert.True(t, found, "seqMap entry should not be consumed for non-response messages") +} + +func TestMessagePipe_ContextCancellationStopsWriter(t *testing.T) { + t.Parallel() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + + transport := NewUnixTransportWithContext(ctx, serverConn) + pipe := NewMessagePipe(ctx, transport, "test", logr.Discard()) + + errCh := make(chan error, 1) + go func() { + errCh <- pipe.Run(ctx) + }() + + // Cancel the context — the writer should stop. + cancel() + + select { + case runErr := <-errCh: + // Writer should return with context.Canceled (or nil if Out closed first). + if runErr != nil { + assert.ErrorIs(t, runErr, context.Canceled) + } + case <-time.After(2 * time.Second): + t.Fatal("writer goroutine did not stop after context cancellation") + } +} + +func TestMessagePipe_SeqCounterContinuesAfterStop(t *testing.T) { + t.Parallel() + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + + transport := NewUnixTransportWithContext(ctx, serverConn) + pipe := NewMessagePipe(ctx, transport, "test", logr.Discard()) + + errCh := make(chan error, 1) + go func() { + errCh <- pipe.Run(ctx) + }() + + // Send a couple of messages so counter reaches 2. + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + for i := 0; i < 2; i++ { + env := NewMessageEnvelope(&dap.ContinueRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "continue", + }, + }) + pipe.Send(env) + + _, readErr := clientTransport.ReadMessage() + require.NoError(t, readErr) + } + + // Stop the writer. + cancel() + <-errCh + + // SeqCounter should continue from where the writer left off. + nextSeq := int(pipe.SeqCounter.Add(1)) + assert.Equal(t, 3, nextSeq, "SeqCounter should continue from where writer left off") +} diff --git a/internal/dap/message_test.go b/internal/dap/message_test.go new file mode 100644 index 00000000..be213a02 --- /dev/null +++ b/internal/dap/message_test.go @@ -0,0 +1,436 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "bufio" + "bytes" + "testing" + + "github.com/google/go-dap" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadMessageWithFallbackKnownRequest(t *testing.T) { + t.Parallel() + + // Create a valid DAP message using WriteProtocolMessage + buf := new(bytes.Buffer) + initReq := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + writeErr := dap.WriteProtocolMessage(buf, initReq) + require.NoError(t, writeErr) + + reader := bufio.NewReader(buf) + msg, readErr := ReadMessageWithFallback(reader) + require.NoError(t, readErr) + + decoded, ok := msg.(*dap.InitializeRequest) + require.True(t, ok, "expected *dap.InitializeRequest, got %T", msg) + assert.Equal(t, 1, decoded.Seq) + assert.Equal(t, "initialize", decoded.Command) +} + +func TestReadMessageWithFallbackUnknownRequest(t *testing.T) { + t.Parallel() + + // Create a DAP message with unknown command + customJSON := `{"seq":2,"type":"request","command":"handshake","arguments":{"value":"test-value"}}` + content := "Content-Length: " + itoa(len(customJSON)) + "\r\n\r\n" + customJSON + + reader := bufio.NewReader(bytes.NewBufferString(content)) + msg, readErr := ReadMessageWithFallback(reader) + require.NoError(t, readErr) + + raw, ok := msg.(*RawMessage) + require.True(t, ok, "expected *RawMessage, got %T", msg) + assert.Equal(t, 2, raw.GetSeq()) + assert.Contains(t, string(raw.Data), `"command":"handshake"`) +} + +func TestReadMessageWithFallbackUnknownEvent(t *testing.T) { + t.Parallel() + + customJSON := `{"seq":5,"type":"event","event":"customEvent","body":{"data":123}}` + content := "Content-Length: " + itoa(len(customJSON)) + "\r\n\r\n" + customJSON + + reader := bufio.NewReader(bytes.NewBufferString(content)) + msg, readErr := ReadMessageWithFallback(reader) + require.NoError(t, readErr) + + raw, ok := msg.(*RawMessage) + require.True(t, ok, "expected *RawMessage, got %T", msg) + assert.Equal(t, 5, raw.GetSeq()) + assert.Contains(t, string(raw.Data), `"event":"customEvent"`) +} + +func TestReadMessageWithFallbackMalformedJSON(t *testing.T) { + t.Parallel() + + badJSON := `{"seq":1,"type":` + content := "Content-Length: " + itoa(len(badJSON)) + "\r\n\r\n" + badJSON + + reader := bufio.NewReader(bytes.NewBufferString(content)) + _, readErr := ReadMessageWithFallback(reader) + require.Error(t, readErr) +} + +func TestWriteMessageWithFallbackKnownMessage(t *testing.T) { + t.Parallel() + + buf := new(bytes.Buffer) + initReq := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + writeErr := WriteMessageWithFallback(buf, initReq) + require.NoError(t, writeErr) + + // Read it back + reader := bufio.NewReader(buf) + msg, readErr := dap.ReadProtocolMessage(reader) + require.NoError(t, readErr) + + decoded, ok := msg.(*dap.InitializeRequest) + require.True(t, ok) + assert.Equal(t, 1, decoded.Seq) +} + +func TestWriteMessageWithFallbackRawMessage(t *testing.T) { + t.Parallel() + + customJSON := `{"seq":2,"type":"request","command":"handshake","arguments":{"value":"test-value"}}` + raw := &RawMessage{Data: []byte(customJSON)} + + buf := new(bytes.Buffer) + writeErr := WriteMessageWithFallback(buf, raw) + require.NoError(t, writeErr) + + // Expect Content-Length header followed by the raw JSON + result := buf.String() + assert.Contains(t, result, "Content-Length:") + assert.Contains(t, result, customJSON) +} + +func TestWriteMessageWithFallbackRoundtrip(t *testing.T) { + t.Parallel() + + originalJSON := `{"seq":3,"type":"request","command":"vsdbgHandshake","arguments":{"protocolVersion":1}}` + raw := &RawMessage{Data: []byte(originalJSON)} + + buf := new(bytes.Buffer) + writeErr := WriteMessageWithFallback(buf, raw) + require.NoError(t, writeErr) + + // Read it back using ReadMessageWithFallback + reader := bufio.NewReader(buf) + msg, readErr := ReadMessageWithFallback(reader) + require.NoError(t, readErr) + + readRaw, ok := msg.(*RawMessage) + require.True(t, ok, "expected *RawMessage, got %T", msg) + assert.Equal(t, originalJSON, string(readRaw.Data)) +} + +// itoa is a simple helper to convert int to string without importing strconv +func itoa(n int) string { + if n == 0 { + return "0" + } + var digits []byte + for n > 0 { + digits = append([]byte{byte('0' + n%10)}, digits...) + n /= 10 + } + return string(digits) +} + +func TestMessageEnvelope_TypedRequest(t *testing.T) { + t.Parallel() + + msg := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + + env := NewMessageEnvelope(msg) + assert.Equal(t, 1, env.Seq) + assert.Equal(t, "request", env.Type) + assert.Equal(t, "initialize", env.Command) + assert.False(t, env.IsResponse()) + + // Modify seq + env.Seq = 100 + finalized, finalizeErr := env.Finalize() + require.NoError(t, finalizeErr) + assert.Equal(t, 100, finalized.GetSeq()) + assert.Equal(t, msg, finalized) // same pointer +} + +func TestMessageEnvelope_TypedResponse(t *testing.T) { + t.Parallel() + + msg := &dap.InitializeResponse{ + Response: dap.Response{ + ProtocolMessage: dap.ProtocolMessage{Seq: 2, Type: "response"}, + Command: "initialize", + RequestSeq: 1, + Success: true, + }, + } + + env := NewMessageEnvelope(msg) + assert.Equal(t, 2, env.Seq) + assert.Equal(t, "response", env.Type) + assert.Equal(t, 1, env.RequestSeq) + assert.True(t, env.IsResponse()) + require.NotNil(t, env.Success) + assert.True(t, *env.Success) + + // Modify both seq and request_seq + env.Seq = 200 + env.RequestSeq = 50 + finalized, finalizeErr := env.Finalize() + require.NoError(t, finalizeErr) + assert.Equal(t, 200, finalized.GetSeq()) + resp := finalized.(*dap.InitializeResponse) + assert.Equal(t, 50, resp.Response.RequestSeq) +} + +func TestMessageEnvelope_TypedEvent(t *testing.T) { + t.Parallel() + + msg := &dap.OutputEvent{ + Event: dap.Event{ + ProtocolMessage: dap.ProtocolMessage{Seq: 3, Type: "event"}, + Event: "output", + }, + } + + env := NewMessageEnvelope(msg) + assert.Equal(t, 3, env.Seq) + assert.Equal(t, "event", env.Type) + assert.Equal(t, "output", env.Event) + assert.False(t, env.IsResponse()) + + // Modify seq + env.Seq = 300 + finalized, finalizeErr := env.Finalize() + require.NoError(t, finalizeErr) + assert.Equal(t, 300, finalized.GetSeq()) +} + +func TestMessageEnvelope_RawRequest(t *testing.T) { + t.Parallel() + + raw := &RawMessage{Data: []byte(`{"seq":5,"type":"request","command":"handshake","arguments":{"v":1}}`)} + env := NewMessageEnvelope(raw) + + assert.Equal(t, 5, env.Seq) + assert.Equal(t, "request", env.Type) + assert.Equal(t, "handshake", env.Command) + assert.False(t, env.IsResponse()) + + // Modify seq + env.Seq = 500 + finalized, finalizeErr := env.Finalize() + require.NoError(t, finalizeErr) + + // Finalize returns the same RawMessage with patched JSON + patchedRaw, ok := finalized.(*RawMessage) + require.True(t, ok) + assert.Equal(t, 500, patchedRaw.GetSeq()) + assert.Contains(t, string(patchedRaw.Data), `"command":"handshake"`) + assert.Contains(t, string(patchedRaw.Data), `"arguments"`) +} + +func TestMessageEnvelope_RawResponse(t *testing.T) { + t.Parallel() + + raw := &RawMessage{Data: []byte(`{"seq":6,"type":"response","command":"handshake","request_seq":5,"success":true,"body":{"v":1}}`)} + env := NewMessageEnvelope(raw) + + assert.Equal(t, 6, env.Seq) + assert.Equal(t, "response", env.Type) + assert.Equal(t, 5, env.RequestSeq) + assert.True(t, env.IsResponse()) + require.NotNil(t, env.Success) + assert.True(t, *env.Success) + + // Modify both seq and request_seq — should produce a single patch pass + env.Seq = 100 + env.RequestSeq = 42 + finalized, finalizeErr := env.Finalize() + require.NoError(t, finalizeErr) + + patchedRaw, ok := finalized.(*RawMessage) + require.True(t, ok) + assert.Equal(t, 100, patchedRaw.GetSeq()) + h := patchedRaw.parseHeader() + assert.Equal(t, 42, h.RequestSeq) + assert.Equal(t, "handshake", h.Command) + assert.Contains(t, string(patchedRaw.Data), `"body"`) +} + +func TestMessageEnvelope_NoChanges(t *testing.T) { + t.Parallel() + + originalJSON := `{"seq":3,"type":"event","event":"custom","body":{"data":123}}` + raw := &RawMessage{Data: []byte(originalJSON)} + env := NewMessageEnvelope(raw) + + // Don't modify anything + finalized, finalizeErr := env.Finalize() + require.NoError(t, finalizeErr) + + patchedRaw, ok := finalized.(*RawMessage) + require.True(t, ok) + // Data should be untouched since nothing changed + assert.Equal(t, originalJSON, string(patchedRaw.Data)) +} + +func TestMessageEnvelopeDescribeTypedRequest(t *testing.T) { + t.Parallel() + + msg := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + env := NewMessageEnvelope(msg) + assert.Equal(t, "request 'initialize' (seq=1)", env.Describe()) +} + +func TestMessageEnvelopeDescribeTypedResponseSuccess(t *testing.T) { + t.Parallel() + + msg := &dap.InitializeResponse{ + Response: dap.Response{ + ProtocolMessage: dap.ProtocolMessage{Seq: 2, Type: "response"}, + Command: "initialize", + RequestSeq: 1, + Success: true, + }, + } + env := NewMessageEnvelope(msg) + assert.Equal(t, "response 'initialize' (seq=2, request_seq=1, success=true)", env.Describe()) +} + +func TestMessageEnvelopeDescribeRawRequest(t *testing.T) { + t.Parallel() + + msg := &RawMessage{Data: []byte(`{"seq":5,"type":"request","command":"vsdbgHandshake"}`)} + env := NewMessageEnvelope(msg) + assert.Equal(t, "raw request 'vsdbgHandshake' (seq=5)", env.Describe()) +} + +func TestMessageEnvelopeDescribeRawResponseSuccess(t *testing.T) { + t.Parallel() + + msg := &RawMessage{Data: []byte(`{"seq":6,"type":"response","command":"vsdbgHandshake","request_seq":5,"success":true}`)} + env := NewMessageEnvelope(msg) + assert.Equal(t, "raw response 'vsdbgHandshake' (seq=6, request_seq=5, success=true)", env.Describe()) +} + +func TestMessageEnvelopeDescribeRawResponseFailure(t *testing.T) { + t.Parallel() + + msg := &RawMessage{Data: []byte(`{"seq":7,"type":"response","command":"vsdbgHandshake","request_seq":5,"success":false,"message":"denied"}`)} + env := NewMessageEnvelope(msg) + assert.Equal(t, "raw response 'vsdbgHandshake' (seq=7, request_seq=5, success=false, message=\"denied\")", env.Describe()) +} + +func TestMessageEnvelopeDescribeRawEvent(t *testing.T) { + t.Parallel() + + msg := &RawMessage{Data: []byte(`{"seq":8,"type":"event","event":"customNotify"}`)} + env := NewMessageEnvelope(msg) + assert.Equal(t, "raw event 'customNotify' (seq=8)", env.Describe()) +} + +func TestMessageEnvelopeDescribeRawUnknownType(t *testing.T) { + t.Parallel() + + msg := &RawMessage{Data: []byte(`{"seq":9,"type":"weird"}`)} + env := NewMessageEnvelope(msg) + assert.Equal(t, "raw weird (seq=9)", env.Describe()) +} + +func TestMessageEnvelopeDescribeReflectsModifiedSeq(t *testing.T) { + t.Parallel() + + msg := &RawMessage{Data: []byte(`{"seq":5,"type":"request","command":"handshake"}`)} + env := NewMessageEnvelope(msg) + env.Seq = 99 + assert.Equal(t, "raw request 'handshake' (seq=99)", env.Describe()) +} + +func TestPatchJSONFieldsSingleField(t *testing.T) { + t.Parallel() + + raw := &RawMessage{Data: []byte(`{"seq":1,"type":"request","command":"test"}`)} + require.NoError(t, raw.patchJSONFields(map[string]int{"seq": 42})) + assert.Equal(t, 42, raw.GetSeq()) + assert.Contains(t, string(raw.Data), `"command":"test"`) +} + +func TestPatchJSONFieldsMultipleFields(t *testing.T) { + t.Parallel() + + raw := &RawMessage{Data: []byte(`{"seq":1,"type":"response","command":"test","request_seq":5,"success":true}`)} + require.NoError(t, raw.patchJSONFields(map[string]int{"seq": 100, "request_seq": 42})) + h := raw.parseHeader() + assert.Equal(t, 100, h.Seq) + assert.Equal(t, 42, h.RequestSeq) + assert.Equal(t, "test", h.Command) + require.NotNil(t, h.Success) + assert.True(t, *h.Success) +} + +func TestPatchJSONFieldsEmptyFieldsIsNoOp(t *testing.T) { + t.Parallel() + + original := `{"seq":1,"type":"request"}` + raw := &RawMessage{Data: []byte(original)} + require.NoError(t, raw.patchJSONFields(map[string]int{})) + assert.Equal(t, original, string(raw.Data)) +} + +func TestPatchJSONFieldsPreservesBody(t *testing.T) { + t.Parallel() + + raw := &RawMessage{Data: []byte(`{"seq":1,"type":"response","command":"test","request_seq":3,"success":true,"body":{"value":"test"}}`)} + require.NoError(t, raw.patchJSONFields(map[string]int{"seq": 42})) + assert.Contains(t, string(raw.Data), `"body"`) + assert.Contains(t, string(raw.Data), `"value":"test"`) +} + +func TestPatchJSONFieldsInvalidatesHeaderCache(t *testing.T) { + t.Parallel() + + raw := &RawMessage{Data: []byte(`{"seq":1,"type":"request","command":"test"}`)} + // Populate cache + h1 := raw.parseHeader() + assert.Equal(t, 1, h1.Seq) + assert.NotNil(t, raw.header) + // Patch + require.NoError(t, raw.patchJSONFields(map[string]int{"seq": 99})) + // Cache should be invalidated + assert.Nil(t, raw.header) + // Re-parse should reflect new value + h2 := raw.parseHeader() + assert.Equal(t, 99, h2.Seq) +} diff --git a/internal/dap/testclient_test.go b/internal/dap/testclient_test.go new file mode 100644 index 00000000..4d1c963c --- /dev/null +++ b/internal/dap/testclient_test.go @@ -0,0 +1,418 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/google/go-dap" + "github.com/microsoft/dcp/pkg/syncmap" +) + +// TestClient is a DAP client for testing purposes. +// It provides helper methods for common DAP operations. +type TestClient struct { + transport Transport + seq atomic.Int64 + + // eventChan receives events from the server + eventChan chan dap.Message + + // responseChans tracks pending requests waiting for responses + responseChans syncmap.Map[int, chan dap.Message] + + // ctx controls the client lifecycle + ctx context.Context + cancel context.CancelFunc + + // wg tracks reader goroutine + wg sync.WaitGroup +} + +// NewTestClient creates a new DAP test client with the given transport. +// The client's lifecycle is bound to the provided context. +func NewTestClient(ctx context.Context, transport Transport) *TestClient { + ctx, cancel := context.WithCancel(ctx) + c := &TestClient{ + transport: transport, + eventChan: make(chan dap.Message, 100), + ctx: ctx, + cancel: cancel, + } + + c.wg.Add(1) + go c.readLoop() + + return c +} + +// readLoop continuously reads messages from the transport and routes them. +func (c *TestClient) readLoop() { + defer c.wg.Done() + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + msg, readErr := c.transport.ReadMessage() + if readErr != nil { + if c.ctx.Err() != nil { + return + } + // Log error and continue or return based on error type + return + } + + // Route based on message type + switch m := msg.(type) { + case dap.ResponseMessage: + resp := m.GetResponse() + if ch, ok := c.responseChans.LoadAndDelete(resp.RequestSeq); ok { + ch <- msg + } + + case dap.EventMessage: + select { + case c.eventChan <- msg: + default: + // Event channel full, drop oldest + select { + case <-c.eventChan: + default: + } + c.eventChan <- msg + } + } + } +} + +// nextSeq returns the next sequence number. +func (c *TestClient) nextSeq() int { + return int(c.seq.Add(1)) +} + +// responseError extracts a detailed error message from a DAP response. +// If the response is an ErrorResponse, it extracts the message and body error details. +// Otherwise, it returns a generic "unexpected response type" error. +func responseError(resp dap.Message, expectedType string) error { + if errResp, ok := resp.(*dap.ErrorResponse); ok { + if errResp.Body.Error != nil { + return fmt.Errorf("%s failed: %s (error %d: %s)", + expectedType, errResp.Message, errResp.Body.Error.Id, errResp.Body.Error.Format) + } + return fmt.Errorf("%s failed: %s", expectedType, errResp.Message) + } + return fmt.Errorf("unexpected response type for %s: %T", expectedType, resp) +} + +// sendRequest sends a request and waits for the response. +func (c *TestClient) sendRequest(ctx context.Context, req dap.RequestMessage) (dap.Message, error) { + request := req.GetRequest() + seq := c.nextSeq() + request.Seq = seq + + // Create response channel + respChan := make(chan dap.Message, 1) + c.responseChans.Store(seq, respChan) + + // Send request + if writeErr := c.transport.WriteMessage(req); writeErr != nil { + c.responseChans.Delete(seq) + return nil, fmt.Errorf("failed to send request: %w", writeErr) + } + + // Wait for response + select { + case resp := <-respChan: + return resp, nil + case <-ctx.Done(): + c.responseChans.Delete(seq) + return nil, ctx.Err() + } +} + +// Initialize sends an initialize request and returns the capabilities. +func (c *TestClient) Initialize(ctx context.Context) (*dap.InitializeResponse, error) { + req := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Type: "request"}, + Command: "initialize", + }, + Arguments: dap.InitializeRequestArguments{ + ClientID: "test-client", + ClientName: "DAP Test Client", + AdapterID: "go", + Locale: "en-US", + LinesStartAt1: true, + ColumnsStartAt1: true, + PathFormat: "path", + SupportsRunInTerminalRequest: true, + }, + } + + resp, sendErr := c.sendRequest(ctx, req) + if sendErr != nil { + return nil, sendErr + } + + initResp, ok := resp.(*dap.InitializeResponse) + if !ok { + return nil, responseError(resp, "initialize") + } + + if !initResp.Success { + return nil, fmt.Errorf("initialize failed: %s", initResp.Message) + } + + return initResp, nil +} + +// Launch sends a launch request to debug the given program. +func (c *TestClient) Launch(ctx context.Context, program string, stopOnEntry bool) error { + args := map[string]interface{}{ + "mode": "exec", + "program": program, + "stopOnEntry": stopOnEntry, + } + argsJSON, marshalErr := json.Marshal(args) + if marshalErr != nil { + return fmt.Errorf("failed to marshal launch arguments: %w", marshalErr) + } + + req := &dap.LaunchRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Type: "request"}, + Command: "launch", + }, + Arguments: argsJSON, + } + + resp, sendErr := c.sendRequest(ctx, req) + if sendErr != nil { + return sendErr + } + + launchResp, ok := resp.(*dap.LaunchResponse) + if !ok { + return responseError(resp, "launch") + } + + if !launchResp.Success { + return fmt.Errorf("launch failed: %s", launchResp.Message) + } + + return nil +} + +// SetBreakpoints sets breakpoints in the given file at the specified lines. +func (c *TestClient) SetBreakpoints(ctx context.Context, file string, lines []int) (*dap.SetBreakpointsResponse, error) { + breakpoints := make([]dap.SourceBreakpoint, len(lines)) + for i, line := range lines { + breakpoints[i] = dap.SourceBreakpoint{Line: line} + } + + req := &dap.SetBreakpointsRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Type: "request"}, + Command: "setBreakpoints", + }, + Arguments: dap.SetBreakpointsArguments{ + Source: dap.Source{ + Path: file, + }, + Breakpoints: breakpoints, + }, + } + + resp, sendErr := c.sendRequest(ctx, req) + if sendErr != nil { + return nil, sendErr + } + + bpResp, ok := resp.(*dap.SetBreakpointsResponse) + if !ok { + return nil, responseError(resp, "setBreakpoints") + } + + if !bpResp.Success { + return nil, fmt.Errorf("setBreakpoints failed: %s", bpResp.Message) + } + + return bpResp, nil +} + +// ConfigurationDone signals that configuration is complete. +func (c *TestClient) ConfigurationDone(ctx context.Context) error { + req := &dap.ConfigurationDoneRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Type: "request"}, + Command: "configurationDone", + }, + } + + resp, sendErr := c.sendRequest(ctx, req) + if sendErr != nil { + return sendErr + } + + configResp, ok := resp.(*dap.ConfigurationDoneResponse) + if !ok { + return responseError(resp, "configurationDone") + } + + if !configResp.Success { + return fmt.Errorf("configurationDone failed: %s", configResp.Message) + } + + return nil +} + +// Continue resumes execution of all threads. +func (c *TestClient) Continue(ctx context.Context, threadID int) error { + req := &dap.ContinueRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Type: "request"}, + Command: "continue", + }, + Arguments: dap.ContinueArguments{ + ThreadId: threadID, + }, + } + + resp, sendErr := c.sendRequest(ctx, req) + if sendErr != nil { + return sendErr + } + + contResp, ok := resp.(*dap.ContinueResponse) + if !ok { + return responseError(resp, "continue") + } + + if !contResp.Success { + return fmt.Errorf("continue failed: %s", contResp.Message) + } + + return nil +} + +// Disconnect sends a disconnect request to terminate the debug session. +func (c *TestClient) Disconnect(ctx context.Context, terminateDebuggee bool) error { + req := &dap.DisconnectRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Type: "request"}, + Command: "disconnect", + }, + Arguments: &dap.DisconnectArguments{ + TerminateDebuggee: terminateDebuggee, + }, + } + + resp, sendErr := c.sendRequest(ctx, req) + if sendErr != nil { + return sendErr + } + + disconnResp, ok := resp.(*dap.DisconnectResponse) + if !ok { + return responseError(resp, "disconnect") + } + + if !disconnResp.Success { + return fmt.Errorf("disconnect failed: %s", disconnResp.Message) + } + + return nil +} + +// WaitForEvent waits for an event of the specified type. +// Returns the event or an error if timeout expires. +func (c *TestClient) WaitForEvent(eventType string, timeout time.Duration) (dap.Message, error) { + deadline := time.After(timeout) + + for { + select { + case msg := <-c.eventChan: + if event, ok := msg.(dap.EventMessage); ok { + if event.GetEvent().Event == eventType { + return msg, nil + } + } + // Not the event we're looking for, continue waiting + + case <-deadline: + return nil, fmt.Errorf("timeout waiting for event %q", eventType) + + case <-c.ctx.Done(): + return nil, c.ctx.Err() + } + } +} + +// WaitForStoppedEvent waits for a stopped event and returns the thread ID. +func (c *TestClient) WaitForStoppedEvent(timeout time.Duration) (*dap.StoppedEvent, error) { + msg, waitErr := c.WaitForEvent("stopped", timeout) + if waitErr != nil { + return nil, waitErr + } + + stoppedEvent, ok := msg.(*dap.StoppedEvent) + if !ok { + return nil, fmt.Errorf("unexpected event type: %T", msg) + } + + return stoppedEvent, nil +} + +// WaitForTerminatedEvent waits for a terminated event. +func (c *TestClient) WaitForTerminatedEvent(timeout time.Duration) error { + _, waitErr := c.WaitForEvent("terminated", timeout) + return waitErr +} + +// CollectEventsUntil collects all events until a specific event type is received. +// Returns the collected events in order, with the target event last. +// This is useful for verifying event ordering. +func (c *TestClient) CollectEventsUntil(targetEventType string, timeout time.Duration) ([]dap.Message, error) { + deadline := time.After(timeout) + var events []dap.Message + + for { + select { + case msg := <-c.eventChan: + events = append(events, msg) + if event, ok := msg.(dap.EventMessage); ok { + if event.GetEvent().Event == targetEventType { + return events, nil + } + } + + case <-deadline: + return events, fmt.Errorf("timeout waiting for event %q (collected %d events)", targetEventType, len(events)) + + case <-c.ctx.Done(): + return events, c.ctx.Err() + } + } +} + +// Close closes the client and its transport. +func (c *TestClient) Close() error { + c.cancel() + // Close the transport first to unblock any pending reads + closeErr := c.transport.Close() + // Then wait for goroutines to finish + c.wg.Wait() + return closeErr +} diff --git a/internal/dap/transport.go b/internal/dap/transport.go new file mode 100644 index 00000000..5a0a2546 --- /dev/null +++ b/internal/dap/transport.go @@ -0,0 +1,159 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/google/go-dap" + dcpio "github.com/microsoft/dcp/pkg/io" +) + +// ErrTransportClosed is returned when a read or write is attempted on a transport +// that has been intentionally closed via Close(). This distinguishes expected +// shutdown errors from unexpected connection failures. +var ErrTransportClosed = errors.New("transport closed") + +// Transport provides an abstraction for DAP message I/O over different connection types. +// Implementations must be safe for concurrent use by multiple goroutines for reading +// and writing, but individual reads and writes may not be concurrent with each other. +type Transport interface { + // ReadMessage reads the next DAP protocol message from the transport. + // Returns the message or an error if reading fails. + // This method blocks until a complete message is available. + ReadMessage() (dap.Message, error) + + // WriteMessage writes a DAP protocol message to the transport. + // Returns an error if writing fails. + WriteMessage(msg dap.Message) error + + // Close closes the transport, releasing any associated resources. + // After Close is called, any blocked ReadMessage or WriteMessage calls + // should return with an error. + Close() error +} + +// connTransport implements Transport over any connection that provides +// an io.Reader for incoming data and an io.Writer for outgoing data. +// It is used for TCP, Unix domain socket, and stdio-based transports. +type connTransport struct { + reader *bufio.Reader + writer *bufio.Writer + closer io.Closer + + // closed tracks whether Close() has been called. This is used to wrap + // subsequent read/write errors with ErrTransportClosed so callers can + // distinguish intentional shutdown from unexpected failures. + closed *atomic.Bool + + // writeMu serializes message writes. Each DAP message is sent as a + // content-length header followed by the message body in separate writes, + // then flushed. The mutex ensures this multi-write sequence is atomic + // so concurrent WriteMessage calls cannot interleave their bytes. + writeMu sync.Mutex +} + +// NewTCPTransportWithContext creates a new Transport backed by a TCP connection +// that respects context cancellation. When the context is cancelled, any blocked +// reads will be unblocked by closing the connection. +func NewTCPTransportWithContext(ctx context.Context, conn net.Conn) Transport { + return newConnTransport(ctx, conn, conn, conn) +} + +// NewStdioTransportWithContext creates a new Transport backed by stdin and stdout streams +// that respects context cancellation. When the context is cancelled, any blocked +// reads will be unblocked by closing the stdin stream. +func NewStdioTransportWithContext(ctx context.Context, stdin io.ReadCloser, stdout io.WriteCloser) Transport { + return newConnTransport(ctx, stdin, stdout, multiCloser{stdin, stdout}) +} + +// NewUnixTransportWithContext creates a new Transport backed by a Unix domain socket connection +// that respects context cancellation. When the context is cancelled, any blocked +// reads will be unblocked by closing the connection. +func NewUnixTransportWithContext(ctx context.Context, conn net.Conn) Transport { + return newConnTransport(ctx, conn, conn, conn) +} + +// newConnTransport creates a connTransport from separate read, write, and close resources. +// A ContextReader wraps the reader so that context cancellation unblocks pending reads. +func newConnTransport(ctx context.Context, r io.Reader, w io.Writer, closer io.Closer) Transport { + contextReader := dcpio.NewContextReader(ctx, r, true) + return &connTransport{ + reader: bufio.NewReader(contextReader), + writer: bufio.NewWriter(w), + closer: closer, + closed: &atomic.Bool{}, + } +} + +func (t *connTransport) ReadMessage() (dap.Message, error) { + msg, readErr := ReadMessageWithFallback(t.reader) + if readErr != nil { + if t.closed.Load() { + return nil, fmt.Errorf("%w: %w", ErrTransportClosed, readErr) + } + return nil, fmt.Errorf("failed to read DAP message: %w", readErr) + } + + return msg, nil +} + +func (t *connTransport) WriteMessage(msg dap.Message) error { + t.writeMu.Lock() + defer t.writeMu.Unlock() + + writeErr := WriteMessageWithFallback(t.writer, msg) + if writeErr != nil { + if t.closed.Load() { + return fmt.Errorf("%w: %w", ErrTransportClosed, writeErr) + } + return fmt.Errorf("failed to write DAP message: %w", writeErr) + } + + flushErr := t.writer.Flush() + if flushErr != nil { + if t.closed.Load() { + return fmt.Errorf("%w: %w", ErrTransportClosed, flushErr) + } + return fmt.Errorf("failed to flush DAP message: %w", flushErr) + } + + return nil +} + +func (t *connTransport) Close() error { + t.closed.Store(true) + return t.closer.Close() +} + +// isExpectedShutdownErr returns true if the error is expected during normal +// bridge shutdown — for example, when transports are intentionally closed, +// the context is cancelled, or the remote end disconnects cleanly. +func isExpectedShutdownErr(err error) bool { + return errors.Is(err, ErrTransportClosed) || + errors.Is(err, context.Canceled) || + isExpectedCloseErr(err) +} + +// multiCloser closes multiple io.Closers, joining all errors. +type multiCloser []io.Closer + +func (mc multiCloser) Close() error { + var errs []error + for _, c := range mc { + if closeErr := c.Close(); closeErr != nil { + errs = append(errs, closeErr) + } + } + return errors.Join(errs...) +} diff --git a/internal/dap/transport_test.go b/internal/dap/transport_test.go new file mode 100644 index 00000000..da753371 --- /dev/null +++ b/internal/dap/transport_test.go @@ -0,0 +1,436 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package dap + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/google/go-dap" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/microsoft/dcp/pkg/testutil" +) + +// uniqueSocketPath generates a unique, short socket path for testing. +// macOS has a ~104 character limit for Unix socket paths, so we use +// the system temp directory with a short filename. +func uniqueSocketPath(t *testing.T, suffix string) string { + t.Helper() + socketPath := filepath.Join(os.TempDir(), fmt.Sprintf("dap-%s-%d.sock", suffix, time.Now().UnixNano())) + t.Cleanup(func() { os.Remove(socketPath) }) + return socketPath +} + +// setupTCPPair creates a connected TCP socket pair for testing. +func setupTCPPair(t *testing.T) (clientConn, serverConn net.Conn) { + t.Helper() + + listener, listenErr := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, listenErr) + defer listener.Close() + + var wg sync.WaitGroup + var acceptErr error + wg.Add(1) + go func() { + defer wg.Done() + serverConn, acceptErr = listener.Accept() + }() + + clientConn, dialErr := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, dialErr) + + wg.Wait() + require.NoError(t, acceptErr) + require.NotNil(t, serverConn) + + t.Cleanup(func() { + clientConn.Close() + serverConn.Close() + }) + + return clientConn, serverConn +} + +func TestTCPTransportWriteAndReadMessage(t *testing.T) { + t.Parallel() + + clientConn, serverConn := setupTCPPair(t) + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + clientTransport := NewTCPTransportWithContext(ctx, clientConn) + serverTransport := NewTCPTransportWithContext(ctx, serverConn) + + request := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + + writeErr := clientTransport.WriteMessage(request) + require.NoError(t, writeErr) + + received, readErr := serverTransport.ReadMessage() + require.NoError(t, readErr) + + initReq, ok := received.(*dap.InitializeRequest) + require.True(t, ok) + assert.Equal(t, 1, initReq.Seq) + assert.Equal(t, "initialize", initReq.Command) +} + +func TestTCPTransportClosePreventsFurtherOperations(t *testing.T) { + t.Parallel() + + clientConn, _ := setupTCPPair(t) + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + clientTransport := NewTCPTransportWithContext(ctx, clientConn) + + closeErr := clientTransport.Close() + assert.NoError(t, closeErr) + + writeErr := clientTransport.WriteMessage(&dap.InitializeRequest{}) + assert.Error(t, writeErr) + + // Double close should not panic + _ = clientTransport.Close() +} + +// mockReadWriteCloser implements io.ReadWriteCloser for testing +type mockReadWriteCloser struct { + reader *bytes.Buffer + writer *bytes.Buffer + closed bool + closeErr error + mu sync.Mutex +} + +func newMockReadWriteCloser() *mockReadWriteCloser { + return &mockReadWriteCloser{ + reader: bytes.NewBuffer(nil), + writer: bytes.NewBuffer(nil), + } +} + +func (m *mockReadWriteCloser) Read(p []byte) (n int, err error) { + m.mu.Lock() + if m.closed { + m.mu.Unlock() + return 0, io.EOF + } + m.mu.Unlock() + return m.reader.Read(p) +} + +func (m *mockReadWriteCloser) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.closed { + return 0, io.ErrClosedPipe + } + return m.writer.Write(p) +} + +func (m *mockReadWriteCloser) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return m.closeErr +} + +func TestStdioTransportWriteAndReadMessage(t *testing.T) { + t.Parallel() + + // Create connected pipes + serverRead, clientWrite := io.Pipe() + clientRead, serverWrite := io.Pipe() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + clientTransport := NewStdioTransportWithContext(ctx, clientRead, clientWrite) + serverTransport := NewStdioTransportWithContext(ctx, serverRead, serverWrite) + + defer clientTransport.Close() + defer serverTransport.Close() + + // Send message from client to server + request := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + var received dap.Message + var readErr error + + go func() { + defer wg.Done() + received, readErr = serverTransport.ReadMessage() + }() + + writeErr := clientTransport.WriteMessage(request) + require.NoError(t, writeErr) + + wg.Wait() + + require.NoError(t, readErr) + initReq, ok := received.(*dap.InitializeRequest) + require.True(t, ok) + assert.Equal(t, 1, initReq.Seq) +} + +func TestStdioTransportClosePreventsFurtherOperations(t *testing.T) { + t.Parallel() + + stdin := newMockReadWriteCloser() + stdout := newMockReadWriteCloser() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + transport := NewStdioTransportWithContext(ctx, stdin, stdout) + + closeErr := transport.Close() + assert.NoError(t, closeErr) + + writeErr := transport.WriteMessage(&dap.InitializeRequest{}) + assert.Error(t, writeErr) + + // Double close should be safe + closeErr = transport.Close() + assert.NoError(t, closeErr) +} + +// setupUnixPair creates a connected Unix socket pair for testing. +func setupUnixPair(t *testing.T, suffix string) (clientConn, serverConn net.Conn) { + t.Helper() + + socketPath := uniqueSocketPath(t, suffix) + + listener, listenErr := net.Listen("unix", socketPath) + require.NoError(t, listenErr) + defer listener.Close() + + var wg sync.WaitGroup + var acceptErr error + wg.Add(1) + go func() { + defer wg.Done() + serverConn, acceptErr = listener.Accept() + }() + + clientConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + + wg.Wait() + require.NoError(t, acceptErr) + require.NotNil(t, serverConn) + + t.Cleanup(func() { + clientConn.Close() + serverConn.Close() + }) + + return clientConn, serverConn +} + +func TestUnixTransportWriteAndReadMessage(t *testing.T) { + t.Parallel() + + clientConn, serverConn := setupUnixPair(t, "ut-wr") + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + serverTransport := NewUnixTransportWithContext(ctx, serverConn) + + request := &dap.InitializeRequest{ + Request: dap.Request{ + ProtocolMessage: dap.ProtocolMessage{Seq: 1, Type: "request"}, + Command: "initialize", + }, + } + + writeErr := clientTransport.WriteMessage(request) + require.NoError(t, writeErr) + + received, readErr := serverTransport.ReadMessage() + require.NoError(t, readErr) + + initReq, ok := received.(*dap.InitializeRequest) + require.True(t, ok) + assert.Equal(t, 1, initReq.Seq) + assert.Equal(t, "initialize", initReq.Command) +} + +func TestUnixTransportClosePreventsFurtherOperations(t *testing.T) { + t.Parallel() + + clientConn, _ := setupUnixPair(t, "ut-cl") + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + + closeErr := clientTransport.Close() + assert.NoError(t, closeErr) + + writeErr := clientTransport.WriteMessage(&dap.InitializeRequest{}) + assert.Error(t, writeErr) + + // Double close should not panic + _ = clientTransport.Close() +} + +func TestUnixTransportWithContext(t *testing.T) { + t.Parallel() + + socketPath := uniqueSocketPath(t, "ctx") + + // Create listener + listener, listenErr := net.Listen("unix", socketPath) + require.NoError(t, listenErr) + defer listener.Close() + + // Accept connection in goroutine + var serverConn net.Conn + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + serverConn, _ = listener.Accept() + }() + + // Connect + clientConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + + wg.Wait() + require.NotNil(t, serverConn) + defer serverConn.Close() + + // Create transport with cancellable context + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + clientTransport := NewUnixTransportWithContext(ctx, clientConn) + + // Start a blocking read, signalling when the goroutine is about to block + readStarted := make(chan struct{}) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + close(readStarted) + _, _ = clientTransport.ReadMessage() + }() + + // Wait for the read goroutine to be running before cancelling + <-readStarted + + // Cancel context should unblock the read + cancel() + + select { + case <-readDone: + // Success - read was unblocked + case <-time.After(2 * time.Second): + t.Fatal("read was not unblocked after context cancellation") + } +} + +func TestIsExpectedShutdownErr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + expected bool + }{ + {"nil error", nil, false}, + {"arbitrary error", fmt.Errorf("something went wrong"), false}, + {"ErrTransportClosed", ErrTransportClosed, true}, + {"wrapped ErrTransportClosed", fmt.Errorf("failed to read: %w", ErrTransportClosed), true}, + {"context.Canceled", context.Canceled, true}, + {"wrapped context.Canceled", fmt.Errorf("read failed: %w", context.Canceled), true}, + {"io.EOF", io.EOF, true}, + {"wrapped io.EOF", fmt.Errorf("read: %w", io.EOF), true}, + {"net.ErrClosed", net.ErrClosed, true}, + {"wrapped net.ErrClosed", fmt.Errorf("read: %w", net.ErrClosed), true}, + {"io.ErrClosedPipe", io.ErrClosedPipe, true}, + {"wrapped io.ErrClosedPipe", fmt.Errorf("write: %w", io.ErrClosedPipe), true}, + {"double wrapped ErrTransportClosed", fmt.Errorf("outer: %w", fmt.Errorf("inner: %w", ErrTransportClosed)), true}, + } + + for _, tc := range tests { + result := isExpectedShutdownErr(tc.err) + assert.Equal(t, tc.expected, result, tc.name) + } +} + +func TestTransportClosedReturnsErrTransportClosed(t *testing.T) { + t.Parallel() + + // Create a pair of connected Unix sockets + socketPath := uniqueSocketPath(t, "closed") + + listener, listenErr := net.Listen("unix", socketPath) + require.NoError(t, listenErr) + defer listener.Close() + + var serverConn net.Conn + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + serverConn, _ = listener.Accept() + }() + + clientConn, dialErr := net.Dial("unix", socketPath) + require.NoError(t, dialErr) + wg.Wait() + require.NotNil(t, serverConn) + defer serverConn.Close() + + ctx, cancel := testutil.GetTestContext(t, 5*time.Second) + defer cancel() + + transport := NewUnixTransportWithContext(ctx, clientConn) + + // Close the transport, then attempt to read + closeErr := transport.Close() + require.NoError(t, closeErr) + + _, readErr := transport.ReadMessage() + require.Error(t, readErr) + assert.ErrorIs(t, readErr, ErrTransportClosed, "ReadMessage after Close should return ErrTransportClosed") + assert.True(t, isExpectedShutdownErr(readErr), "error from closed transport should be an expected shutdown error") + + writeErr := transport.WriteMessage(&dap.InitializeRequest{}) + require.Error(t, writeErr) + assert.ErrorIs(t, writeErr, ErrTransportClosed, "WriteMessage after Close should return ErrTransportClosed") + assert.True(t, isExpectedShutdownErr(writeErr), "error from closed transport should be an expected shutdown error") +} diff --git a/internal/dap/transport_unix.go b/internal/dap/transport_unix.go new file mode 100644 index 00000000..8e618bcd --- /dev/null +++ b/internal/dap/transport_unix.go @@ -0,0 +1,25 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +//go:build !windows + +package dap + +import ( + "errors" + "io" + "net" + "syscall" +) + +// isExpectedCloseErr returns true if the error is expected when a network +// connection or pipe is closed. This is used to suppress error-level logging +// for errors that occur as a normal consequence of shutting down transports. +func isExpectedCloseErr(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, io.ErrClosedPipe) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) +} diff --git a/internal/dap/transport_windows.go b/internal/dap/transport_windows.go new file mode 100644 index 00000000..91ffdb8b --- /dev/null +++ b/internal/dap/transport_windows.go @@ -0,0 +1,25 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +//go:build windows + +package dap + +import ( + "errors" + "io" + "net" + "syscall" +) + +// isExpectedCloseErr returns true if the error is expected when a network +// connection or pipe is closed. This is used to suppress error-level logging +// for errors that occur as a normal consequence of shutting down transports. +func isExpectedCloseErr(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, io.ErrClosedPipe) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.WSAECONNRESET) +} diff --git a/internal/dcp/bootstrap/dcp_run.go b/internal/dcp/bootstrap/dcp_run.go index f0dd0955..03d38e2d 100644 --- a/internal/dcp/bootstrap/dcp_run.go +++ b/internal/dcp/bootstrap/dcp_run.go @@ -200,14 +200,7 @@ func DcpRun( func createNotificationSource(lifetimeCtx context.Context, log logr.Logger) (notifications.UnixSocketNotificationSource, error) { const noNotifications = "Notifications will not be sent to controller process" - socketPath, socketPathErr := notifications.PrepareNotificationSocketPath("", "dcp-notify-sock-") - if socketPathErr != nil { - retErr := fmt.Errorf("failed to prepare notification socket path: %w", socketPathErr) - log.Error(socketPathErr, noNotifications) - return nil, retErr - } - - ns, nsErr := notifications.NewNotificationSource(lifetimeCtx, socketPath, log) + ns, nsErr := notifications.NewNotificationSource(lifetimeCtx, "", "dcp-notify-sock-", log) if nsErr != nil { retErr := fmt.Errorf("failed to create notification source: %w", nsErr) log.Error(nsErr, noNotifications) diff --git a/internal/dcpproc/commands/container.go b/internal/dcpproc/commands/container.go index 866c52c2..529f4967 100644 --- a/internal/dcpproc/commands/container.go +++ b/internal/dcpproc/commands/container.go @@ -108,7 +108,7 @@ func monitorContainer(log logr.Logger) func(cmd *cobra.Command, args []string) e } defer pe.Dispose() - monitorCtx, monitorCtxCancel, monitorCtxErr := cmds.MonitorPid(cmd.Context(), monitorPid, monitorProcessStartTime, monitorInterval, log) + monitorCtx, monitorCtxCancel, monitorCtxErr := cmds.MonitorPid(cmd.Context(), process.NewHandle(monitorPid, monitorProcessStartTime), monitorInterval, log) defer monitorCtxCancel() if monitorCtxErr != nil { if errors.Is(monitorCtxErr, os.ErrProcessDone) { diff --git a/internal/dcpproc/commands/process.go b/internal/dcpproc/commands/process.go index 90b3c36f..a44d53d0 100644 --- a/internal/dcpproc/commands/process.go +++ b/internal/dcpproc/commands/process.go @@ -65,14 +65,14 @@ func monitorProcess(log logr.Logger) func(cmd *cobra.Command, args []string) err log = log.WithValues(logger.RESOURCE_LOG_STREAM_ID, resourceId) } - monitorCtx, monitorCtxCancel, monitorCtxErr := cmds.MonitorPid(cmd.Context(), monitorPid, monitorProcessStartTime, monitorInterval, log) + monitorCtx, monitorCtxCancel, monitorCtxErr := cmds.MonitorPid(cmd.Context(), process.NewHandle(monitorPid, monitorProcessStartTime), monitorInterval, log) defer monitorCtxCancel() if monitorCtxErr != nil { if errors.Is(monitorCtxErr, os.ErrProcessDone) { // If the monitor process is already terminated, stop the service immediately log.Info("Monitored process already exited, shutting down child process...") executor := process.NewOSExecutor(log) - stopErr := executor.StopProcess(childPid, childProcessStartTime) + stopErr := executor.StopProcess(process.NewHandle(childPid, childProcessStartTime)) if stopErr != nil { log.Error(stopErr, "Failed to stop child process") return stopErr @@ -85,7 +85,7 @@ func monitorProcess(log logr.Logger) func(cmd *cobra.Command, args []string) err } } - childProcessCtx, childProcessCtxCancel, childMonitorErr := cmds.MonitorPid(cmd.Context(), childPid, childProcessStartTime, monitorInterval, log) + childProcessCtx, childProcessCtxCancel, childMonitorErr := cmds.MonitorPid(cmd.Context(), process.NewHandle(childPid, childProcessStartTime), monitorInterval, log) defer childProcessCtxCancel() if childMonitorErr != nil { // Log as Info--we might leak the child process if regular cleanup fails, but this should be rare. @@ -105,7 +105,7 @@ func monitorProcess(log logr.Logger) func(cmd *cobra.Command, args []string) err if childProcessCtx.Err() == nil { log.Info("Monitored process exited, shutting down child process") executor := process.NewOSExecutor(log) - stopErr := executor.StopProcess(childPid, childProcessStartTime) + stopErr := executor.StopProcess(process.NewHandle(childPid, childProcessStartTime)) if stopErr != nil { log.Error(stopErr, "Failed to stop child service process") return stopErr diff --git a/internal/dcpproc/commands/stop_process_tree.go b/internal/dcpproc/commands/stop_process_tree.go index 7c9b8cb8..ab3cbcc9 100644 --- a/internal/dcpproc/commands/stop_process_tree.go +++ b/internal/dcpproc/commands/stop_process_tree.go @@ -48,7 +48,7 @@ func stopProcessTree(log logr.Logger) func(cmd *cobra.Command, args []string) er "ProcessStartTime", stopProcessStartTime, ) - _, procErr := process.FindWaitableProcess(stopPid, stopProcessStartTime) + _, procErr := process.FindWaitableProcess(process.NewHandle(stopPid, stopProcessStartTime)) if procErr != nil { log.Error(procErr, "Could not find the process to stop") return procErr @@ -61,7 +61,7 @@ func stopProcessTree(log logr.Logger) func(cmd *cobra.Command, args []string) er } pe := process.NewOSExecutor(log) - stopErr := pe.StopProcess(stopPid, stopProcessStartTime) + stopErr := pe.StopProcess(process.NewHandle(stopPid, stopProcessStartTime)) if stopErr != nil { log.Error(stopErr, "Failed to stop process tree") return stopErr diff --git a/internal/dcpproc/dcpproc_api.go b/internal/dcpproc/dcpproc_api.go index 25979598..44f41d33 100644 --- a/internal/dcpproc/dcpproc_api.go +++ b/internal/dcpproc/dcpproc_api.go @@ -35,22 +35,21 @@ const ( // so monitoring DCPCTRL is a safe bet. func RunProcessWatcher( pe process.Executor, - childPid process.Pid_t, - childStartTime time.Time, + child process.ProcessHandle, log logr.Logger, ) { if _, found := os.LookupEnv(DCP_DISABLE_MONITOR_PROCESS); found { return } - log = log.WithValues("ChildPID", childPid) + log = log.WithValues("ChildPID", child.Pid) cmdArgs := []string{ "monitor-process", - "--child", strconv.FormatInt(int64(childPid), 10), + "--child", strconv.FormatInt(int64(child.Pid), 10), } - if !childStartTime.IsZero() { - cmdArgs = append(cmdArgs, "--child-identity-time", childStartTime.Format(osutil.RFC3339MiliTimestampFormat)) + if !child.IdentityTime.IsZero() { + cmdArgs = append(cmdArgs, "--child-identity-time", child.IdentityTime.Format(osutil.RFC3339MiliTimestampFormat)) } cmdArgs = append(cmdArgs, getMonitorCmdArgs()...) @@ -91,18 +90,17 @@ func RunContainerWatcher( func StopProcessTree( ctx context.Context, pe process.Executor, - rootPid process.Pid_t, - rootProcessStartTime time.Time, + root process.ProcessHandle, log logr.Logger, ) error { - log = log.WithValues("RootPID", rootPid) + log = log.WithValues("RootPID", root.Pid) cmdArgs := []string{ "stop-process-tree", - "--pid", strconv.FormatInt(int64(rootPid), 10), + "--pid", strconv.FormatInt(int64(root.Pid), 10), } - if !rootProcessStartTime.IsZero() { - cmdArgs = append(cmdArgs, "--process-start-time", rootProcessStartTime.Format(osutil.RFC3339MiliTimestampFormat)) + if !root.IdentityTime.IsZero() { + cmdArgs = append(cmdArgs, "--process-start-time", root.IdentityTime.Format(osutil.RFC3339MiliTimestampFormat)) } dcpPath, dcpPathErr := dcppaths.GetDcpExePath() @@ -114,14 +112,14 @@ func StopProcessTree( stopProcessTreeCmd.Env = os.Environ() // Use DCP CLI environment logger.WithSessionId(stopProcessTreeCmd) // Ensure the session ID is passed to the monitor command - exitCode, err := process.RunWithTimeout(ctx, pe, stopProcessTreeCmd) - if err != nil { - log.Error(err, "Failed to stop process tree", "ExitCode", exitCode) - return err + exitCode, runErr := process.RunWithTimeout(ctx, pe, stopProcessTreeCmd) + if runErr != nil { + log.Error(runErr, "Failed to stop process tree", "ExitCode", exitCode) + return runErr } else if exitCode != 0 { - err = fmt.Errorf("'dcp stop-process-tree --pid %d' command returned non-zero exit code: %d", rootPid, exitCode) - log.Error(err, "Failed to stop process tree", "ExitCode", exitCode) - return err + runErr = fmt.Errorf("'dcp stop-process-tree --pid %d' command returned non-zero exit code: %d", root.Pid, exitCode) + log.Error(runErr, "Failed to stop process tree", "ExitCode", exitCode) + return runErr } return nil @@ -151,7 +149,7 @@ func startDcpProc(pe process.Executor, cmdArgs []string) error { dcpProcCmd := exec.Command(dcpPath, cmdArgs...) dcpProcCmd.Env = os.Environ() // Use DCP CLI environment logger.WithSessionId(dcpProcCmd) // Ensure the session ID is passed to the monitor command - _, _, monitorErr := pe.StartAndForget(dcpProcCmd, process.CreationFlagsNone) + _, monitorErr := pe.StartAndForget(dcpProcCmd, process.CreationFlagsNone) return monitorErr } @@ -167,20 +165,21 @@ func SimulateStopProcessTreeCommand(pe *internal_testutil.ProcessExecution) int3 if pidErr != nil { return 3 // Invalid PID } - var startTime time.Time + var handle process.ProcessHandle + handle.Pid = pid i = slices.Index(pe.Cmd.Args, "--process-start-time") if i >= 0 && len(pe.Cmd.Args) > i+1 { - var startTimeErr error - startTime, startTimeErr = time.Parse(osutil.RFC3339MiliTimestampFormat, pe.Cmd.Args[i+1]) + startTime, startTimeErr := time.Parse(osutil.RFC3339MiliTimestampFormat, pe.Cmd.Args[i+1]) if startTimeErr != nil { return 4 // Invalid start time } + handle.IdentityTime = startTime } // We do not simulate stopping the whole process tree (or process parent-child relationships, for that matter). // We can consider adding it if we have tests that require it (currently none). - stopErr := pe.Executor.StopProcess(pid, startTime) + stopErr := pe.Executor.StopProcess(handle) if stopErr != nil { return 5 // Failed to stop the process } diff --git a/internal/dcpproc/dcpproc_api_test.go b/internal/dcpproc/dcpproc_api_test.go index b15d3f28..3f93d8cd 100644 --- a/internal/dcpproc/dcpproc_api_test.go +++ b/internal/dcpproc/dcpproc_api_test.go @@ -35,7 +35,7 @@ func TestRunProcessWatcher(t *testing.T) { testPid := process.Pid_t(28869) testStartTime := time.Now() - RunProcessWatcher(pe, testPid, testStartTime, log) + RunProcessWatcher(pe, process.NewHandle(testPid, testStartTime), log) dcpProc, dcpProcErr := findRunningDcp(pe) require.NoError(t, dcpProcErr) @@ -103,9 +103,9 @@ func TestStopProcessTree(t *testing.T) { }, }) - pid, startTime, startErr := pex.StartAndForget(testCmd, process.CreationFlagsNone) + handle, startErr := pex.StartAndForget(testCmd, process.CreationFlagsNone) require.NoError(t, startErr, "Could not simulate starting test process") - testProc, found := pex.FindByPid(pid) + testProc, found := pex.FindByPid(handle.Pid) require.True(t, found, "Could not find the started process") var dcpProc *internal_testutil.ProcessExecution @@ -119,7 +119,7 @@ func TestStopProcessTree(t *testing.T) { }, }) - stopProcessTreeErr := StopProcessTree(ctx, pex, pid, startTime, log) + stopProcessTreeErr := StopProcessTree(ctx, pex, handle, log) require.NoError(t, stopProcessTreeErr, "Could not stop the process tree") require.True(t, testProc.Finished(), "The test processed should have been stopped") @@ -129,9 +129,9 @@ func TestStopProcessTree(t *testing.T) { require.Equal(t, "stop-process-tree", dcpProc.Cmd.Args[1], "Should use 'stop-process-tree' subcommand") require.Equal(t, dcpProc.Cmd.Args[2], "--pid", "Should include --pid flag") - require.Equal(t, dcpProc.Cmd.Args[3], strconv.FormatInt(int64(pid), 10), "Should include test process ID") + require.Equal(t, dcpProc.Cmd.Args[3], strconv.FormatInt(int64(handle.Pid), 10), "Should include test process ID") require.Equal(t, dcpProc.Cmd.Args[4], "--process-start-time", "Should include --process-start-time flag") - require.Equal(t, dcpProc.Cmd.Args[5], startTime.Format(osutil.RFC3339MiliTimestampFormat), "Should include formatted process start time") + require.Equal(t, dcpProc.Cmd.Args[5], handle.IdentityTime.Format(osutil.RFC3339MiliTimestampFormat), "Should include formatted process start time") } func findRunningDcp(pe *internal_testutil.TestProcessExecutor) (*internal_testutil.ProcessExecution, error) { diff --git a/internal/docker/cli_orchestrator.go b/internal/docker/cli_orchestrator.go index e55011e9..a4bae512 100644 --- a/internal/docker/cli_orchestrator.go +++ b/internal/docker/cli_orchestrator.go @@ -677,7 +677,7 @@ func (dco *DockerCliOrchestrator) ExecContainer(ctx context.Context, options con } dco.log.V(1).Info("Running Docker command", "Command", cmd.String()) - _, _, startWaitForProcessExit, err := dco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) + _, startWaitForProcessExit, err := dco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) if err != nil { close(exitCh) return nil, errors.Join(err, fmt.Errorf("failed to start Docker command '%s'", "ExecContainer")) @@ -1097,13 +1097,13 @@ func (dco *DockerCliOrchestrator) doWatchContainers(watcherCtx context.Context, // Container events are delivered on best-effort basis. // If the "docker events" command fails unexpectedly, we will log the error, // but we won't try to restart it. - pid, startTime, startWaitForProcessExit, err := dco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) + handle, startWaitForProcessExit, err := dco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) if err != nil { dco.log.Error(err, "Could not execute 'docker events' command; container events unavailable") return } - dcpproc.RunProcessWatcher(dco.executor, pid, startTime, dco.log) + dcpproc.RunProcessWatcher(dco.executor, handle, dco.log) startWaitForProcessExit() @@ -1115,7 +1115,7 @@ func (dco *DockerCliOrchestrator) doWatchContainers(watcherCtx context.Context, } case <-watcherCtx.Done(): // We are asked to shut down - dco.log.V(1).Info("Stopping 'docker events' command", "pid", pid) + dco.log.V(1).Info("Stopping 'docker events' command", "pid", handle.Pid) } } @@ -1156,13 +1156,13 @@ func (dco *DockerCliOrchestrator) doWatchNetworks(watcherCtx context.Context, ss // Container events are delivered on best-effort basis. // If the "docker events" command fails unexpectedly, we will log the error, // but we won't try to restart it. - pid, startTime, startWaitForProcessExit, err := dco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) + handle, startWaitForProcessExit, err := dco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) if err != nil { dco.log.Error(err, "Could not execute 'docker events' command; network events unavailable") return } - dcpproc.RunProcessWatcher(dco.executor, pid, startTime, dco.log) + dcpproc.RunProcessWatcher(dco.executor, handle, dco.log) startWaitForProcessExit() @@ -1174,7 +1174,7 @@ func (dco *DockerCliOrchestrator) doWatchNetworks(watcherCtx context.Context, ss } case <-watcherCtx.Done(): // We are asked to shut down - dco.log.V(1).Info("Stopping 'docker events' command", "PID", pid) + dco.log.V(1).Info("Stopping 'docker events' command", "PID", handle.Pid) } } @@ -1209,14 +1209,14 @@ func (dco *DockerCliOrchestrator) streamDockerCommand( } dco.log.V(1).Info("Running Docker command", "Command", cmd.String()) - pid, startTime, startWaitForProcessExit, err := dco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) + handle, startWaitForProcessExit, err := dco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) if err != nil { close(exitCh) return nil, errors.Join(err, fmt.Errorf("failed to start Docker command '%s'", commandName)) } if opts&streamCommandOptionUseWatcher != 0 { - dcpproc.RunProcessWatcher(dco.executor, pid, startTime, dco.log) + dcpproc.RunProcessWatcher(dco.executor, handle, dco.log) } startWaitForProcessExit() diff --git a/internal/exerunners/bridge_output_handler.go b/internal/exerunners/bridge_output_handler.go new file mode 100644 index 00000000..047cbd03 --- /dev/null +++ b/internal/exerunners/bridge_output_handler.go @@ -0,0 +1,49 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package exerunners + +import ( + "io" + + "github.com/microsoft/dcp/internal/dap" +) + +// bridgeOutputHandler routes DAP output events to the appropriate writers +// based on their category. It implements dap.OutputHandler. +type bridgeOutputHandler struct { + stdout io.Writer + stderr io.Writer +} + +var _ dap.OutputHandler = (*bridgeOutputHandler)(nil) + +// newBridgeOutputHandler creates a new bridgeOutputHandler that routes +// "stdout" and "console" output to the stdout writer, and "stderr" output +// to the stderr writer. Either writer may be nil, in which case output +// for that category is silently discarded. +func newBridgeOutputHandler(stdout, stderr io.Writer) *bridgeOutputHandler { + return &bridgeOutputHandler{ + stdout: stdout, + stderr: stderr, + } +} + +// HandleOutput routes the output to the appropriate writer based on category. +// "stdout" and "console" categories are written to the stdout writer. +// "stderr" category is written to the stderr writer. +// Other categories are silently discarded. +func (h *bridgeOutputHandler) HandleOutput(category string, output string) { + switch category { + case "stdout", "console": + if h.stdout != nil { + _, _ = h.stdout.Write([]byte(output)) + } + case "stderr": + if h.stderr != nil { + _, _ = h.stderr.Write([]byte(output)) + } + } +} diff --git a/internal/exerunners/bridge_output_handler_test.go b/internal/exerunners/bridge_output_handler_test.go new file mode 100644 index 00000000..1e4ba5ea --- /dev/null +++ b/internal/exerunners/bridge_output_handler_test.go @@ -0,0 +1,100 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package exerunners + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBridgeOutputHandler_StdoutCategory(t *testing.T) { + t.Parallel() + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + handler := newBridgeOutputHandler(stdout, stderr) + + handler.HandleOutput("stdout", "hello world\n") + + assert.Equal(t, "hello world\n", stdout.String()) + assert.Empty(t, stderr.String()) +} + +func TestBridgeOutputHandler_StderrCategory(t *testing.T) { + t.Parallel() + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + handler := newBridgeOutputHandler(stdout, stderr) + + handler.HandleOutput("stderr", "error message\n") + + assert.Empty(t, stdout.String()) + assert.Equal(t, "error message\n", stderr.String()) +} + +func TestBridgeOutputHandler_ConsoleCategory(t *testing.T) { + t.Parallel() + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + handler := newBridgeOutputHandler(stdout, stderr) + + handler.HandleOutput("console", "console output\n") + + assert.Equal(t, "console output\n", stdout.String()) + assert.Empty(t, stderr.String()) +} + +func TestBridgeOutputHandler_UnknownCategory(t *testing.T) { + t.Parallel() + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + handler := newBridgeOutputHandler(stdout, stderr) + + handler.HandleOutput("telemetry", "telemetry data\n") + + assert.Empty(t, stdout.String()) + assert.Empty(t, stderr.String()) +} + +func TestBridgeOutputHandler_NilWriters(t *testing.T) { + t.Parallel() + + handler := newBridgeOutputHandler(nil, nil) + + // Should not panic with nil writers + handler.HandleOutput("stdout", "hello\n") + handler.HandleOutput("stderr", "error\n") + handler.HandleOutput("console", "console\n") +} + +func TestBridgeOutputHandler_NilStdoutOnly(t *testing.T) { + t.Parallel() + + stderr := &bytes.Buffer{} + handler := newBridgeOutputHandler(nil, stderr) + + handler.HandleOutput("stdout", "stdout\n") + handler.HandleOutput("stderr", "stderr\n") + + assert.Equal(t, "stderr\n", stderr.String()) +} + +func TestBridgeOutputHandler_NilStderrOnly(t *testing.T) { + t.Parallel() + + stdout := &bytes.Buffer{} + handler := newBridgeOutputHandler(stdout, nil) + + handler.HandleOutput("stdout", "stdout\n") + handler.HandleOutput("stderr", "stderr\n") + + assert.Equal(t, "stdout\n", stdout.String()) +} diff --git a/internal/exerunners/ide_connection_info.go b/internal/exerunners/ide_connection_info.go index 91c4c27c..1dbddf0a 100644 --- a/internal/exerunners/ide_connection_info.go +++ b/internal/exerunners/ide_connection_info.go @@ -137,7 +137,9 @@ func NewIdeConnectionInfo(lifetimeCtx context.Context, log logr.Logger) (*ideCon connInfo.supportedApiVersions = info.ProtocolsSupported // We will use the IDE endpoint ONLY IF we support at least one common API version - if slices.Contains(info.ProtocolsSupported, version20251001) { + if slices.Contains(info.ProtocolsSupported, version20260201) { + connInfo.apiVersion = version20260201 + } else if slices.Contains(info.ProtocolsSupported, version20251001) { connInfo.apiVersion = version20251001 } else if slices.Contains(info.ProtocolsSupported, version20240423) { connInfo.apiVersion = version20240423 @@ -209,3 +211,15 @@ func (connInfo *ideConnectionInfo) GetClient() *http.Client { func (connInfo *ideConnectionInfo) GetDialer() *websocket.Dialer { return connInfo.wsDialer } + +// GetToken returns the security token used for IDE authentication. +// This token is reused for debug bridge session authentication. +func (connInfo *ideConnectionInfo) GetToken() string { + return connInfo.tokenStr +} + +// SupportsDebugBridge returns true if the connected IDE supports the debug bridge feature. +// This is available in API version 2026-02-01 and later. +func (connInfo *ideConnectionInfo) SupportsDebugBridge() bool { + return equalOrNewer(connInfo.apiVersion, version20260201) +} diff --git a/internal/exerunners/ide_executable_runner.go b/internal/exerunners/ide_executable_runner.go index 7cb9ad4e..ab08d3ea 100644 --- a/internal/exerunners/ide_executable_runner.go +++ b/internal/exerunners/ide_executable_runner.go @@ -26,6 +26,7 @@ import ( apiv1 "github.com/microsoft/dcp/api/v1" "github.com/microsoft/dcp/controllers" + "github.com/microsoft/dcp/internal/dap" "github.com/microsoft/dcp/internal/logs" usvc_io "github.com/microsoft/dcp/pkg/io" "github.com/microsoft/dcp/pkg/osutil" @@ -58,6 +59,7 @@ type IdeExecutableRunner struct { lifetimeCtx context.Context // Lifetime context of the controller hosting this runner connectionInfo *ideConnectionInfo notificationHandler *ideNotificationHandler + bridgeManager *dap.BridgeManager // Manager for debug bridge sessions and shared socket } func NewIdeExecutableRunner(lifetimeCtx context.Context, log logr.Logger) (*IdeExecutableRunner, error) { @@ -75,6 +77,21 @@ func NewIdeExecutableRunner(lifetimeCtx context.Context, log logr.Logger) (*IdeE connectionInfo: connInfo, } + // Create and start the bridge manager if the IDE supports debug bridge + if connInfo.SupportsDebugBridge() { + r.bridgeManager = dap.NewBridgeManager(dap.BridgeManagerConfig{ + ConnectionHandler: r.handleBridgeConnection, + }, log.WithName("BridgeManager")) + + // Start the bridge manager in a background goroutine + go func() { + managerErr := r.bridgeManager.Start(lifetimeCtx) + if managerErr != nil && !errors.Is(managerErr, context.Canceled) { + log.Error(managerErr, "Bridge manager terminated with error") + } + }() + } + nh := NewIdeNotificationHandler(lifetimeCtx, r, connInfo, log) r.notificationHandler = nh return r, nil @@ -369,6 +386,42 @@ func (r *IdeExecutableRunner) prepareRunRequestV1(exe *apiv1.Executable) ([]byte Args: exe.Status.EffectiveArgs, } + // Set up debug bridge if IDE supports it and bridge manager is available + if r.connectionInfo.SupportsDebugBridge() && r.bridgeManager != nil { + // Wait for bridge manager to be ready (with timeout) + select { + case <-r.bridgeManager.Ready(): + // Bridge manager is ready + case <-time.After(5 * time.Second): + return nil, fmt.Errorf("timeout waiting for debug bridge manager to be ready") + case <-r.lifetimeCtx.Done(): + return nil, fmt.Errorf("context cancelled while waiting for bridge manager: %w", r.lifetimeCtx.Err()) + } + + sessionID := string(exe.UID) + ideToken := r.connectionInfo.GetToken() + + // Register the session with the IDE's token (reused for bridge authentication) + _, regErr := r.bridgeManager.RegisterSession(sessionID, ideToken) + if regErr != nil { + // If session already exists, that's okay - just continue + if !errors.Is(regErr, dap.ErrBridgeSessionAlreadyExists) { + return nil, fmt.Errorf("failed to register debug bridge session: %w", regErr) + } + } + + var socketErr error + isr.DebugBridgeSocketPath, socketErr = r.bridgeManager.SocketPath(r.lifetimeCtx) + if socketErr != nil { + return nil, fmt.Errorf("failed to get debug bridge socket path: %w", socketErr) + } + isr.DebugSessionID = sessionID + + r.log.Info("Debug bridge session registered", + "sessionID", sessionID, + "socketPath", isr.DebugBridgeSocketPath) + } + isrBody, marshalErr := json.Marshal(isr) if marshalErr != nil { return nil, fmt.Errorf("failed to create Executable run request body: %w", marshalErr) @@ -501,6 +554,26 @@ func (r *IdeExecutableRunner) ensureRunData(runID controllers.RunID) *runData { return rd } +// handleBridgeConnection is the BridgeConnectionHandler callback invoked by the +// BridgeManager when the IDE connects to the debug bridge. It resolves the run data +// for the given run ID and returns an OutputHandler and stdout/stderr writers that +// route debug adapter output into the executable's log files. +// +// The ensureRunData call handles out-of-order arrival: the bridge connection may +// arrive before doStartRun completes. The BufferedWrappingWriter in runData buffers +// output until SetOutputWriters wires up the temp files. +func (r *IdeExecutableRunner) handleBridgeConnection(sessionID string, runID string) (dap.OutputHandler, io.Writer, io.Writer) { + if runID == "" { + r.log.V(1).Info("Bridge connection without RunID, output will not be captured", + "sessionID", sessionID) + return nil, nil, nil + } + + rd := r.ensureRunData(controllers.RunID(runID)) + handler := newBridgeOutputHandler(rd.stdOut, rd.stdErr) + return handler, rd.stdOut, rd.stdErr +} + func (r *IdeExecutableRunner) makeRequest( requestPath string, httpMethod string, diff --git a/internal/exerunners/ide_requests_responses.go b/internal/exerunners/ide_requests_responses.go index cb59fafd..14199ad8 100644 --- a/internal/exerunners/ide_requests_responses.go +++ b/internal/exerunners/ide_requests_responses.go @@ -84,6 +84,14 @@ type ideRunSessionRequestV1 struct { Env []apiv1.EnvVar `json:"env,omitempty"` Args []string `json:"args,omitempty"` + + // Debug bridge fields (added in version 2026-02-01) + // When present, the IDE should connect to the Unix socket at DebugBridgeSocketPath + // and send a handshake message with the IDE session token and DebugSessionID. + // The IDE session token (used for this request's authentication) is reused for + // bridge handshake authentication. + DebugBridgeSocketPath string `json:"debug_bridge_socket_path,omitempty"` + DebugSessionID string `json:"debug_session_id,omitempty"` } type launchConfigurationBase struct { @@ -148,6 +156,7 @@ const ( version20240303 apiVersion = "2024-03-03" version20240423 apiVersion = "2024-04-23" version20251001 apiVersion = "2025-10-01" + version20260201 apiVersion = "2026-02-01" // Added debug bridge support queryParamApiVersion = "api-version" instanceIdHeader = "Microsoft-Developer-DCP-Instance-ID" diff --git a/internal/exerunners/process_executable_runner.go b/internal/exerunners/process_executable_runner.go index d37982c5..4731cef4 100644 --- a/internal/exerunners/process_executable_runner.go +++ b/internal/exerunners/process_executable_runner.go @@ -31,10 +31,10 @@ import ( ) type processRunState struct { - identityTime time.Time - stdOutFile *os.File - stdErrFile *os.File - cmdInfo string // Command line used to start the process, for logging purposes + handle process.ProcessHandle + stdOutFile *os.File + stdErrFile *os.File + cmdInfo string // Command line used to start the process, for logging purposes } type ProcessExecutableRunner struct { @@ -106,7 +106,7 @@ func (r *ProcessExecutableRunner) StartRun( }) // We want to ensure that the service process tree is killed when DCP is stopped so that ports are released etc. - pid, processIdentityTime, startWaitForProcessExit, startErr := r.pe.StartProcess(ctx, cmd, processExitHandler, process.CreationFlagEnsureKillOnDispose) + handle, startWaitForProcessExit, startErr := r.pe.StartProcess(ctx, cmd, processExitHandler, process.CreationFlagEnsureKillOnDispose) if startErr != nil { startLog.Error(startErr, "Failed to start a process") result.CompletionTimestamp = metav1.NowMicro() @@ -127,19 +127,19 @@ func (r *ProcessExecutableRunner) StartRun( return result } else { // Use original log here, the watcher is a different process. - dcpproc.RunProcessWatcher(r.pe, pid, processIdentityTime, log) + dcpproc.RunProcessWatcher(r.pe, handle, log) - r.runningProcesses.Store(pidToRunID(pid), &processRunState{ - identityTime: processIdentityTime, - stdOutFile: stdOutFile, - stdErrFile: stdErrFile, - cmdInfo: cmd.String(), + r.runningProcesses.Store(pidToRunID(handle.Pid), &processRunState{ + handle: handle, + stdOutFile: stdOutFile, + stdErrFile: stdErrFile, + cmdInfo: cmd.String(), }) - result.RunID = pidToRunID(pid) - pointers.SetValue(&result.Pid, int64(pid)) + result.RunID = pidToRunID(handle.Pid) + pointers.SetValue(&result.Pid, int64(handle.Pid)) result.ExeState = apiv1.ExecutableStateRunning - result.CompletionTimestamp = metav1.NewMicroTime(process.StartTimeForProcess(pid)) + result.CompletionTimestamp = metav1.NewMicroTime(process.StartTimeForProcess(handle.Pid)) result.StartWaitForRunCompletion = startWaitForProcessExit runChangeHandler.OnStartupCompleted(exe.NamespacedName(), result) @@ -171,9 +171,9 @@ func (r *ProcessExecutableRunner) StopRun(ctx context.Context, runID controllers // This means we cannot send Ctrl-C to that process directly and need to use dcpproc StopProcessTree facility instead. stopCtx, stopCtxCancel := context.WithTimeout(ctx, ProcessStopTimeout) defer stopCtxCancel() - errCh <- dcpproc.StopProcessTree(stopCtx, r.pe, runIdToPID(runID), runState.identityTime, stopLog) + errCh <- dcpproc.StopProcessTree(stopCtx, r.pe, runState.handle, stopLog) } else { - errCh <- r.pe.StopProcess(runIdToPID(runID), runState.identityTime) + errCh <- r.pe.StopProcess(runState.handle) } }() @@ -216,16 +216,4 @@ func pidToRunID(pid process.Pid_t) controllers.RunID { return controllers.RunID(strconv.FormatInt(int64(pid), 10)) } -func runIdToPID(runID controllers.RunID) process.Pid_t { - pid64, err := strconv.ParseInt(string(runID), 10, 64) - if err != nil { - return process.UnknownPID - } - pid, err := process.Int64_ToPidT(pid64) - if err != nil { - return process.UnknownPID - } - return pid -} - var _ controllers.ExecutableRunner = (*ProcessExecutableRunner)(nil) diff --git a/internal/hosting/command_service.go b/internal/hosting/command_service.go index 0756d655..6ae48a14 100644 --- a/internal/hosting/command_service.go +++ b/internal/hosting/command_service.go @@ -79,7 +79,7 @@ func (s *CommandService) Run(ctx context.Context) error { pic := make(chan process.ProcessExitInfo, 1) peh := process.NewChannelProcessExitHandler(pic) - _, _, startWaitForProcessExit, startErr := s.executor.StartProcess(runCtx, s.cmd, peh, process.CreationFlagsNone) + _, startWaitForProcessExit, startErr := s.executor.StartProcess(runCtx, s.cmd, peh, process.CreationFlagsNone) if startErr != nil { return startErr } diff --git a/internal/networking/unix_socket.go b/internal/networking/unix_socket.go new file mode 100644 index 00000000..18fc783a --- /dev/null +++ b/internal/networking/unix_socket.go @@ -0,0 +1,195 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package networking + +import ( + "context" + "fmt" + "net" + "os" + "path/filepath" + "sync" + "sync/atomic" + + "github.com/microsoft/dcp/internal/dcppaths" + "github.com/microsoft/dcp/pkg/osutil" + "github.com/microsoft/dcp/pkg/randdata" + "github.com/microsoft/dcp/pkg/resiliency" +) + +// PrivateUnixSocketListener manages a Unix domain socket in a directory +// that enforces user-only access permissions. It handles user-only directory creation, +// random socket name generation (to support multiple DCP instances without +// collisions), and socket file lifecycle management. +// +// PrivateUnixSocketListener implements net.Listener and can be used as a drop-in +// replacement anywhere a net.Listener is expected (e.g., gRPC server Serve()). +type PrivateUnixSocketListener struct { + listener net.Listener + socketPath string + + closed atomic.Bool + closeErr error + mu *sync.Mutex +} + +var _ net.Listener = (*PrivateUnixSocketListener)(nil) + +// NewPrivateUnixSocketListener creates a new Unix domain socket listener in a +// user-private directory. The socket file name is generated by combining the given +// prefix with a random suffix to avoid collisions between multiple DCP instances. +// +// If the generated socket path already exists (e.g., belonging to another running +// instance), the function retries with a new random suffix up to +// maxSocketCreateAttempts times. Existing socket files are never removed, as they +// may be in active use by another process. +// +// If socketDir is empty, os.UserCacheDir() is used as the root directory. A "dcp-work" +// subdirectory is created (if it doesn't already exist) with owner-only permissions (0700). +// On Unix-like systems, the directory permissions are validated to ensure privacy. +// +// The socket file permissions are set to owner-only read/write (0600) on a best-effort +// basis — the chmod may not succeed on all platforms. +// +// The caller should call Close() when the listener is no longer needed. Close removes +// the socket file and closes the underlying listener. +func NewPrivateUnixSocketListener(socketDir string, socketNamePrefix string) (*PrivateUnixSocketListener, error) { + privateDir, privateDirErr := PreparePrivateUnixSocketDir(socketDir) + if privateDirErr != nil { + return nil, fmt.Errorf("failed to prepare user-only socket directory: %w", privateDirErr) + } + + // Retry with a new random suffix on path collisions. + return resiliency.RetryGetExponential(context.Background(), func() (*PrivateUnixSocketListener, error) { + suffix, suffixErr := randdata.MakeRandomString(8) + if suffixErr != nil { + return nil, resiliency.Permanent(fmt.Errorf("failed to generate random socket name suffix: %w", suffixErr)) + } + + socketPath := filepath.Join(privateDir, socketNamePrefix+string(suffix)) + + // If a file already exists at this path, it may belong to another running + // DCP instance. Skip this path and retry with a new random suffix. + if _, statErr := os.Stat(socketPath); statErr == nil { + return nil, fmt.Errorf("socket path %s already exists", socketPath) + } + + listener, listenErr := net.Listen("unix", socketPath) + if listenErr != nil { + // The path may have been created between the stat check and the listen call. + // Treat this as a collision and retry. + if os.IsExist(listenErr) { + return nil, fmt.Errorf("socket path %s already in use: %w", socketPath, listenErr) + } + return nil, resiliency.Permanent(fmt.Errorf("failed to create Unix socket listener at %s: %w", socketPath, listenErr)) + } + + // Best-effort: set socket file permissions to owner-only. + // This may not work on all platforms (e.g., Windows) but provides + // defense-in-depth on systems that support it. + _ = os.Chmod(socketPath, osutil.PermissionOnlyOwnerReadWrite) + + return &PrivateUnixSocketListener{ + listener: listener, + socketPath: socketPath, + mu: &sync.Mutex{}, + }, nil + }) +} + +// Accept waits for and returns the next connection to the listener. +// Returns net.ErrClosed if the listener has been closed. +func (l *PrivateUnixSocketListener) Accept() (net.Conn, error) { + if l.closed.Load() { + return nil, net.ErrClosed + } + + conn, acceptErr := l.listener.Accept() + if acceptErr != nil { + // If the listener was closed while we were blocking on Accept(), + // return net.ErrClosed so the caller can distinguish a graceful + // shutdown from an unexpected error. + if l.closed.Load() { + return nil, net.ErrClosed + } + return nil, acceptErr + } + + return conn, nil +} + +// Close closes the listener and removes the socket file. +// Close is idempotent — subsequent calls return the original close error. +func (l *PrivateUnixSocketListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + + if l.closed.Load() { + return l.closeErr + } + + l.closed.Store(true) + + l.closeErr = l.listener.Close() + + // Best effort removal of the socket file. + _ = os.Remove(l.socketPath) + + return l.closeErr +} + +// Addr returns the listener's network address. +func (l *PrivateUnixSocketListener) Addr() net.Addr { + return l.listener.Addr() +} + +// SocketPath returns the full path to the Unix socket file. +// The path includes the randomly generated suffix, so callers must use this +// method to discover the actual socket path after listener creation. +func (l *PrivateUnixSocketListener) SocketPath() string { + return l.socketPath +} + +// PreparePrivateUnixSocketDir ensures a directory exists for creating Unix domain sockets +// that is writable only by the current user. The directory is created under rootDir +// with owner-only traverse permissions (0700). +// +// If rootDir is empty, os.UserCacheDir() is used as the root. +// On non-Windows systems, the directory permissions are validated after creation +// to ensure they have not been tampered with or set incorrectly. +// +// Returns the path to the user-only directory. +func PreparePrivateUnixSocketDir(rootDir string) (string, error) { + if rootDir == "" { + cacheDir, cacheDirErr := os.UserCacheDir() + if cacheDirErr != nil { + return "", fmt.Errorf("failed to get user-only cache directory for socket: %w", cacheDirErr) + } + rootDir = cacheDir + } + + socketDir := filepath.Join(rootDir, dcppaths.DcpWorkDir) + if mkdirErr := os.MkdirAll(socketDir, osutil.PermissionOnlyOwnerReadWriteTraverse); mkdirErr != nil { + return "", fmt.Errorf("failed to create user-only socket directory: %w", mkdirErr) + } + + // On Windows the user cache directory always exists and is always private to the user, + // but on Unix-like systems, we need to verify the directory is private. + if !osutil.IsWindows() { + info, infoErr := os.Stat(socketDir) + if infoErr != nil { + return "", fmt.Errorf("failed to check permissions on socket directory: %w", infoErr) + } + if !info.IsDir() { + return "", fmt.Errorf("socket path %s is not a directory", socketDir) + } + if info.Mode().Perm() != osutil.PermissionOnlyOwnerReadWriteTraverse { + return "", fmt.Errorf("socket directory %s is not private to the user (permissions: %o)", socketDir, info.Mode().Perm()) + } + } + + return socketDir, nil +} diff --git a/internal/networking/unix_socket_test.go b/internal/networking/unix_socket_test.go new file mode 100644 index 00000000..4c0d9665 --- /dev/null +++ b/internal/networking/unix_socket_test.go @@ -0,0 +1,291 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package networking + +import ( + "fmt" + "net" + "os" + "path/filepath" + "runtime" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/microsoft/dcp/internal/dcppaths" + "github.com/microsoft/dcp/pkg/osutil" +) + +// shortTempDir creates a short temporary directory for socket tests. +// macOS has a ~104 character limit for Unix socket paths, so we use +// a short base path. +func shortTempDir(t *testing.T) string { + t.Helper() + dir, dirErr := os.MkdirTemp("", "sck") + require.NoError(t, dirErr) + t.Cleanup(func() { os.RemoveAll(dir) }) + return dir +} + +func TestPrepareSecureSocketDirCreatesDirectoryWithCorrectPermissions(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + socketDir, prepareErr := PreparePrivateUnixSocketDir(rootDir) + require.NoError(t, prepareErr) + + expectedDir := filepath.Join(rootDir, dcppaths.DcpWorkDir) + assert.Equal(t, expectedDir, socketDir) + + info, statErr := os.Stat(socketDir) + require.NoError(t, statErr) + assert.True(t, info.IsDir()) + if runtime.GOOS != "windows" { + assert.Equal(t, osutil.PermissionOnlyOwnerReadWriteTraverse, info.Mode().Perm()) + } +} + +func TestPrepareSecureSocketDirIdempotentOnRepeatedCalls(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + dir1, err1 := PreparePrivateUnixSocketDir(rootDir) + require.NoError(t, err1) + + dir2, err2 := PreparePrivateUnixSocketDir(rootDir) + require.NoError(t, err2) + + assert.Equal(t, dir1, dir2) +} + +func TestPrepareSecureSocketDirFallsBackToUserCacheDir(t *testing.T) { + t.Parallel() + + socketDir, prepareErr := PreparePrivateUnixSocketDir("") + require.NoError(t, prepareErr) + + cacheDir, cacheDirErr := os.UserCacheDir() + require.NoError(t, cacheDirErr) + + expectedDir := filepath.Join(cacheDir, dcppaths.DcpWorkDir) + assert.Equal(t, expectedDir, socketDir) +} + +func TestPrepareSecureSocketDirRejectsWrongPermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission validation is skipped on Windows") + } + + t.Parallel() + rootDir := shortTempDir(t) + + // Pre-create the dcp-work directory with overly-permissive permissions + socketDir := filepath.Join(rootDir, dcppaths.DcpWorkDir) + mkdirErr := os.MkdirAll(socketDir, 0755) + require.NoError(t, mkdirErr) + + _, prepareErr := PreparePrivateUnixSocketDir(rootDir) + require.Error(t, prepareErr) + assert.Contains(t, prepareErr.Error(), "not private to the user") +} + +func TestPrivateUnixSocketListenerCreatesListenerWithRandomName(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + listener, createErr := NewPrivateUnixSocketListener(rootDir, "test-") + require.NoError(t, createErr) + require.NotNil(t, listener) + defer listener.Close() + + socketPath := listener.SocketPath() + socketName := filepath.Base(socketPath) + + // Verify the socket name starts with the prefix and has the random suffix + assert.True(t, len(socketName) > len("test-"), "socket name should include random suffix") + assert.Equal(t, "test-", socketName[:len("test-")]) + + // Verify socket file was created + _, statErr := os.Stat(socketPath) + require.NoError(t, statErr) +} + +func TestPrivateUnixSocketListenerAcceptsConnections(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + listener, createErr := NewPrivateUnixSocketListener(rootDir, "acc-") + require.NoError(t, createErr) + defer listener.Close() + + // Accept in background + var serverConn net.Conn + var acceptErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + serverConn, acceptErr = listener.Accept() + }() + + // Connect client + clientConn, dialErr := net.Dial("unix", listener.SocketPath()) + require.NoError(t, dialErr) + defer clientConn.Close() + + wg.Wait() + require.NoError(t, acceptErr) + require.NotNil(t, serverConn) + defer serverConn.Close() + + // Verify we can exchange data + _, writeErr := clientConn.Write([]byte("hello")) + require.NoError(t, writeErr) + + buf := make([]byte, 5) + n, readErr := serverConn.Read(buf) + require.NoError(t, readErr) + assert.Equal(t, "hello", string(buf[:n])) +} + +func TestPrivateUnixSocketListenerCloseRemovesSocketFile(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + listener, createErr := NewPrivateUnixSocketListener(rootDir, "cls-") + require.NoError(t, createErr) + + socketPath := listener.SocketPath() + + // Verify socket exists + _, statErr := os.Stat(socketPath) + require.NoError(t, statErr) + + closeErr := listener.Close() + assert.NoError(t, closeErr) + + // Verify socket was removed + _, statErr = os.Stat(socketPath) + assert.True(t, os.IsNotExist(statErr)) + + // Double close should be safe + closeErr = listener.Close() + assert.NoError(t, closeErr) +} + +func TestPrivateUnixSocketListenerDoesNotRemoveExistingSocketOnCollision(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + // Create listeners that will occupy socket paths in the directory. + // A new listener should get a different path without removing these. + l1, err1 := NewPrivateUnixSocketListener(rootDir, "col-") + require.NoError(t, err1) + defer l1.Close() + + l2, err2 := NewPrivateUnixSocketListener(rootDir, "col-") + require.NoError(t, err2) + defer l2.Close() + + // The first listener's socket must still exist (not removed by the second). + _, statErr := os.Stat(l1.SocketPath()) + assert.NoError(t, statErr, "existing socket file should not be removed on collision") + + // Both listeners should have distinct paths. + assert.NotEqual(t, l1.SocketPath(), l2.SocketPath()) + + // Both should accept connections. + for _, listener := range []*PrivateUnixSocketListener{l1, l2} { + conn, dialErr := net.Dial("unix", listener.SocketPath()) + require.NoError(t, dialErr) + conn.Close() + } +} + +func TestPrivateUnixSocketListenerConcurrentCloseReturnsErrClosed(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + listener, createErr := NewPrivateUnixSocketListener(rootDir, "ccl-") + require.NoError(t, createErr) + + // Start Accept() in a goroutine; it will block until the listener is closed. + var acceptErr error + acceptDone := make(chan struct{}) + go func() { + defer close(acceptDone) + _, acceptErr = listener.Accept() + }() + + // Give the accept goroutine a moment to enter the blocking Accept() call. + runtime.Gosched() + + // Launch 10 goroutines that all race to call Close(). + const closerCount = 10 + closeErrs := make([]error, closerCount) + startCh := make(chan struct{}) + var closeWg sync.WaitGroup + closeWg.Add(closerCount) + for i := range closerCount { + go func() { + defer closeWg.Done() + <-startCh + closeErrs[i] = listener.Close() + }() + } + + // Signal all closers to race. + close(startCh) + closeWg.Wait() + + // Wait for Accept() to return. + <-acceptDone + + // Accept() must return net.ErrClosed so the caller can distinguish + // a graceful shutdown from an unexpected error. + assert.ErrorIs(t, acceptErr, net.ErrClosed) + + // All Close() calls must succeed (Close is idempotent). + for i, closeErr := range closeErrs { + assert.NoError(t, closeErr, "Close() call %d returned an error", i) + } +} + +func TestPrivateUnixSocketListenerAddrReturnsValidAddress(t *testing.T) { + t.Parallel() + rootDir := shortTempDir(t) + + listener, createErr := NewPrivateUnixSocketListener(rootDir, "addr-") + require.NoError(t, createErr) + defer listener.Close() + + addr := listener.Addr() + require.NotNil(t, addr) + assert.Equal(t, "unix", addr.Network()) +} + +func TestPrivateUnixSocketListenerSocketFilePermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("socket file permission check not applicable on Windows") + } + + t.Parallel() + rootDir := shortTempDir(t) + + listener, createErr := NewPrivateUnixSocketListener(rootDir, "perm-") + require.NoError(t, createErr) + defer listener.Close() + + info, statErr := os.Stat(listener.SocketPath()) + require.NoError(t, statErr) + // The socket file should have 0600 permissions (best-effort). + // On some systems the kernel may adjust socket permissions, so + // we check that at minimum the group/other write bits are not set. + perm := info.Mode().Perm() + assert.Zero(t, perm&0077, fmt.Sprintf("socket should not be accessible by group/others, got %o", perm)) +} diff --git a/internal/notifications/notification_source.go b/internal/notifications/notification_source.go index 5f532ac9..48f01c92 100644 --- a/internal/notifications/notification_source.go +++ b/internal/notifications/notification_source.go @@ -9,7 +9,6 @@ import ( "context" "errors" "fmt" - "net" "sync" "sync/atomic" @@ -18,6 +17,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" + "github.com/microsoft/dcp/internal/networking" "github.com/microsoft/dcp/internal/notifications/proto" "github.com/microsoft/dcp/pkg/concurrency" "github.com/microsoft/dcp/pkg/grpcutil" @@ -40,7 +40,7 @@ type unixSocketNotificationSource struct { lock *sync.Mutex // The Unix domain socket listener for incoming connections. - listener *net.UnixListener + listener *networking.PrivateUnixSocketListener // Subscriptions are just long-lived gRPC calls returning a stream of notifications. // Each channel gets an unbounded channel for sending notifications to the client/subscriber. diff --git a/internal/notifications/notifications.go b/internal/notifications/notifications.go index 04e68ca0..8599437e 100644 --- a/internal/notifications/notifications.go +++ b/internal/notifications/notifications.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "net" - "os" "path/filepath" "sync" "time" @@ -20,11 +19,10 @@ import ( "google.golang.org/grpc" "google.golang.org/protobuf/types/known/durationpb" - "github.com/microsoft/dcp/internal/dcppaths" + "github.com/microsoft/dcp/internal/networking" "github.com/microsoft/dcp/internal/notifications/proto" "github.com/microsoft/dcp/pkg/concurrency" "github.com/microsoft/dcp/pkg/grpcutil" - "github.com/microsoft/dcp/pkg/osutil" "github.com/microsoft/dcp/pkg/randdata" ) @@ -129,38 +127,14 @@ func asNotification(nd *proto.NotificationData) (Notification, error) { } } -// A helper function that ensures the notification socket can be created +// PrepareNotificationSocketPath ensures the notification socket can be created // in a folder that is writable only by the current user, and that the path // is reasonably unique to the calling process. -// If the `rootDir` is empty, it will use the user's cache directory. +// If the rootDir is empty, it will use the user's cache directory. func PrepareNotificationSocketPath(rootDir string, socketNamePrefix string) (string, error) { - if rootDir == "" { - cacheDir, cacheDirErr := os.UserCacheDir() - if cacheDirErr != nil { - return "", fmt.Errorf("failed to get user cache directory when creating a notification socket: %w", cacheDirErr) - } else { - rootDir = cacheDir - } - } - - socketDir := filepath.Join(rootDir, dcppaths.DcpWorkDir) - if err := os.MkdirAll(socketDir, osutil.PermissionOnlyOwnerReadWriteTraverse); err != nil { - return "", fmt.Errorf("failed to create directory for notification socket: %w", err) - } - - // On Windows the user cache directory always exists and is always private to the user, - // but on Unix-like systems, we need to ensure the directory is private. - if !osutil.IsWindows() { - info, infoErr := os.Stat(socketDir) - if infoErr != nil { - return "", fmt.Errorf("failed to check permissions on the notification socket directory: %w", infoErr) - } - if !info.IsDir() { - return "", fmt.Errorf("notification socket path %s is not a directory", socketDir) - } - if info.Mode().Perm() != osutil.PermissionOnlyOwnerReadWriteTraverse { - return "", fmt.Errorf("notification socket directory %s is not private to the user", socketDir) - } + socketDir, dirErr := networking.PreparePrivateUnixSocketDir(rootDir) + if dirErr != nil { + return "", fmt.Errorf("failed to prepare notification socket directory: %w", dirErr) } suffix, suffixErr := randdata.MakeRandomString(8) @@ -218,18 +192,22 @@ type UnixSocketNotificationSource interface { SocketPath() string } -func NewNotificationSource(lifetimeCtx context.Context, socketPath string, log logr.Logger) (UnixSocketNotificationSource, error) { - listener, listenErr := net.ListenUnix("unix", &net.UnixAddr{Name: socketPath, Net: "unix"}) - if listenErr != nil { - return nil, fmt.Errorf("could not create notification socket at %s: %w", socketPath, listenErr) +// NewNotificationSource creates a notification source that listens on the given socket path. +// The socketDir and socketNamePrefix are used to create a secure Unix domain socket via +// the shared networking library. If socketDir is empty, os.UserCacheDir() is used. +// The actual socket path (including a random suffix) can be retrieved via SocketPath(). +func NewNotificationSource(lifetimeCtx context.Context, socketDir string, socketNamePrefix string, log logr.Logger) (UnixSocketNotificationSource, error) { + socketListener, listenerErr := networking.NewPrivateUnixSocketListener(socketDir, socketNamePrefix) + if listenerErr != nil { + return nil, fmt.Errorf("could not create notification socket: %w", listenerErr) } ns := &unixSocketNotificationSource{ lifetimeCtx: lifetimeCtx, log: log, - socketPath: socketPath, + socketPath: socketListener.SocketPath(), lock: &sync.Mutex{}, - listener: listener, + listener: socketListener, subscriptions: make(map[uint32]*concurrency.UnboundedChan[Notification]), dispose: concurrency.NewOneTimeJob[struct{}](), clientConnected: concurrency.NewSemaphore(), @@ -241,7 +219,7 @@ func NewNotificationSource(lifetimeCtx context.Context, socketPath string, log l proto.RegisterNotificationsServer(notifyServer, ns) go func() { - serverErr := notifyServer.Serve(ns.listener) + serverErr := notifyServer.Serve(socketListener) if serverErr != nil && !errors.Is(serverErr, net.ErrClosed) { ns.log.Error(serverErr, "Notification server encountered an error") } diff --git a/internal/notifications/notifications_test.go b/internal/notifications/notifications_test.go index 1a8e84ca..7eebcfd0 100644 --- a/internal/notifications/notifications_test.go +++ b/internal/notifications/notifications_test.go @@ -32,12 +32,11 @@ func TestNotificationSendReceive(t *testing.T) { ctx, cancel := testutil.GetTestContext(t, defaultNotificationsTestTimeout) defer cancel() - socketPath, socketPathErr := PrepareNotificationSocketPath(testutil.TestTempDir(), "test-notification-socket-") - require.NoError(t, socketPathErr) - nsi, nsErr := NewNotificationSource(ctx, socketPath, sourceLog) + nsi, nsErr := NewNotificationSource(ctx, testutil.TestTempDir(), "test-notification-socket-", sourceLog) require.NoError(t, nsErr) require.NotNil(t, nsi) usns := nsi.(*unixSocketNotificationSource) + socketPath := nsi.SocketPath() const numNotifications = 10 notes := make(chan Notification, numNotifications) @@ -85,12 +84,11 @@ func TestNotificationMultipleReceivers(t *testing.T) { ctx, cancel := testutil.GetTestContext(t, defaultNotificationsTestTimeout) defer cancel() - socketPath, socketPathErr := PrepareNotificationSocketPath(testutil.TestTempDir(), "test-notification-socket-") - require.NoError(t, socketPathErr) - ns, err := NewNotificationSource(ctx, socketPath, testLog) - require.NoError(t, err) + ns, nsCreateErr := NewNotificationSource(ctx, testutil.TestTempDir(), "test-notification-socket-", testLog) + require.NoError(t, nsCreateErr) require.NotNil(t, ns) usns := ns.(*unixSocketNotificationSource) + socketPath := ns.SocketPath() // Start with two receivers r1Ctx, r1CtxCancel := context.WithCancel(ctx) diff --git a/internal/podman/cli_orchestrator.go b/internal/podman/cli_orchestrator.go index 189c27be..0f300a33 100644 --- a/internal/podman/cli_orchestrator.go +++ b/internal/podman/cli_orchestrator.go @@ -663,7 +663,7 @@ func (pco *PodmanCliOrchestrator) ExecContainer(ctx context.Context, options con } pco.log.V(1).Info("Running Podman command", "Command", cmd.String()) - _, _, startWaitForProcessExit, err := pco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) + _, startWaitForProcessExit, err := pco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) if err != nil { close(exitCh) return nil, errors.Join(err, fmt.Errorf("failed to start Podman command '%s'", "ExecContainer")) @@ -1088,13 +1088,13 @@ func (pco *PodmanCliOrchestrator) doWatchContainers(watcherCtx context.Context, // Container events are delivered on best-effort basis. // If the "podman events" command fails unexpectedly, we will log the error, // but we won't try to restart it. - pid, startTime, startWaitForProcessExit, err := pco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) + handle, startWaitForProcessExit, err := pco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) if err != nil { pco.log.Error(err, "Could not execute 'podman events' command; container events unavailable") return } - dcpproc.RunProcessWatcher(pco.executor, pid, startTime, pco.log) + dcpproc.RunProcessWatcher(pco.executor, handle, pco.log) startWaitForProcessExit() @@ -1106,7 +1106,7 @@ func (pco *PodmanCliOrchestrator) doWatchContainers(watcherCtx context.Context, } case <-watcherCtx.Done(): // We are asked to shut down - pco.log.V(1).Info("Stopping 'podman events' command", "PID", pid) + pco.log.V(1).Info("Stopping 'podman events' command", "PID", handle.Pid) } } @@ -1148,13 +1148,13 @@ func (pco *PodmanCliOrchestrator) doWatchNetworks(watcherCtx context.Context, ss // Container events are delivered on best-effort basis. // If the "podman events" command fails unexpectedly, we will log the error, // but we won't try to restart it. - pid, startTime, startWaitForProcessExit, err := pco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) + handle, startWaitForProcessExit, err := pco.executor.StartProcess(watcherCtx, cmd, peh, process.CreationFlagsNone) if err != nil { pco.log.Error(err, "Could not execute 'podman events' command; network events unavailable") return } - dcpproc.RunProcessWatcher(pco.executor, pid, startTime, pco.log) + dcpproc.RunProcessWatcher(pco.executor, handle, pco.log) startWaitForProcessExit() @@ -1166,7 +1166,7 @@ func (pco *PodmanCliOrchestrator) doWatchNetworks(watcherCtx context.Context, ss } case <-watcherCtx.Done(): // We are asked to shut down - pco.log.V(1).Info("Stopping 'podman events' command", "PID", pid) + pco.log.V(1).Info("Stopping 'podman events' command", "PID", handle.Pid) } } @@ -1201,14 +1201,14 @@ func (pco *PodmanCliOrchestrator) streamPodmanCommand( } pco.log.V(1).Info("Running podman command", "Command", cmd.String()) - pid, startTime, startWaitForProcessExit, err := pco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) + handle, startWaitForProcessExit, err := pco.executor.StartProcess(ctx, cmd, process.ProcessExitHandlerFunc(exitHandler), process.CreationFlagsNone) if err != nil { close(exitCh) return nil, errors.Join(err, fmt.Errorf("failed to start podman command '%s'", commandName)) } if opts&streamCommandOptionUseWatcher != 0 { - dcpproc.RunProcessWatcher(pco.executor, pid, startTime, pco.log) + dcpproc.RunProcessWatcher(pco.executor, handle, pco.log) } startWaitForProcessExit() diff --git a/internal/testutil/ctrlutil/apiserver_start.go b/internal/testutil/ctrlutil/apiserver_start.go index 5d384b2d..8d9d83da 100644 --- a/internal/testutil/ctrlutil/apiserver_start.go +++ b/internal/testutil/ctrlutil/apiserver_start.go @@ -238,14 +238,14 @@ func StartApiServer( info.ApiServerExited.SetAndFreeze() }) - apiServerPID, _, startWaitForProcessExit, dcpStartErr := pe.StartProcess(testRunCtx, cmd, apiserverExitHandler, process.CreationFlagsNone) + apiServerHandle, startWaitForProcessExit, dcpStartErr := pe.StartProcess(testRunCtx, cmd, apiserverExitHandler, process.CreationFlagsNone) if dcpStartErr != nil { info.ApiServerExited.SetAndFreeze() cleanup() return nil, fmt.Errorf("failed to start the API server process: %w", dcpStartErr) } startWaitForProcessExit() - info.ApiServerPID = apiServerPID + info.ApiServerPID = apiServerHandle.Pid // Using generous timeout because AzDO pipeline machines can be very slow at times. const configCreationTimeout = 70 * time.Second diff --git a/internal/testutil/test_process_executor.go b/internal/testutil/test_process_executor.go index 871f06d7..c00c40da 100644 --- a/internal/testutil/test_process_executor.go +++ b/internal/testutil/test_process_executor.go @@ -79,11 +79,11 @@ func (e *TestProcessExecutor) StartProcess( cmd *exec.Cmd, handler process.ProcessExitHandler, _ process.ProcessCreationFlag, -) (process.Pid_t, time.Time, func(), error) { +) (process.ProcessHandle, func(), error) { pid64 := atomic.AddInt64(&e.nextPID, 1) - pid, err := process.Int64_ToPidT(pid64) - if err != nil { - return process.UnknownPID, time.Time{}, nil, err + pid, pidErr := process.Int64_ToPidT(pid64) + if pidErr != nil { + return process.ProcessHandle{Pid: process.UnknownPID}, nil, pidErr } e.m.Lock() @@ -130,17 +130,18 @@ func (e *TestProcessExecutor) StartProcess( } if autoExecutionErr := e.maybeAutoExecute(&pe); autoExecutionErr != nil { - return process.UnknownPID, time.Time{}, nil, autoExecutionErr + return process.ProcessHandle{Pid: process.UnknownPID}, nil, autoExecutionErr } - return pid, startTimestamp, startWaitingForExit, nil + handle := process.NewHandle(pid, startTimestamp) + return handle, startWaitingForExit, nil } -func (e *TestProcessExecutor) StartAndForget(cmd *exec.Cmd, _ process.ProcessCreationFlag) (process.Pid_t, time.Time, error) { +func (e *TestProcessExecutor) StartAndForget(cmd *exec.Cmd, _ process.ProcessCreationFlag) (process.ProcessHandle, error) { pid64 := atomic.AddInt64(&e.nextPID, 1) - pid, err := process.Int64_ToPidT(pid64) - if err != nil { - return process.UnknownPID, time.Time{}, err + pid, pidErr := process.Int64_ToPidT(pid64) + if pidErr != nil { + return process.ProcessHandle{Pid: process.UnknownPID}, pidErr } e.m.Lock() @@ -166,10 +167,11 @@ func (e *TestProcessExecutor) StartAndForget(cmd *exec.Cmd, _ process.ProcessCre e.Executions = append(e.Executions, &pe) if autoExecutionErr := e.maybeAutoExecute(&pe); autoExecutionErr != nil { - return process.UnknownPID, time.Time{}, autoExecutionErr + return process.ProcessHandle{Pid: process.UnknownPID}, autoExecutionErr } - return pid, startTimestamp, nil + handle := process.NewHandle(pid, startTimestamp) + return handle, nil } func (e *TestProcessExecutor) maybeAutoExecute(pe *ProcessExecution) error { @@ -192,7 +194,7 @@ func (e *TestProcessExecutor) maybeAutoExecute(pe *ProcessExecution) error { if !stopInitiated { // RunCommand() "ended on its own" (as opposed to being triggered by StopProcess() or SimulateProcessExit()), // so we need to do the resource cleanup. - stopProcessErr := e.stopProcessImpl(pe.PID, pe.StartedAt, exitCode) + stopProcessErr := e.stopProcessImpl(process.NewHandle(pe.PID, pe.StartedAt), exitCode) if stopProcessErr != nil && ae.StopError == nil { panic(fmt.Errorf("we should have an execution with PID=%d: %w", pe.PID, stopProcessErr)) } @@ -208,15 +210,15 @@ func (e *TestProcessExecutor) maybeAutoExecute(pe *ProcessExecution) error { } // Called by the controller (via Executor interface) -func (e *TestProcessExecutor) StopProcess(pid process.Pid_t, processStartTime time.Time) error { - return e.stopProcessImpl(pid, processStartTime, KilledProcessExitCode) +func (e *TestProcessExecutor) StopProcess(handle process.ProcessHandle) error { + return e.stopProcessImpl(handle, KilledProcessExitCode) } // Called by tests to simulate a process exit with specific exit code. func (e *TestProcessExecutor) SimulateProcessExit(t *testing.T, pid process.Pid_t, exitCode int32) { - err := e.stopProcessImpl(pid, time.Time{}, exitCode) - if err != nil { - require.Failf(t, "invalid PID (test issue)", err.Error()) + stopErr := e.stopProcessImpl(process.ProcessHandle{Pid: pid}, exitCode) + if stopErr != nil { + require.Failf(t, "invalid PID (test issue)", stopErr.Error()) } } @@ -303,21 +305,21 @@ func (e *TestProcessExecutor) findByPid(pid process.Pid_t) int { return NotFound } -func (e *TestProcessExecutor) stopProcessImpl(pid process.Pid_t, processStartTime time.Time, exitCode int32) error { +func (e *TestProcessExecutor) stopProcessImpl(handle process.ProcessHandle, exitCode int32) error { e.m.Lock() - i := e.findByPid(pid) + i := e.findByPid(handle.Pid) if i == NotFound { e.m.Unlock() - return fmt.Errorf("no process with PID %d found", pid) + return fmt.Errorf("no process with PID %d found", handle.Pid) } - if !processStartTime.IsZero() { - if !osutil.Within(processStartTime, e.Executions[i].StartedAt, process.ProcessIdentityTimeMaximumDifference) { + if !handle.IdentityTime.IsZero() { + if !osutil.Within(handle.IdentityTime, e.Executions[i].StartedAt, process.ProcessIdentityTimeMaximumDifference) { e.m.Unlock() return fmt.Errorf("process start time mismatch for PID %d: expected %s, actual %s", - pid, - processStartTime.Format(osutil.RFC3339MiliTimestampFormat), + handle.Pid, + handle.IdentityTime.Format(osutil.RFC3339MiliTimestampFormat), e.Executions[i].StartedAt.Format(osutil.RFC3339MiliTimestampFormat), ) } @@ -366,7 +368,7 @@ func (e *TestProcessExecutor) stopProcessImpl(pid process.Pid_t, processStartTim case <-e.lifetimeCtx.Done(): return case <-pe.startWaitingChan: - pe.ExitHandler.OnProcessExited(pid, exitCode, nil) + pe.ExitHandler.OnProcessExited(handle.Pid, exitCode, nil) } }() } diff --git a/pkg/osutil/env_suppression.go b/pkg/osutil/env_suppression.go new file mode 100644 index 00000000..11ef520e --- /dev/null +++ b/pkg/osutil/env_suppression.go @@ -0,0 +1,62 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package osutil + +import ( + "os" + "strings" + + "github.com/microsoft/dcp/pkg/maps" +) + +// SuppressedEnvVarPrefixes is the set of environment variable prefixes that should not +// be inherited from the ambient (DCP process) environment when launching child processes +// such as Executables and debug adapters. Variables whose names start with any of these +// prefixes are removed from the inherited environment. +var SuppressedEnvVarPrefixes = []string{ + "DEBUG_SESSION", + "DCP_", +} + +// NewFilteredAmbientEnv returns a StringKeyMap populated from the current process +// environment with all variables whose names match SuppressedEnvVarPrefixes removed. +// The returned map uses case-insensitive keys on Windows and case-sensitive keys +// on other platforms. +// +// Callers can overlay additional environment variables on top of the returned map +// (e.g. from configuration or spec) before converting it to the final []string +// used by exec.Cmd.Env. +func NewFilteredAmbientEnv() maps.StringKeyMap[string] { + envMap := NewPlatformStringMap[string]() + + envMap.Apply(maps.SliceToMap(os.Environ(), func(envStr string) (string, string) { + parts := strings.SplitN(envStr, "=", 2) + return parts[0], parts[1] + })) + + SuppressEnvVarPrefixes(envMap) + + return envMap +} + +// NewPlatformStringMap returns a new empty StringKeyMap with the key-comparison mode +// appropriate for the current platform (case-insensitive on Windows, case-sensitive +// elsewhere). +func NewPlatformStringMap[T any]() maps.StringKeyMap[T] { + if IsWindows() { + return maps.NewStringKeyMap[T](maps.StringMapModeCaseInsensitive) + } + return maps.NewStringKeyMap[T](maps.StringMapModeCaseSensitive) +} + +// SuppressEnvVarPrefixes removes all entries from envMap whose keys start with any +// of the SuppressedEnvVarPrefixes. This can be called at any point in an environment- +// building pipeline to strip DCP-internal variables. +func SuppressEnvVarPrefixes(envMap maps.StringKeyMap[string]) { + for _, prefix := range SuppressedEnvVarPrefixes { + envMap.DeletePrefix(prefix) + } +} diff --git a/pkg/osutil/env_suppression_test.go b/pkg/osutil/env_suppression_test.go new file mode 100644 index 00000000..545d7477 --- /dev/null +++ b/pkg/osutil/env_suppression_test.go @@ -0,0 +1,102 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package osutil + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewFilteredAmbientEnvExcludesSuppressedPrefixes(t *testing.T) { + // Set environment variables that should be suppressed. + t.Setenv("DCP_TEST_VAR", "should_be_removed") + t.Setenv("DCP_ANOTHER", "also_removed") + t.Setenv("DEBUG_SESSION_ID", "removed_too") + // Set a normal variable that should survive. + t.Setenv("MY_APP_SETTING", "keep_me") + + envMap := NewFilteredAmbientEnv() + + _, hasDcpTest := envMap.Get("DCP_TEST_VAR") + assert.False(t, hasDcpTest, "DCP_TEST_VAR should be suppressed") + + _, hasDcpAnother := envMap.Get("DCP_ANOTHER") + assert.False(t, hasDcpAnother, "DCP_ANOTHER should be suppressed") + + _, hasDebugSession := envMap.Get("DEBUG_SESSION_ID") + assert.False(t, hasDebugSession, "DEBUG_SESSION_ID should be suppressed") + + val, hasAppSetting := envMap.Get("MY_APP_SETTING") + assert.True(t, hasAppSetting, "MY_APP_SETTING should be present") + assert.Equal(t, "keep_me", val) +} + +func TestNewFilteredAmbientEnvContainsNormalVars(t *testing.T) { + // PATH should always exist and not be suppressed. + pathVal, found := os.LookupEnv("PATH") + if !found { + t.Skip("PATH not set in test environment") + } + + envMap := NewFilteredAmbientEnv() + + got, ok := envMap.Get("PATH") + require.True(t, ok, "PATH should be present in the filtered env") + assert.Equal(t, pathVal, got) +} + +func TestNewFilteredAmbientEnvHasNoSuppressedKeys(t *testing.T) { + t.Setenv("DCP_SOME_KEY", "value") + t.Setenv("DEBUG_SESSION_TOKEN", "value") + + envMap := NewFilteredAmbientEnv() + + for key := range envMap.Data() { + for _, prefix := range SuppressedEnvVarPrefixes { + assert.Falsef(t, strings.HasPrefix(key, prefix), + "key %q should have been suppressed (prefix %q)", key, prefix) + } + } +} + +func TestSuppressEnvVarPrefixesRemovesMatchingKeys(t *testing.T) { + envMap := NewPlatformStringMap[string]() + envMap.Set("DCP_FOO", "1") + envMap.Set("DEBUG_SESSION_BAR", "2") + envMap.Set("KEEP_ME", "3") + + SuppressEnvVarPrefixes(envMap) + + _, hasDcp := envMap.Get("DCP_FOO") + assert.False(t, hasDcp) + + _, hasDebug := envMap.Get("DEBUG_SESSION_BAR") + assert.False(t, hasDebug) + + val, hasKeep := envMap.Get("KEEP_ME") + assert.True(t, hasKeep) + assert.Equal(t, "3", val) +} + +func TestNewPlatformStringMapMode(t *testing.T) { + m := NewPlatformStringMap[string]() + m.Set("TestKey", "value") + + if IsWindows() { + // Case-insensitive: looking up with different casing should succeed. + val, ok := m.Get("testkey") + assert.True(t, ok, "expected case-insensitive lookup on Windows") + assert.Equal(t, "value", val) + } else { + // Case-sensitive: different casing should NOT match. + _, ok := m.Get("testkey") + assert.False(t, ok, "expected case-sensitive lookup on non-Windows") + } +} diff --git a/pkg/process/os_executor.go b/pkg/process/os_executor.go index e7134d5c..f1727cfe 100644 --- a/pkg/process/os_executor.go +++ b/pkg/process/os_executor.go @@ -48,32 +48,29 @@ type waitState struct { reason waitReason // The reason why are waiting on the process } -type WaitKey struct { - Pid Pid_t - StartedAt time.Time -} - func (e *OSExecutor) StartProcess( ctx context.Context, cmd *exec.Cmd, handler ProcessExitHandler, flags ProcessCreationFlag, -) (Pid_t, time.Time, func(), error) { +) (ProcessHandle, func(), error) { e.acquireLock() if e.disposed { e.releaseLock() - return UnknownPID, time.Time{}, nil, ErrDisposed + return ProcessHandle{Pid: UnknownPID}, nil, ErrDisposed } e.releaseLock() - pid, processIdentityTime, err := e.startProcess(cmd, flags) - if err != nil { - return UnknownPID, time.Time{}, nil, err + handle, startProcessErr := e.startProcess(cmd, flags) + if startProcessErr != nil { + return ProcessHandle{Pid: UnknownPID}, nil, startProcessErr } + pid := handle.Pid + // Get the wait result channel, but do not actually start waiting // This also has the effect of tying the wait for this process to the command that started it. - ws, _ := e.tryStartWaiting(pid, processIdentityTime, waitableCmd{cmd, flags}, waitReasonNone) + ws, _ := e.tryStartWaiting(handle, waitableCmd{cmd, flags}, waitReasonNone) // Start the goroutine that waits for the context to expire. go func() { @@ -88,7 +85,7 @@ func (e *OSExecutor) StartProcess( } case <-ctx.Done(): - _, shouldStopProcess := e.tryStartWaiting(pid, processIdentityTime, waitableCmd{cmd, flags}, waitReasonStopping) + _, shouldStopProcess := e.tryStartWaiting(handle, waitableCmd{cmd, flags}, waitReasonStopping) var stopProcessErr error = nil if shouldStopProcess { @@ -99,7 +96,7 @@ func (e *OSExecutor) StartProcess( "Args", cmd.Args[1:], ) log.Info("Context expired, stopping process...") - stopProcessErr = e.stopProcessInternal(pid, processIdentityTime, optIsResponsibleForStopping) + stopProcessErr = e.stopProcessInternal(handle, optIsResponsibleForStopping) if stopProcessErr != nil { log.Error(stopProcessErr, "Could not stop process upon context expiration") if handler != nil { @@ -122,28 +119,28 @@ func (e *OSExecutor) StartProcess( }() startWaitingForProcessExit := func() { - _, _ = e.tryStartWaiting(pid, processIdentityTime, waitableCmd{cmd, flags}, waitReasonMonitoring) + _, _ = e.tryStartWaiting(handle, waitableCmd{cmd, flags}, waitReasonMonitoring) } - return pid, processIdentityTime, startWaitingForProcessExit, nil + return handle, startWaitingForProcessExit, nil } -func (e *OSExecutor) StartAndForget(cmd *exec.Cmd, flags ProcessCreationFlag) (Pid_t, time.Time, error) { +func (e *OSExecutor) StartAndForget(cmd *exec.Cmd, flags ProcessCreationFlag) (ProcessHandle, error) { e.acquireLock() if e.disposed { e.releaseLock() - return UnknownPID, time.Time{}, ErrDisposed + return ProcessHandle{Pid: UnknownPID}, ErrDisposed } e.releaseLock() - pid, processStartTime, err := e.startProcess(cmd, flags) - if err != nil { - return UnknownPID, time.Time{}, err + handle, startProcessErr := e.startProcess(cmd, flags) + if startProcessErr != nil { + return ProcessHandle{Pid: UnknownPID}, startProcessErr } if cmd.Process == nil { e.log.V(1).Info("Process info is not available after successful start???", - "PID", pid, + "PID", handle.Pid, "Command", cmd.Path, "Args", cmd.Args[1:], ) @@ -155,10 +152,10 @@ func (e *OSExecutor) StartAndForget(cmd *exec.Cmd, flags ProcessCreationFlag) (P }(cmd.Process) } - return pid, processStartTime, nil + return handle, nil } -func (e *OSExecutor) StopProcess(pid Pid_t, processStartTime time.Time) error { +func (e *OSExecutor) StopProcess(handle ProcessHandle) error { e.acquireLock() if e.disposed { e.releaseLock() @@ -166,15 +163,15 @@ func (e *OSExecutor) StopProcess(pid Pid_t, processStartTime time.Time) error { } e.releaseLock() - return e.stopProcessInternal(pid, processStartTime, optNone) + return e.stopProcessInternal(handle, optNone) } -// Returns the PID, process identity time (to distinguish between process instances with the same PID), and error. -func (e *OSExecutor) startProcess(cmd *exec.Cmd, flags ProcessCreationFlag) (Pid_t, time.Time, error) { +// Returns a ProcessHandle identifying the started process, or an error. +func (e *OSExecutor) startProcess(cmd *exec.Cmd, flags ProcessCreationFlag) (ProcessHandle, error) { e.prepareProcessStart(cmd, flags) if err := cmd.Start(); err != nil { - return UnknownPID, time.Time{}, err + return ProcessHandle{Pid: UnknownPID}, err } osPid := cmd.Process.Pid @@ -187,23 +184,23 @@ func (e *OSExecutor) startProcess(cmd *exec.Cmd, flags ProcessCreationFlag) (Pid "CreationFlags", flags, ) - processIdentityTime := ProcessIdentityTime(pid) + handle := NewHandle(pid, ProcessIdentityTime(pid)) - startCompletionErr := e.completeProcessStart(cmd, pid, processIdentityTime, flags) + startCompletionErr := e.completeProcessStart(cmd, handle, flags) if startCompletionErr != nil { startLog.Error(startCompletionErr, "Could not complete process start") // If we could not complete the process start, we need to stop the process. // Do not try graceful stop (no optTrySignal), just kill it immediately. - if stopErr := e.stopProcessInternal(pid, processIdentityTime, optIsResponsibleForStopping); stopErr != nil { + if stopErr := e.stopProcessInternal(handle, optIsResponsibleForStopping); stopErr != nil { startLog.Error(stopErr, "Could not stop process after failed start") } - return UnknownPID, time.Time{}, fmt.Errorf("could not complete process start: %w", startCompletionErr) + return ProcessHandle{Pid: UnknownPID}, fmt.Errorf("could not complete process start: %w", startCompletionErr) } startLog.V(1).Info("Process started successfully", "PID", pid) - return pid, processIdentityTime, nil + return handle, nil } // Atomically starts waiting on the passed waitable if noting is already waiting in association with the process @@ -212,11 +209,11 @@ func (e *OSExecutor) startProcess(cmd *exec.Cmd, flags ProcessCreationFlag) (Pid // Returns the waitState object associated with the process, and a boolean indicating whether the caller // is the first one to indicate that the reason for the wait is "stopping the process", // and thus IT is the caller that must stop the process. -func (e *OSExecutor) tryStartWaiting(pid Pid_t, startTime time.Time, waitable Waitable, reason waitReason) (*waitState, bool) { +func (e *OSExecutor) tryStartWaiting(handle ProcessHandle, waitable Waitable, reason waitReason) (*waitState, bool) { e.acquireLock() defer e.releaseLock() - ws, found := e.procsWaiting[WaitKey{pid, startTime}] + ws, found := e.procsWaiting[handle] callerShouldStopProcess := false if found { @@ -230,7 +227,7 @@ func (e *OSExecutor) tryStartWaiting(pid Pid_t, startTime time.Time, waitable Wa mustStartWaiting := ws.reason == waitReasonNone && reason != waitReasonNone ws.reason |= reason if mustStartWaiting { - go e.doWait(ws, waitable, pid) + go e.doWait(ws, waitable, handle.Pid) } } else { callerShouldStopProcess = (reason & waitReasonStopping) != 0 @@ -239,9 +236,9 @@ func (e *OSExecutor) tryStartWaiting(pid Pid_t, startTime time.Time, waitable Wa waitEndedCh: make(chan struct{}), reason: reason, } - e.procsWaiting[WaitKey{pid, startTime}] = ws + e.procsWaiting[handle] = ws if reason != waitReasonNone { - go e.doWait(ws, waitable, pid) + go e.doWait(ws, waitable, handle.Pid) } } @@ -285,7 +282,7 @@ func (e *OSExecutor) acquireLock() { } // Only keep wait states that correspond to processes that are still running, or the ones that completed recently - e.procsWaiting = maps.Select(e.procsWaiting, func(_ WaitKey, ws *waitState) bool { + e.procsWaiting = maps.Select(e.procsWaiting, func(_ ProcessHandle, ws *waitState) bool { return ws.waitEnded.IsZero() || time.Since(ws.waitEnded) < maxCompletedDuration }) } @@ -294,16 +291,16 @@ func (e *OSExecutor) releaseLock() { e.lock.Unlock() } -func (e *OSExecutor) stopProcessInternal(pid Pid_t, processStartTime time.Time, opts processStoppingOpts) error { - tree, err := GetProcessTree(ProcessTreeItem{pid, processStartTime}) - if err != nil { - return fmt.Errorf("could not get process tree for process %d: %w", pid, err) +func (e *OSExecutor) stopProcessInternal(handle ProcessHandle, opts processStoppingOpts) error { + tree, treeErr := GetProcessTree(handle) + if treeErr != nil { + return fmt.Errorf("could not get process tree for process %d: %w", handle.Pid, treeErr) } - procTreeLog := e.log.WithValues("Root", pid) - procTreeLog.V(1).Info("Stopping process tree...", "Root", pid, "Tree", getIDs(tree)) + procTreeLog := e.log.WithValues("Root", handle.Pid) + procTreeLog.V(1).Info("Stopping process tree...", "Root", handle.Pid, "Tree", getIDs(tree)) - procEndedCh, stopErr := e.stopSingleProcess(pid, processStartTime, opts|optNotFoundIsError|optTrySignal|optWaitForStdio) + procEndedCh, stopErr := e.stopSingleProcess(handle, opts|optNotFoundIsError|optTrySignal|optWaitForStdio) if stopErr != nil && !errors.Is(stopErr, ErrTimedOutWaitingForProcessToStop) { // If the root process cannot be stopped (and it is not just a timeout error), don't bother with the rest of the tree. procTreeLog.Error(stopErr, "Could not stop root process") @@ -336,7 +333,7 @@ func (e *OSExecutor) stopProcessInternal(pid Pid_t, processStartTime time.Time, } procTreeLog.V(1).Info("Make sure children of the root processes are gone...") - childStoppingErrors := slices.MapConcurrent[error](tree, func(p ProcessTreeItem) error { + childStoppingErrors := slices.MapConcurrent[error](tree, func(p ProcessHandle) error { // Retry stopping the child process as we occasionally see transient "Access Denied" errors. const childStopTimeout = 2 * time.Second childLog := procTreeLog.WithValues("Child", p.Pid) @@ -344,7 +341,7 @@ func (e *OSExecutor) stopProcessInternal(pid Pid_t, processStartTime time.Time, retryErr := resiliency.RetryExponentialWithTimeout(context.Background(), childStopTimeout, func() error { childLog.V(1).Info("Stopping child process...") - _, childStopErr := e.stopSingleProcess(p.Pid, p.IdentityTime, opts&^optNotFoundIsError) + _, childStopErr := e.stopSingleProcess(p, opts&^optNotFoundIsError) if childStopErr != nil { childLog.V(1).Info("Error stopping child process", "Error", childStopErr.Error()) } else { @@ -355,7 +352,7 @@ func (e *OSExecutor) stopProcessInternal(pid Pid_t, processStartTime time.Time, }) if retryErr != nil { - childLog.Error(err, "Could not stop child process") + childLog.Error(treeErr, "Could not stop child process") } return retryErr @@ -427,13 +424,13 @@ func (e *OSExecutor) Dispose() { if flags&CreationFlagEnsureKillOnDispose == CreationFlagEnsureKillOnDispose { // Best effort to stop the process. e.log.V(1).Info("Stopping process during executor disposal...", "PID", wk.Pid, "Command", waitable.Info()) - stopErr := e.stopProcessInternal(wk.Pid, wk.StartedAt, optIsResponsibleForStopping|optTrySignal) + stopErr := e.stopProcessInternal(wk, optIsResponsibleForStopping|optTrySignal) if stopErr != nil { e.log.Error(stopErr, "Could not stop process during executor disposal", "PID", wk.Pid, "Command", waitable.Info()) } } else { // Just make sure we called wait() so the process does not become a zombie. - _, _ = e.tryStartWaiting(wk.Pid, wk.StartedAt, waitable, waitReasonMonitoring) + _, _ = e.tryStartWaiting(wk, waitable, waitReasonMonitoring) } }() } diff --git a/pkg/process/os_executor_unix.go b/pkg/process/os_executor_unix.go index d5a3c720..8b0317a1 100644 --- a/pkg/process/os_executor_unix.go +++ b/pkg/process/os_executor_unix.go @@ -1,10 +1,10 @@ +//go:build !windows + /*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE in the project root for license information. *--------------------------------------------------------------------------------------------*/ -//go:build !windows - // Copyright (c) Microsoft Corporation. All rights reserved. package process @@ -27,7 +27,7 @@ const ( ) type OSExecutor struct { - procsWaiting map[WaitKey]*waitState + procsWaiting map[ProcessHandle]*waitState disposed bool lock sync.Locker log logr.Logger @@ -35,33 +35,33 @@ type OSExecutor struct { func NewOSExecutor(log logr.Logger) Executor { return &OSExecutor{ - procsWaiting: make(map[WaitKey]*waitState), + procsWaiting: make(map[ProcessHandle]*waitState), disposed: false, lock: &sync.Mutex{}, log: log.WithName("os-executor"), } } -func (e *OSExecutor) stopSingleProcess(pid Pid_t, processStartTime time.Time, opts processStoppingOpts) (<-chan struct{}, error) { - proc, err := FindProcess(pid, processStartTime) +func (e *OSExecutor) stopSingleProcess(handle ProcessHandle, opts processStoppingOpts) (<-chan struct{}, error) { + proc, err := FindProcess(handle) if err != nil { e.acquireLock() alreadyEnded := false - ws, found := e.procsWaiting[WaitKey{pid, processStartTime}] + ws, found := e.procsWaiting[handle] if found { alreadyEnded = !ws.waitEnded.IsZero() } e.releaseLock() if (opts&optNotFoundIsError) != 0 && !alreadyEnded { - return nil, ErrProcessNotFound{Pid: pid, Inner: err} + return nil, ErrProcessNotFound{Pid: handle.Pid, Inner: err} } else { return makeClosedChan(), nil } } - waitable := makeWaitable(pid, proc) - ws, shouldStopProcess := e.tryStartWaiting(pid, processStartTime, waitable, waitReasonStopping) + waitable := makeWaitable(handle.Pid, proc) + ws, shouldStopProcess := e.tryStartWaiting(handle, waitable, waitReasonStopping) waitEndedCh := ws.waitEndedCh if opts&optWaitForStdio == 0 { @@ -79,22 +79,22 @@ func (e *OSExecutor) stopSingleProcess(pid Pid_t, processStartTime time.Time, op err = e.signalAndWaitForExit(proc, syscall.SIGTERM, ws) switch { case err == nil: - e.log.V(1).Info("Process stopped by SIGTERM", "PID", pid) + e.log.V(1).Info("Process stopped by SIGTERM", "PID", handle.Pid) return waitEndedCh, nil case !errors.Is(err, ErrTimedOutWaitingForProcessToStop): return nil, err default: - e.log.V(1).Info("Process did not stop upon SIGTERM", "PID", pid) + e.log.V(1).Info("Process did not stop upon SIGTERM", "PID", handle.Pid) } } - e.log.V(1).Info("Sending SIGKILL to process...", "PID", pid) + e.log.V(1).Info("Sending SIGKILL to process...", "PID", handle.Pid) err = e.signalAndWaitForExit(proc, syscall.SIGKILL, ws) if err != nil { return nil, err } - e.log.V(1).Info("Process stopped by SIGKILL", "PID", pid) + e.log.V(1).Info("Process stopped by SIGKILL", "PID", handle.Pid) return waitEndedCh, nil } @@ -132,7 +132,7 @@ func (e *OSExecutor) prepareProcessStart(_ *exec.Cmd, _ ProcessCreationFlag) { // No additional preparation needed for Unix-like systems. } -func (e *OSExecutor) completeProcessStart(_ *exec.Cmd, _ Pid_t, _ time.Time, _ ProcessCreationFlag) error { +func (e *OSExecutor) completeProcessStart(_ *exec.Cmd, _ ProcessHandle, _ ProcessCreationFlag) error { // No additional actions needed on process start for Unix-like systems. return nil } diff --git a/pkg/process/os_executor_windows.go b/pkg/process/os_executor_windows.go index 9eb1957a..99d20f7a 100644 --- a/pkg/process/os_executor_windows.go +++ b/pkg/process/os_executor_windows.go @@ -1,10 +1,10 @@ +//go:build windows + /*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE in the project root for license information. *--------------------------------------------------------------------------------------------*/ -//go:build windows - // Copyright (c) Microsoft Corporation. All rights reserved. package process @@ -36,7 +36,7 @@ var ( ) type OSExecutor struct { - procsWaiting map[WaitKey]*waitState + procsWaiting map[ProcessHandle]*waitState lock sync.Locker disposed bool log logr.Logger @@ -45,7 +45,7 @@ type OSExecutor struct { func NewOSExecutor(log logr.Logger) Executor { e := &OSExecutor{ - procsWaiting: make(map[WaitKey]*waitState), + procsWaiting: make(map[ProcessHandle]*waitState), lock: &sync.Mutex{}, disposed: false, log: log.WithName("os-executor"), @@ -54,26 +54,26 @@ func NewOSExecutor(log logr.Logger) Executor { return e } -func (e *OSExecutor) stopSingleProcess(pid Pid_t, processStartTime time.Time, opts processStoppingOpts) (<-chan struct{}, error) { - proc, err := FindProcess(pid, processStartTime) +func (e *OSExecutor) stopSingleProcess(handle ProcessHandle, opts processStoppingOpts) (<-chan struct{}, error) { + proc, err := FindProcess(handle) if err != nil { e.acquireLock() alreadyEnded := false - ws, found := e.procsWaiting[WaitKey{pid, processStartTime}] + ws, found := e.procsWaiting[handle] if found { alreadyEnded = !ws.waitEnded.IsZero() } e.releaseLock() if (opts&optNotFoundIsError) != 0 && !alreadyEnded { - return nil, ErrProcessNotFound{Pid: pid, Inner: err} + return nil, ErrProcessNotFound{Pid: handle.Pid, Inner: err} } else { return makeClosedChan(), nil } } - waitable := makeWaitable(pid, proc) - ws, shouldStopProcess := e.tryStartWaiting(pid, processStartTime, waitable, waitReasonStopping) + waitable := makeWaitable(handle.Pid, proc) + ws, shouldStopProcess := e.tryStartWaiting(handle, waitable, waitReasonStopping) waitEndedCh := ws.waitEndedCh if opts&optWaitForStdio == 0 { @@ -90,22 +90,22 @@ func (e *OSExecutor) stopSingleProcess(pid Pid_t, processStartTime time.Time, op err = e.signalAndWaitForExit(proc, windows.CTRL_BREAK_EVENT, ws) switch { case err == nil: - e.log.V(1).Info("Process stopped by CTRL_BREAK_EVENT", "PID", pid) + e.log.V(1).Info("Process stopped by CTRL_BREAK_EVENT", "PID", handle.Pid) return waitEndedCh, nil case !errors.Is(err, ErrTimedOutWaitingForProcessToStop): return nil, err default: - e.log.V(1).Info("Process did not stop upon CTRL_BREAK_EVENT", "PID", pid) + e.log.V(1).Info("Process did not stop upon CTRL_BREAK_EVENT", "PID", handle.Pid) } } - e.log.V(1).Info("Sending SIGKILL to process...", "PID", pid) + e.log.V(1).Info("Sending SIGKILL to process...", "PID", handle.Pid) err = proc.Kill() if err != nil && !errors.Is(err, os.ErrProcessDone) { return nil, err } - e.log.V(1).Info("Process stopped by SIGKILL", "PID", pid) + e.log.V(1).Info("Process stopped by SIGKILL", "PID", handle.Pid) return waitEndedCh, nil } @@ -162,7 +162,7 @@ func (e *OSExecutor) prepareProcessStart(cmd *exec.Cmd, flags ProcessCreationFla } } -func (e *OSExecutor) completeProcessStart(_ *exec.Cmd, pid Pid_t, _ time.Time, flags ProcessCreationFlag) error { +func (e *OSExecutor) completeProcessStart(_ *exec.Cmd, handle ProcessHandle, flags ProcessCreationFlag) error { if cleanupJobDisabled() || (flags&CreationFlagEnsureKillOnDispose) == 0 { return nil } @@ -180,9 +180,9 @@ func (e *OSExecutor) completeProcessStart(_ *exec.Cmd, pid Pid_t, _ time.Time, f // The AssignProcessToJobObject docs say PROCESS_TERMINATE and PROCESS_SET_QUOTA are sufficient to assign a process to a job object, // but in practice we need PROCESS_ALL_ACCESS to make it work. const access = windows.PROCESS_ALL_ACCESS - processHandle, processHandleErr := windows.OpenProcess(access, false, uint32(pid)) + processHandle, processHandleErr := windows.OpenProcess(access, false, uint32(handle.Pid)) if processHandleErr != nil { - e.log.V(1).Info("Could not open new process handle", "PID", pid, "Error", processHandleErr) + e.log.V(1).Info("Could not open new process handle", "PID", handle.Pid, "Error", processHandleErr) } else { defer tryCloseHandle(processHandle) @@ -192,15 +192,15 @@ func (e *OSExecutor) completeProcessStart(_ *exec.Cmd, pid Pid_t, _ time.Time, f jobAssignmentErr := windows.AssignProcessToJobObject(pcj, processHandle) if jobAssignmentErr != nil { - e.log.V(1).Info("Could not assign process to job object", "PID", pid, "Error", jobAssignmentErr) + e.log.V(1).Info("Could not assign process to job object", "PID", handle.Pid, "Error", jobAssignmentErr) } } } - resumptionErr := resumeNewSuspendedProcess(uint32(pid)) + resumptionErr := resumeNewSuspendedProcess(uint32(handle.Pid)) if resumptionErr != nil { - e.log.Error(resumptionErr, "Could not resume new suspended process", "PID", pid) - return fmt.Errorf("could not resume new suspended process with pid %d: %w", pid, resumptionErr) + e.log.Error(resumptionErr, "Could not resume new suspended process", "PID", handle.Pid) + return fmt.Errorf("could not resume new suspended process with pid %d: %w", handle.Pid, resumptionErr) } return nil diff --git a/pkg/process/process_handle.go b/pkg/process/process_handle.go new file mode 100644 index 00000000..35556c78 --- /dev/null +++ b/pkg/process/process_handle.go @@ -0,0 +1,62 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package process + +import ( + "os" + "os/exec" + "time" +) + +// ProcessHandle is a compound type representing a reference to a process. +// It holds the process ID and its identity time (used to distinguish between +// different instances of processes with the same PID after PID reuse). +// +// The IdentityTime may not be a valid wall-clock time on all platforms; on Linux +// it is expressed as ticks since boot to avoid issues with system clock changes. +// +// ProcessHandle is a value type and is safe to use as a map key. +type ProcessHandle struct { + Pid Pid_t + IdentityTime time.Time +} + +// NewHandle creates a ProcessHandle from a PID and an identity time. +func NewHandle(pid Pid_t, identityTime time.Time) ProcessHandle { + return ProcessHandle{ + Pid: pid, + IdentityTime: identityTime, + } +} + +// ProcessHandleFromCmd creates a ProcessHandle from a started exec.Cmd. +// The command must have been started (cmd.Process must be non-nil). +// The identity time is obtained via ProcessIdentityTime for stability across clock changes. +func ProcessHandleFromCmd(cmd *exec.Cmd) ProcessHandle { + if cmd.Process == nil { + return ProcessHandle{Pid: UnknownPID} + } + + pid := Uint32_ToPidT(uint32(cmd.Process.Pid)) + return ProcessHandle{ + Pid: pid, + IdentityTime: ProcessIdentityTime(pid), + } +} + +// ProcessHandleFromProcess creates a ProcessHandle from a running os.Process. +// The identity time is obtained via ProcessIdentityTime for stability across clock changes. +func ProcessHandleFromProcess(p *os.Process) ProcessHandle { + if p == nil { + return ProcessHandle{Pid: UnknownPID} + } + + pid := Uint32_ToPidT(uint32(p.Pid)) + return ProcessHandle{ + Pid: pid, + IdentityTime: ProcessIdentityTime(pid), + } +} diff --git a/pkg/process/process_handle_test.go b/pkg/process/process_handle_test.go new file mode 100644 index 00000000..2d09d707 --- /dev/null +++ b/pkg/process/process_handle_test.go @@ -0,0 +1,37 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package process + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestProcessHandle_Comparable(t *testing.T) { + t.Parallel() + + now := time.Now() + h1 := NewHandle(Uint32_ToPidT(100), now) + h2 := NewHandle(Uint32_ToPidT(100), now) + h3 := NewHandle(Uint32_ToPidT(200), now) + + assert.Equal(t, h1, h2) + assert.NotEqual(t, h1, h3) + + // Verify zero-value handle doesn't equal a handle with actual values + zeroHandle := ProcessHandle{Pid: UnknownPID} + assert.NotEqual(t, zeroHandle, h1) + + // Verify usable as map key (replaces WaitKey) + m := map[ProcessHandle]string{ + h1: "first", + h3: "second", + } + assert.Equal(t, "first", m[h2]) + assert.Equal(t, "second", m[h3]) +} diff --git a/pkg/process/process_test.go b/pkg/process/process_test.go index e22e4e1d..925a1762 100644 --- a/pkg/process/process_test.go +++ b/pkg/process/process_test.go @@ -197,7 +197,7 @@ func TestRunCancelled(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) go func() { - _, _, startWaitForExit, processStartErr := executor.StartProcess(ctx, cmd, onProcessExited, process.CreationFlagsNone) + _, startWaitForExit, processStartErr := executor.StartProcess(ctx, cmd, onProcessExited, process.CreationFlagsNone) startupNotification := process.NewProcessExitInfo() if processStartErr != nil { startupNotification.Err = processStartErr @@ -245,19 +245,17 @@ func TestChildrenTerminated(t *testing.T) { return process.ProcessTreeItem{pid, identityTime} }}, {"executor start, no wait", func(t *testing.T, cmd *exec.Cmd, e process.Executor) process.ProcessTreeItem { - pid, _, _, err := e.StartProcess(context.Background(), cmd, nil, process.CreationFlagsNone) + handle, _, err := e.StartProcess(context.Background(), cmd, nil, process.CreationFlagsNone) require.NoError(t, err, "could not start the 'delay' test program") - identityTime := process.ProcessIdentityTime(pid) - require.False(t, identityTime.IsZero(), "process identity time should not be zero") - return process.ProcessTreeItem{pid, identityTime} + require.False(t, handle.IdentityTime.IsZero(), "process identity time should not be zero") + return process.ProcessTreeItem{handle.Pid, handle.IdentityTime} }}, {"executor start with wait", func(t *testing.T, cmd *exec.Cmd, e process.Executor) process.ProcessTreeItem { - pid, _, startWaitForProcessExit, err := e.StartProcess(context.Background(), cmd, nil, process.CreationFlagsNone) + handle, startWaitForProcessExit, err := e.StartProcess(context.Background(), cmd, nil, process.CreationFlagsNone) require.NoError(t, err, "could not start the 'delay' test program") startWaitForProcessExit() - identityTime := process.ProcessIdentityTime(pid) - require.False(t, identityTime.IsZero(), "process identity time should not be zero") - return process.ProcessTreeItem{pid, identityTime} + require.False(t, handle.IdentityTime.IsZero(), "process identity time should not be zero") + return process.ProcessTreeItem{handle.Pid, handle.IdentityTime} }}, } @@ -289,7 +287,7 @@ func TestChildrenTerminated(t *testing.T) { processTree, err := process.GetProcessTree(rootP) require.NoError(t, err) - err = executor.StopProcess(rootP.Pid, rootP.IdentityTime) + err = executor.StopProcess(process.NewHandle(rootP.Pid, rootP.IdentityTime)) require.NoError(t, err) // Wait up to 10 seconds for all processes to exit. This guarantees that the test will only pass if StopProcess() @@ -313,7 +311,7 @@ func TestChildrenTerminatedOnDispose(t *testing.T) { cmd.Dir = delayToolDir processExited := make(chan struct{}) - _, _, startWaitForProcessExit, startErr := executor.StartProcess( + _, startWaitForProcessExit, startErr := executor.StartProcess( context.Background(), cmd, process.ProcessExitHandlerFunc(func(_ process.Pid_t, _ int32, err error) { @@ -353,7 +351,7 @@ func TestWatchCatchesProcessExit(t *testing.T) { require.NoError(t, err) pid := process.Uint32_ToPidT(uint32(cmd.Process.Pid)) - delayProc, err := process.FindWaitableProcess(pid, time.Time{}) + delayProc, err := process.FindWaitableProcess(process.NewHandle(pid, time.Time{})) require.NoError(t, err) err = delayProc.Wait(ctx) @@ -380,7 +378,7 @@ func TestContextCancelsWatch(t *testing.T) { require.NoError(t, err, "command should start without error") pid := process.Uint32_ToPidT(uint32(cmd.Process.Pid)) - delayProc, err := process.FindWaitableProcess(pid, time.Time{}) + delayProc, err := process.FindWaitableProcess(process.NewHandle(pid, time.Time{})) require.NoError(t, err, "find process should succeed without error") waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second*5) diff --git a/pkg/process/process_types.go b/pkg/process/process_types.go index cfbe8faa..f5878b43 100644 --- a/pkg/process/process_types.go +++ b/pkg/process/process_types.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "os/exec" - "time" ) const ( @@ -43,23 +42,23 @@ type Pid_t int64 type Executor interface { // Starts the process described by given command instance. // When the passed context is cancelled, the process is automatically terminated. - // Returns the process PID, process start time, and a function that enables process exit notifications + // Returns a ProcessHandle identifying the started process and a function that enables process exit notifications // delivered to the exit handler. StartProcess( ctx context.Context, cmd *exec.Cmd, exitHandler ProcessExitHandler, creationFlags ProcessCreationFlag, - ) (pid Pid_t, startTime time.Time, startWaitForProcessExit func(), err error) + ) (handle ProcessHandle, startWaitForProcessExit func(), err error) - // Stops the process with a given PID. - // The processStartTime, if provided (time.IsZero() returns false), is used to further validate the process to be stopped. + // Stops the process identified by the given ProcessHandle. + // The handle's IdentityTime, if provided (time.IsZero() returns false), is used to further validate the process to be stopped // (to protect against stopping a wrong process, if the PID was reused). - StopProcess(pid Pid_t, processStartTime time.Time) error + StopProcess(handle ProcessHandle) error // Starts a process that does not need to be tracked (the caller is not interested in its exit code), // minimizing resource usage. An error is returned if the process could not be started. - StartAndForget(cmd *exec.Cmd, creationFlags ProcessCreationFlag) (pid Pid_t, startTime time.Time, err error) + StartAndForget(cmd *exec.Cmd, creationFlags ProcessCreationFlag) (handle ProcessHandle, err error) // Disposes the executor. Processes started with CreationFlagEnsureKillOnDispose will be terminated. // Other processes will be waited on (so that they do not become zombies), but not terminated. diff --git a/pkg/process/process_unix_test.go b/pkg/process/process_unix_test.go index e6b5d9b3..88071743 100644 --- a/pkg/process/process_unix_test.go +++ b/pkg/process/process_unix_test.go @@ -1,10 +1,10 @@ +//go:build !windows + /*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE in the project root for license information. *--------------------------------------------------------------------------------------------*/ -//go:build !windows - // Copyright (c) Microsoft Corporation. All rights reserved. package process_test @@ -56,7 +56,7 @@ func TestStopProcessIgnoreSigterm(t *testing.T) { executor := process.NewOSExecutor(log) start := time.Now() - err = executor.StopProcess(pid, time.Time{}) + err = executor.StopProcess(process.NewHandle(pid, time.Time{})) require.NoError(t, err) elapsed := time.Since(start) elapsedStr := osutil.FormatDuration(elapsed) diff --git a/pkg/process/process_util.go b/pkg/process/process_util.go index f773c77e..4b19c144 100644 --- a/pkg/process/process_util.go +++ b/pkg/process/process_util.go @@ -24,41 +24,41 @@ import ( "github.com/microsoft/dcp/pkg/slices" ) -type ProcessTreeItem struct { - Pid Pid_t - IdentityTime time.Time // Used to distinguish between different instances of processes with the same PID, may not be a valid wall-clock time. -} +// ProcessTreeItem is an alias for ProcessHandle, retained for backward compatibility. +// +// Deprecated: Use ProcessHandle directly. +type ProcessTreeItem = ProcessHandle var ( - This func() (ProcessTreeItem, error) + This func() (ProcessHandle, error) // Essentially the same as ps.ErrorProcessNotRunning, but we do not want to // expose the ps package outside of this package. ErrorProcessNotFound = errors.New("process does not exist") ) -func getIDs(items []ProcessTreeItem) []Pid_t { - return slices.Map[Pid_t](items, func(item ProcessTreeItem) Pid_t { +func getIDs(items []ProcessHandle) []Pid_t { + return slices.Map[Pid_t](items, func(item ProcessHandle) Pid_t { return item.Pid }) } // Returns the list of ID for a given process and its children // The list is ordered starting with the root of the hierarchy, then the children, then the grandchildren etc. -func GetProcessTree(rootP ProcessTreeItem) ([]ProcessTreeItem, error) { - root, err := findPsProcess(rootP.Pid, rootP.IdentityTime) +func GetProcessTree(rootP ProcessHandle) ([]ProcessHandle, error) { + root, err := findPsProcess(rootP) if err != nil { return nil, err } - tree := []ProcessTreeItem{} + tree := []ProcessHandle{} next := []*ps.Process{root} for len(next) > 0 { current := next[0] next = next[1:] nextPid := Uint32_ToPidT(uint32(current.Pid)) - tree = append(tree, ProcessTreeItem{nextPid, processIdentityTime(current)}) + tree = append(tree, ProcessHandle{nextPid, processIdentityTime(current)}) children, childrenErr := current.Children() if childrenErr != nil { @@ -82,9 +82,9 @@ func RunToCompletion(ctx context.Context, executor Executor, cmd *exec.Cmd) (int pic := make(chan ProcessExitInfo, 1) peh := NewChannelProcessExitHandler(pic) - _, _, startWaitForProcessExit, err := executor.StartProcess(ctx, cmd, peh, CreationFlagsNone) - if err != nil { - return UnknownExitCode, err + _, startWaitForProcessExit, startProcessErr := executor.StartProcess(ctx, cmd, peh, CreationFlagsNone) + if startProcessErr != nil { + return UnknownExitCode, startProcessErr } startWaitForProcessExit() @@ -158,8 +158,8 @@ func ProcessIdentityTime(pid Pid_t) time.Time { return processIdentityTime(proc) } -func findPsProcess(pid Pid_t, expectedIdentityTime time.Time) (*ps.Process, error) { - osPid, err := PidT_ToUint32(pid) +func findPsProcess(handle ProcessHandle) (*ps.Process, error) { + osPid, err := PidT_ToUint32(handle.Pid) if err != nil { return nil, err } @@ -170,17 +170,17 @@ func findPsProcess(pid Pid_t, expectedIdentityTime time.Time) (*ps.Process, erro if !errors.Is(procErr, ps.ErrorProcessNotRunning) { return nil, procErr } else { - return nil, fmt.Errorf("process with pid %d does not exist: %w", pid, ErrorProcessNotFound) + return nil, fmt.Errorf("process with pid %d does not exist: %w", handle.Pid, ErrorProcessNotFound) } } - if !HasExpectedIdentityTime(proc, expectedIdentityTime) { + if !HasExpectedIdentityTime(proc, handle.IdentityTime) { actualIdentityTime := processIdentityTime(proc) return nil, fmt.Errorf( "process start time mismatch, pid might have been reused: pid %d, expected start time %s, actual start time %s", - pid, - expectedIdentityTime.Format(osutil.RFC3339MiliTimestampFormat), + handle.Pid, + handle.IdentityTime.Format(osutil.RFC3339MiliTimestampFormat), actualIdentityTime.Format(osutil.RFC3339MiliTimestampFormat), ) } @@ -188,17 +188,17 @@ func findPsProcess(pid Pid_t, expectedIdentityTime time.Time) (*ps.Process, erro return proc, nil } -// Returns the process with the given PID. If the expectedStartTime is not zero, +// Returns the process with the given PID. If the handle's IdentityTime is not zero, // the process start time is checked to match the expected start time. -func FindProcess(pid Pid_t, expectedStartTime time.Time) (*os.Process, error) { - proc, err := findPsProcess(pid, expectedStartTime) +func FindProcess(handle ProcessHandle) (*os.Process, error) { + proc, err := findPsProcess(handle) if err != nil { return nil, err } - process, err := os.FindProcess(int(proc.Pid)) - if err != nil { - return nil, err + process, findErr := os.FindProcess(int(proc.Pid)) + if findErr != nil { + return nil, findErr } return process, nil @@ -324,8 +324,8 @@ func makeWaitable(pid Pid_t, proc *os.Process) Waitable { func init() { ps.EnableBootTimeCache(true) - This = sync.OnceValues(func() (ProcessTreeItem, error) { - retval := ProcessTreeItem{ + This = sync.OnceValues(func() (ProcessHandle, error) { + retval := ProcessHandle{ Pid: UnknownPID, IdentityTime: time.Time{}, } diff --git a/pkg/process/waitable_process.go b/pkg/process/waitable_process.go index b49615b9..929769c2 100644 --- a/pkg/process/waitable_process.go +++ b/pkg/process/waitable_process.go @@ -27,8 +27,8 @@ type WaitableProcess struct { waitLock sync.Mutex } -func FindWaitableProcess(pid Pid_t, processStartTime time.Time) (*WaitableProcess, error) { - foundProcess, err := FindProcess(pid, processStartTime) +func FindWaitableProcess(handle ProcessHandle) (*WaitableProcess, error) { + foundProcess, err := FindProcess(handle) if err != nil { return nil, err } @@ -36,7 +36,7 @@ func FindWaitableProcess(pid Pid_t, processStartTime time.Time) (*WaitableProces dcpProcess := &WaitableProcess{ WaitPollInterval: defaultWaitPollInterval, process: foundProcess, - processStartTime: processStartTime, + processStartTime: handle.IdentityTime, err: nil, waitLock: sync.Mutex{}, } @@ -70,7 +70,7 @@ func (p *WaitableProcess) pollingWait(ctx context.Context) { case <-timer.C: pid := Uint32_ToPidT(uint32(p.process.Pid)) - _, pollErr := FindProcess(pid, p.processStartTime) + _, pollErr := FindProcess(ProcessHandle{Pid: pid, IdentityTime: p.processStartTime}) // We couldn't find the PID, so the process has exited if pollErr != nil { p.err = nil diff --git a/test/debuggee/debuggee.go b/test/debuggee/debuggee.go new file mode 100644 index 00000000..5f15943e --- /dev/null +++ b/test/debuggee/debuggee.go @@ -0,0 +1,29 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +// Package main provides a simple program for debugging tests. +// This program is used as a target for DAP proxy integration tests with Delve. +package main + +import ( + "fmt" + "os" +) + +func main() { + // This is a breakpoint target line - tests will set breakpoints here + result := compute(10) // Line 18 - breakpoint target + fmt.Printf("Result: %d\n", result) + os.Exit(0) +} + +// compute performs a simple computation that can be stepped through. +func compute(n int) int { + sum := 0 + for i := 1; i <= n; i++ { + sum += i // Line 26 - can step through loop iterations + } + return sum +}